Using PyTorch custom operations is common in PyTorch models. PyTorch custom operations can be custom classes and custom functions implemented in C++ and CUDA and used in both Python and C++ inference programs.
In this blog post, I would like to share how to implement PyTorch custom operations in C++ and CUDA, and how to use them in PyTorch models and AOTInductor compiled inference programs, using a simple identity convolution example.
PyTorch Custom Function
PyTorch custom functions can be implemented in C++ and CUDA and registered using the TORCH_LIBRARY_IMPL macro. Both the CPU and CUDA implementations can be provided, and PyTorch will dispatch to the correct implementation based on the device of the input tensors.
// --------------------------------------------------------------------------- // CPU implementation: plain element-wise copy via clone(). // --------------------------------------------------------------------------- torch::Tensor identity_conv_cpu_impl(const torch::Tensor& input) { TORCH_CHECK(!input.is_cuda(), "identity_conv_cpu_impl: input must be a CPU tensor"); return input.clone(); }
// --------------------------------------------------------------------------- // Host-side dispatcher. // --------------------------------------------------------------------------- torch::Tensor identity_conv_cuda_impl(const torch::Tensor& input) { TORCH_CHECK(input.is_cuda(), "identity_conv_cuda_impl: input must be a CUDA tensor");
// Output has the same shape, dtype, and strides as input. auto output = torch::empty_like(input); constint64_t numel = input.numel();
if (numel == 0) return output;
// Upload shape and strides to the device so the kernel can read them. constint ndim = input.dim(); constauto opts = torch::TensorOptions().dtype(torch::kInt64).device(input.device()); constauto shape_dev = torch::tensor( std::vector<int64_t>(input.sizes().begin(), input.sizes().end()), opts); constauto strides_dev = torch::tensor( std::vector<int64_t>(input.strides().begin(), input.strides().end()), opts);
// CUDA kernel implementation for my_ops::identity_conv_op. TORCH_LIBRARY_IMPL(my_ops, CUDA, m) { m.impl("identity_conv_op", identity_conv_cuda_impl); }
// CPU fallback. TORCH_LIBRARY_IMPL(my_ops, CPU, m) { m.impl("identity_conv_op", identity_conv_cpu_impl); }
PyTorch Custom Class
PyTorch custom functions are stateless and cannot hold any parameters. If we would like to implement a custom class that holds some parameters and has a forward() method that can be called from Python, we can use torch::CustomClassHolder to define a custom class in C++ and register it with TORCH_LIBRARY macro.
// --------------------------------------------------------------------------- // IdentityConvClass // // A custom class registered with torch.classes so that it can be embedded // in a torch.nn.Module, exported with torch.export, and compiled with // AOTInductor. // // The forward() method delegates to the CUDA identity kernel. The // `channels_` field is preserved for semantic completeness and is serialised // via def_pickle so that the class survives export/import round-trips. // --------------------------------------------------------------------------- structIdentityConvClass : torch::CustomClassHolder { int64_t channels_;
// --------------------------------------------------------------------------- // Operator / class registration // // This file has no pybind11 dependency and is compiled into // libidentity_conv_ops.so, which can be dlopen'd by a pure C++ binary // without needing libpython. // --------------------------------------------------------------------------- TORCH_LIBRARY(my_ops, m) { // Register IdentityConvClass so Python can instantiate it as // torch.classes.my_ops.IdentityConvClass(channels). m.class_<IdentityConvClass>("IdentityConvClass") .def(torch::init<int64_t>()) .def("forward", &IdentityConvClass::forward) .def("get_channels", &IdentityConvClass::get_channels) // __obj_flatten__ is called by torch.export's non-strict tracer on // the *real* C++ object before it switches to FakeTensor mode. // Must return a tuple of (str, value) pair-tuples so that // _check_valid_flat_script_obj passes (it checks isinstance(item, // tuple) for every element in the flat sequence). We encode `channels_` // as a single named entry; there are no tensor leaves. .def("__obj_flatten__", [](const c10::intrusive_ptr<IdentityConvClass>& self) { return std::make_tuple( std::make_tuple(std::string("channels"), self->channels_)); }) // def_pickle enables TorchScript serialisation. .def_pickle( [](const c10::intrusive_ptr<IdentityConvClass>& self) -> int64_t { return self->channels_; }, [](int64_t channels) -> c10::intrusive_ptr<IdentityConvClass> { return c10::make_intrusive<IdentityConvClass>(channels); });
// Register the schema for identity_conv_op. m.def("identity_conv_op(Tensor x) -> Tensor"); }
Using Custom Operations and Classes In PyTorch
The PyTorch custom classes, functions, and their registrations in C++ are built into a shared library (libidentity_conv_ops.so) that can be loaded and registered in PyTorch using torch.ops.load_library. For torch.compile and torch.export compatibility, we also need to register “fake” (abstract) versions of the custom classes and functions in PyTorch using @register_fake_class and @torch.library.register_fake so that the FakeTensor-based symbolic tracing can work correctly without having to execute the actual C++/CUDA code during tracing.
""" custom_ops.py ============= Loads the C++ / CUDA shared library and sets up all custom PyTorch operations used by the IdentityModel: 1. torch.classes.my_ops.IdentityConvClass (registered by the shared library) - A fake/abstract version is registered here so that torch.export can trace through module attributes that hold an instance of this class. 2. my_ops::identity_conv_op (schema + CPU + CUDA registered by the shared library) - register_fake: abstract implementation for torch.export / FakeTensor. """
import os
import torch import torch.library
# --------------------------------------------------------------------------- # 1. Load the C++ / CUDA shared library. # This triggers the TORCH_LIBRARY(my_ops, ...) static initialiser which # registers torch.classes.my_ops.IdentityConvClass into PyTorch's global # operator registry. # # The library path can be overridden via the IDENTITY_CONV_OPS_LIB # environment variable; it defaults to ../ext/libidentity_conv_ops.so # relative to this file. # --------------------------------------------------------------------------- _default_lib = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "ext", "libidentity_conv_ops.so") _lib_path = os.path.abspath( os.environ.get("IDENTITY_CONV_OPS_LIB", _default_lib)) torch.ops.load_library(_lib_path)
# --------------------------------------------------------------------------- # 2. Register a "fake" (abstract) version of IdentityConvClass for # torch.export tracing. # # torch.export uses FakeTensor-based symbolic tracing. When it encounters # a custom-class attribute on a module it looks for: # • __obj_flatten__ - returns (leaves, context) for pytree flattening # • __obj_unflatten__ - reconstructs the object from (leaves, context) # These are provided by the @register_fake_class-decorated Python class. # --------------------------------------------------------------------------- from torch._library.fake_class_registry import register_fake_class
@register_fake_class("my_ops::IdentityConvClass") classFakeIdentityConvClass: """Abstract counterpart of IdentityConvClass used during torch.export."""
# -- pytree protocol required by torch.export ---------------------------- def__obj_flatten__(self): # Must return a tuple of (str, value) pair-tuples, matching the C++ # __obj_flatten__ which returns (("channels", N),). return (("channels", self.channels_), )
@classmethod def__obj_unflatten__(cls, flat): # `flat` is the (possibly tensor-fakified) sequence of (key, value) # pairs produced by maybe_to_fake_obj. Reconstruct from it. return cls(dict(flat)["channels"])
# --------------------------------------------------------------------------- # 3. Register the fake (abstract) implementation of identity_conv_op for # torch.export tracing. # # The schema and both implementations (CUDA and CPU) are already registered # by the C++ extension via TORCH_LIBRARY / TORCH_LIBRARY_IMPL. Python only # needs to provide the abstract shape/dtype computation so that # torch.export's FakeTensor interpreter can trace through the op. # --------------------------------------------------------------------------- @torch.library.register_fake("my_ops::identity_conv_op") def_identity_conv_op_fake(x: torch.Tensor) -> torch.Tensor: """Abstract implementation used by torch.export / FakeTensor tracing.""" return torch.empty_like(x)
# Convenience alias so other modules can do: from custom_ops import identity_conv_op identity_conv_op = torch.ops.my_ops.identity_conv_op
PyTorch custom classes can be loaded using torch.classes and PyTorch custom functions can be loaded using torch.ops after the shared library is loaded.
""" model.py ======== Defines the four-layer IdentityModel used in the AOTInductor demo. Layer layout ------------ layer1 : IdentityConv - native PyTorch operators layer2 : IdentityConvCustomClass - torch.classes C++/CUDA custom class layer3 : IdentityConvCustomOp - torch.library.custom_op C++/CUDA op layer4 : IdentityConv - native PyTorch operators Every layer is an identity transformation, so model(x) == x for any input x. """
import torch import torch.nn as nn
# Importing custom_ops registers the C++ extension, the fake class, and # the custom op - this must happen before any model is instantiated. from custom_ops import identity_conv_op # noqa: F401
# --------------------------------------------------------------------------- # Layer 1 / 4 - native PyTorch depthwise 1×1 convolution (identity weights) # --------------------------------------------------------------------------- classIdentityConv(nn.Module): """Identity convolution implemented with native PyTorch operators. Uses a depthwise Conv2d with kernel_size=1 and weight=1.0, which is equivalent to a no-op (output == input). This layer is compatible with torch.export and AOTInductor out of the box. """
def__init__(self, channels: int) -> None: super().__init__() self.conv = nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=channels, bias=False, ) # Set all weights to 1.0 so that the convolution acts as identity. self.conv.weight.data = torch.ones(channels, 1, 1, 1) # Freeze the weights - they are constants, not learnable parameters. self.conv.weight.requires_grad = False
# --------------------------------------------------------------------------- # Layer 2 - custom C++/CUDA class via torch.classes # --------------------------------------------------------------------------- classIdentityConvCustomClass(nn.Module): """Identity convolution backed by a torch.classes C++/CUDA custom class. At runtime the forward call is dispatched to the CUDA kernel registered inside IdentityConvClass (csrc/identity_conv.cpp + .cu). For torch.export compatibility a FakeIdentityConvClass is registered in custom_ops.py via @register_fake_class so that symbolic tracing works. """
# --------------------------------------------------------------------------- # Layer 3 - custom C++/CUDA op via torch.library.custom_op # --------------------------------------------------------------------------- classIdentityConvCustomOp(nn.Module): """Identity convolution backed by a torch.library.custom_op C++/CUDA op. The op (my_ops::identity_conv_op) is defined in custom_ops.py with: • a register_fake implementation for torch.export tracing • a register_kernel("cuda") implementation that calls the CUDA kernel """
# --------------------------------------------------------------------------- # Full model # --------------------------------------------------------------------------- classIdentityModel(nn.Module): """Four-layer identity model for AOTInductor demo."""
defforward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x
defcreate_model(channels: int = 3) -> IdentityModel: """Return an IdentityModel in eval mode on the default CUDA device.""" return IdentityModel(channels=channels).cuda().eval()
PyTorch Model Export and Lowering
The PyTorch model using custom classes and custom functions can be exported with torch.export if fake (abstract) versions of all custom classes and functions are registered for torch.export symbolic tracing.
#!/usr/bin/env python3 """ export_compile.py ================= Exports the IdentityModel with torch.export and compiles it with torch._inductor.aoti_compile_and_package. The resulting package (model.pt2) is written to the artifacts/ directory and can be loaded by both run_inference.py (Python) and the C++ inference binary. Usage (run from the python/ directory): python export_compile.py """
import os import sys
# Ensure the python/ directory is on the path so that local modules are found. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
# Importing custom_ops loads the C++ extension and registers all custom ops. import custom_ops # noqa: F401 from model import create_model
# Save the compiled package in the artifacts/ directory at the project root. PACKAGE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "artifacts", "model.pt2")
From the exported graph we can see that the custom class IdentityConvClass.forward is represented as a call to torch.ops.higher_order.call_torchbind. The custom op identity_conv_op is represented as a call to torch.ops.my_ops.identity_conv_op.
The exported program can be compiled and packaged with torch._inductor.aoti_compile_and_package to produce a model.pt2 package that can be loaded by both Python and C++ inference programs. The custom class and custom op implementations will be loaded from the shared library and correctly dispatched at runtime when the compiled model is executed.
#!/usr/bin/env python3 """ run_inference.py ================ Loads the AOTInductor-compiled IdentityModel package (model.pt2) and runs inference to verify correctness. The output of the identity model must equal the input within a tight floating-point tolerance. Usage (run from the python/ directory after export_compile.py): python run_inference.py [MODEL_PATH [OP_LIB_PATH]] Arguments: MODEL_PATH Path to the compiled model package (.pt2). Defaults to ../artifacts/model.pt2 relative to this script. OP_LIB_PATH Path to the custom-op shared library (.so). When provided the library path is forwarded to custom_ops.py via the IDENTITY_CONV_OPS_LIB environment variable so that torch.ops.load_library uses that file instead of the default ../ext/libidentity_conv_ops.so. """
import os import sys
# Ensure the python/ directory is on the path so that local modules are found. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch import torch._inductor.codecache # required before aoti_load_package
# --------------------------------------------------------------------------- # If an explicit library path was given, pass it to custom_ops.py via an # environment variable so that torch.ops.load_library uses that file. # --------------------------------------------------------------------------- if OP_LIB_PATH isnotNone: os.environ["IDENTITY_CONV_OPS_LIB"] = os.path.abspath(OP_LIB_PATH)
# Importing custom_ops loads the shared library and registers all custom ops # BEFORE the compiled model is loaded. import custom_ops # noqa: F401
# --------------------------------------------------------------------------- # Configuration - must match the values used in export_compile.py # --------------------------------------------------------------------------- CHANNELS = 3 BATCH_SIZE = 1 HEIGHT = 224 WIDTH = 224
The custom class and custom function shared library loading and registration can be performed using dlopen in a pure C++ inference program without any pybind11 or libpython dependency.
/* * main.cpp * ======== * C++ inference program for the AOTInductor-compiled IdentityModel. * * Prerequisites * ------------- * • Build libidentity_conv_ops.so (top-level CMakeLists.txt) first. * • Run export_compile.py to produce model.pt2. * * The custom operator library (libidentity_conv_ops.so) must be loaded before * the compiled model is executed so that torch.classes.my_ops and * my_ops::identity_conv_op are present in the operator registry. * * libidentity_conv_ops.so has no pybind11 dependency and does not link * libtorch_python.so, so no libpython pre-loading is required. * * Usage * ----- * ./run_inference <path/to/model.pt2> <path/to/libidentity_conv_ops.so> * * Verification * ------------ * The model is an identity transform, so output should equal the random * input within floating-point rounding tolerance (< 1e-5). */
std::cout << "================================================\n" << "AOTInductor - C++ Inference\n" << "================================================\n";
try { // ------------------------------------------------------------------ // Step 1: Load the custom operator shared library. // // libidentity_conv_ops.so contains only TORCH_LIBRARY registrations // and the CPU/CUDA kernels. It has no pybind11 dependency and does // not link libtorch_python.so, so no libpython pre-loading is needed. // ------------------------------------------------------------------ std::cout << "\n[1/4] Loading custom op library:\n " << custom_op_lib << std::endl;
// ------------------------------------------------------------------ // Step 2: Load the compiled model package. // // AOTIModelPackageLoader unpacks the .pt2 archive and prepares the // AOTIModelContainerRunner for the target device. // ------------------------------------------------------------------ std::cout << "\n[2/4] Loading model package:\n " << model_path << std::endl;
torch::inductor::AOTIModelPackageLoader loader(model_path); auto runner = loader.get_runner();
std::cout << " Model loaded." << std::endl;
// ------------------------------------------------------------------ // Step 3: Prepare a random input tensor on CUDA. // ------------------------------------------------------------------ auto options = torch::TensorOptions() .dtype(torch::kFloat32) .device(torch::kCUDA, 0);
auto input = torch::randn({kBatchSize, kChannels, kHeight, kWidth}, options);