While most Torch operators are implemented in C++/CUDA, I wanted to see how difficult it would be to make one in Rust. There already are great Rust bindings for Pytorch, and they even come with an example Python extension that can operate on PyTorch tensors.
However, the approach above doesn't really extend Pytorch: it just creates
Python bindings for a Rust function that just so happens to take Pytorch tensors
as arguments. Of course, there still a lot of magic happening under the hood:
the Pytorch (Python) types are implicitly converted to Rust types, and the
tch-rs
library exposes many Pytorch APIs to actually use those tensors in Rust
code.
The main drawback of the "binding" approach is that it doesn't play nice with
Pytorch's compiler
(docs).
This is easy to see by trying to torch.compile
code that uses that function
with fullgraph=True
. For example:
def graph(x):
x = x + 1
x = tch_ext.add_one(x)
x = x + 1
return x
opt_foo = torch.compile(graph, fullgraph=True)
t = torch.randn(2, 2)
opt_foo(t)
will crash with because of the graph break:
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(add_one) __call__ [TensorVariable()] {}
from user code:
File "/home/niko/code/tch-ext/test.py", line 10, in graph
x = tch_ext.add_one(x)
This articles describes one way to make an actual custom operator that supports backpropagation and compilation properly, as described in the official docs.
The official docs recommend using the TORCH_LIBRARY
C++ macro to declare the
operator. The implementation of the registration system lives
here.
At this point, we have two choices about how to proceed:
Extend tch-rs
to bind TORCH_LIBRARY
(and all the code needed to support
it) to Rust, so that we can write Rust code to register new operators.
Side-step the issue and register the operator in C++ code (using the existing macros), and have a small piece of C++ plumbing to call back into Rust for the actual logic.
The first option would be cleaner (no extra C++ code) but also more work to maintain as many Pytorch implementation details would need to be exposed to Rust. So I decided to go with the second option.
Luckily, neither of these two options require binding any Python code to/from
Rust, as this is done internally by Pytorch already. However, we still need a
(native) Python extension to register the Pytorch operator so it can be used in
a script. We achieve this with the py-o3
crate. Note that we don't need to
bind any method from this module: all the work is done by the TORCH_LIBRARY
macro described above.
As describe above, we write a bit of C++ glue code to register the operator. For convenience, we also allocate the result tensor here so we don't need to worry about lifetimes on the Rust side:
torch::Tensor fill_natural_impl(torch::IntArrayRef size) {
auto tensor = torch::empty(size, torch::kInt64);
fill_with_natural_numbers(&tensor);
return tensor;
}
TORCH_LIBRARY(torchrust, m) { m.def("fill_natural", &fill_natural_impl); }
We also add the prototype for an extern "C"
function that will contain the
Rust logic:
extern "C" {
void fill_with_natural_numbers(torch::Tensor* tensor);
}
Finally on the Rust side, we can implement fill_with_natural_numbers
by taking
the C++ tensor pointer and constructing a Rust tensor from it:
pub unsafe extern "C" fn fill_with_natural_numbers(c_tensor: *mut C_tensor) {
unsafe {
let tensor = Tensor::from_ptr(c_tensor);
...
std::mem::forget(tensor);
}
}
The Tensor
type in tch-rs
is designed to take ownership of the C++ tensor
pointer, which is inconvenient in this simple case. So we std::mem::forget
to
avoid invoking the Drop
trait of Tensor
and have the Python script free
memory when done.
The build setup is not as straightforward as the code, because we need to integrate Rust, C++, and Python.
First of all, we use the cc
crate to manage the C++ build within Cargo and
configure it in build.rs
. Besides adding the C++ include paths, we also have
to add the force_load
(Mac) / --whole-archive
(Linux) flag to ensure that
the static initializers created by the TORCH_LIBRARY
macro are invoked when
the extension is imported.
In .cargo/config.toml
, we add -L
and -l
flags to link to libtorch
and
adjust rpath
accordingly.
Finally, we can use the new operator in Python after importing the extension
module with torch.ops.load_library
:
torch.ops.load_library('lib.so')
torch.ops.torchrust.fill_natural([1, 2, 3])
The full code is on GitHub.