Nicolò Valigi | articles | talks

Implementing custom PyTorch operators in Rust

While most Torch operators are implemented in C++/CUDA, I wanted to see how difficult it would be to make one in Rust. There already are great Rust bindings for Pytorch, and they even come with an example Python extension that can operate on PyTorch tensors.

However, the approach above doesn't really extend Pytorch: it just creates Python bindings for a Rust function that just so happens to take Pytorch tensors as arguments. Of course, there still a lot of magic happening under the hood: the Pytorch (Python) types are implicitly converted to Rust types, and the tch-rs library exposes many Pytorch APIs to actually use those tensors in Rust code.

The main drawback of the "binding" approach is that it doesn't play nice with Pytorch's compiler (docs). This is easy to see by trying to torch.compile code that uses that function with fullgraph=True. For example:

def graph(x):
    x = x + 1
    x = tch_ext.add_one(x)
    x = x + 1
    return x

opt_foo = torch.compile(graph, fullgraph=True)

t = torch.randn(2, 2)
opt_foo(t)

will crash with because of the graph break:

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(add_one) __call__ [TensorVariable()] {}

from user code:
   File "/home/niko/code/tch-ext/test.py", line 10, in graph
    x = tch_ext.add_one(x)

This articles describes one way to make an actual custom operator that supports backpropagation and compilation properly, as described in the official docs.

Implementation choices

The official docs recommend using the TORCH_LIBRARY C++ macro to declare the operator. The implementation of the registration system lives here.

At this point, we have two choices about how to proceed:

  1. Extend tch-rs to bind TORCH_LIBRARY (and all the code needed to support it) to Rust, so that we can write Rust code to register new operators.

  2. Side-step the issue and register the operator in C++ code (using the existing macros), and have a small piece of C++ plumbing to call back into Rust for the actual logic.

The first option would be cleaner (no extra C++ code) but also more work to maintain as many Pytorch implementation details would need to be exposed to Rust. So I decided to go with the second option.

Luckily, neither of these two options require binding any Python code to/from Rust, as this is done internally by Pytorch already. However, we still need a (native) Python extension to register the Pytorch operator so it can be used in a script. We achieve this with the py-o3 crate. Note that we don't need to bind any method from this module: all the work is done by the TORCH_LIBRARY macro described above.

The C++ glue

As describe above, we write a bit of C++ glue code to register the operator. For convenience, we also allocate the result tensor here so we don't need to worry about lifetimes on the Rust side:

torch::Tensor fill_natural_impl(torch::IntArrayRef size) {
  auto tensor = torch::empty(size, torch::kInt64);
  fill_with_natural_numbers(&tensor);
  return tensor;
}

TORCH_LIBRARY(torchrust, m) { m.def("fill_natural", &fill_natural_impl); }

We also add the prototype for an extern "C" function that will contain the Rust logic:

extern "C" {
void fill_with_natural_numbers(torch::Tensor* tensor);
}

Finally on the Rust side, we can implement fill_with_natural_numbers by taking the C++ tensor pointer and constructing a Rust tensor from it:

pub unsafe extern "C" fn fill_with_natural_numbers(c_tensor: *mut C_tensor) {
    unsafe {
        let tensor = Tensor::from_ptr(c_tensor);
        ...
        std::mem::forget(tensor);
    }
}

The Tensor type in tch-rs is designed to take ownership of the C++ tensor pointer, which is inconvenient in this simple case. So we std::mem::forget to avoid invoking the Drop trait of Tensor and have the Python script free memory when done.

Build setup

The build setup is not as straightforward as the code, because we need to integrate Rust, C++, and Python.

First of all, we use the cc crate to manage the C++ build within Cargo and configure it in build.rs. Besides adding the C++ include paths, we also have to add the force_load (Mac) / --whole-archive (Linux) flag to ensure that the static initializers created by the TORCH_LIBRARY macro are invoked when the extension is imported.

In .cargo/config.toml, we add -L and -l flags to link to libtorch and adjust rpath accordingly.

Victory at last

Finally, we can use the new operator in Python after importing the extension module with torch.ops.load_library:

torch.ops.load_library('lib.so')

torch.ops.torchrust.fill_natural([1, 2, 3])

The full code is on GitHub.