PyTorch Export
Introduction
PyTorch is easy to use because of its flexibility and dynamic nature. Because of this, deploying PyTorch models for inference can often result in suboptimal performance. In PyTorch 2, a new feature called torch.export is introduced to allow users to export their PyTorch models into a static graph representation, which can be optimized and compiled for inference.
In this blog post, I would like to discuss some common practices for torch.export and its downstream inference using Python and C++.
Custom Operations
It is very common to have custom operations in PyTorch models. There are two ways to define custom operations in PyTorch:
- Define and register the custom operation as a stateless
torch.opsfunction. This is the recommended way to define custom operations for export, as it allows the operation to be represented in the graph and exported to C++. - Define the custom operation as a stateful
torch.classesclass and register it as a fake Python class. This approach can be somewhat heavier to implement since it requires implementing the custom operation in both C++ and Python.
Custom C++ Class and Fake Python Class
This example demonstrates how to define a custom C++ class and a fake Python class for export. The fake Python class is necessary for torch.export to trace the custom operation and include it in the exported graph. However, its implementation does not have to be exactly matching the C++ class, as long as its output shape is correct and it does not introduce graph breaks. For example, if we know the custom operation output shape is always the same as the input shape, we can simply return the input tensor in the fake Python class without implementing the actual logic of the custom operation, no matter how complicated the logic is in the C++ class.
In addition, we have to make sure strict=False is used for torch.export.export in the case of the models that use custom operations. This is because when strict=True, torch.export.export can only export the models that can be strictly represented using a specific set of PyTorch 2 IR operators, which does not include custom operations.
No Data-Dependent Shapes Or Control Flows
PyTorch export does not allow data-dependent shapes or control flows in the exported graph. To some extent, it is converting a PyTorch model into a TensorFlow 1, JAX, or ONNX model, which also does not allow data-dependent shapes or control flows. This requirement also applies to custom operations.
Because C++ custom operation class implementation can hide the intermediate data-dependent shapes or control flows, consequently, there can be very small deviations from this requirement in a custom operation, as long as the apparent custom operation output shape is not data-dependent. However, this situation is usually rare. The best practice is to avoid data-dependent shapes or control flows in any part of the model, including custom operations, so that the model can be represented using a static graph.
Running Exported Models
The exported model can be saved and loaded using torch.export.save and torch.export.load in Python, respectively. The loaded model can also be run in Python even without the presence of model definition code, unlike the models that typically being saved and loaded via state dict using torch.save and torch.load. However, the exported model cannot be loaded and run in C++.
To run the exported model in C++, the exported model has to be optimized and compiled using AOTInductor for x86 or Executorch for edge devices.
Conclusions
We could make a few analogies to better understand PyTorch export.
- PyTorch exported model is analogous to an ONNX model.
- PyTorch IR is analogous to ONNX Opset.
- PyTorch custom operation is analogous to ONNX custom operator, which does not introduce graph break, but can eliminate the possibility of graph optimization and fusion across the custom operation boundary.
- There is nothing analogous to ONNX Runtime in PyTorch yet.
AOTInductorandExecutorchare analogous toTensorRTwhich can optimize and compile the exported model for inference.
Custom operations can hardly save the model if it has graph breaks. Therefore, the first step to create a good model for inference is to ensure there is no graph break in the model.
References
PyTorch Export