PyTorch Custom ONNX Operator Export

Introduction

In my previous article “TensorRT Custom Plugin Example”, I discussed how to implement a TensorRT custom plugin and how to integrate the TensorRT custom plugin into a TensorRT engine from an ONNX model with an ONNX custom operator. The ONNX custom operator I used in the previous article was created from scratch using ONNX graph surgeon. However, in practice, when the ONNX model graph becomes more complex, performing ONNX graph surgeon to create ONNX custom operators might become very difficult. Therefore, exporting the ONNX custom operators from PyTorch modules directly is a more practical approach.

In this article, I will discuss how to export PyTorch modules to ONNX custom operators for TensorRT custom plugin integration.

PyTorch Custom ONNX Operator Export

PyTorch official tutorials “How to Export Pytorch Model with Custom Op to ONNX and Run it in ONNX Runtime” and “Extending the ONNX Registry” have also discussed how to export PyTorch modules to ONNX custom operators. However, the assumption of those tutorials is that the custom operator implementations are already available in the ONNX registry or have been built into PyTorch C++ extension so that it can be run with PyTorch in Python. This assumption might not be true in practice. For example, TensorRT custom plugins are not quite straightforward to be run as part of PyTorch and ONNX export forward pass.

In our use cases, we want to export PyTorch modules to ONNX custom operators without having to implement and integrate the custom operators into ONNX registry or PyTorch C++ extension.

ONNX Opset 15 or Above

To export an ONNX model using ONNX Opset 15 or above (ONNX IR >= 8), we can employ the export_modules_as_functions argument in the torch.onnx.export function. This argument is a dictionary that maps the PyTorch module to the ONNX custom operator function. TensorRT parser will be able to map the ONNX custom operator function to the TensorRT custom plugin.

scripts/export_identity_neural_network_new_opset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Define a PyTorch identity neural network and export it to ONNX format with custom ONNX operators.
# This only works for ONNX Opset 15 and above.

import os
from typing import List, ClassVar

import torch
import torch.nn as nn


class IdentityConvBase(nn.Module):

def __init__(self, channels):
super(IdentityConvBase, self).__init__()
self.conv = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=channels,
bias=False)
# Fill the weight values with 1.0.
self.conv.weight.data = torch.ones(channels, 1, 1, 1)
# Turn off the gradient for the weights, making it a constant.
self.conv.weight.requires_grad = False

def forward(self, x):
return self.conv(x)


class IdentityConv(IdentityConvBase):

__constants__ = ["kernel_shape", "strides", "pads", "group"]
# Attributes to match the plugin requirements.
# Must follow the type annotations via PEP 526-style.
# https://peps.python.org/pep-0526/#class-and-instance-variable-annotations
kernel_shape: ClassVar[List[int]]
strides: ClassVar[List[int]]
pads: ClassVar[List[int]]
group: int

def __init__(self, channels):
super(IdentityConv, self).__init__(channels)
self.kernel_shape = list(self.conv.kernel_size)
self.strides = list(self.conv.stride)
self.pads = list(self.conv.padding)
# ONNX expects a list of 4 pad values whereas PyTorch uses a list of 2 pad values.
self.pads = self.pads + self.pads
self.group = self.conv.groups

def forward(self, x):
# Call the parent class method.
x = super(IdentityConv, self).forward(x)
# Apply the identity operation.
return x


class IdentityNeuralNetwork(nn.Module):

def __init__(self, channels):
super(IdentityNeuralNetwork, self).__init__()
self.conv1 = IdentityConvBase(channels)
self.conv2 = IdentityConv(channels)
self.conv3 = IdentityConvBase(channels)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x


if __name__ == "__main__":

opset_version = 15
data_directory_path = "data"
onnx_file_name = "identity_neural_network.onnx"
onnx_file_path = os.path.join(data_directory_path, onnx_file_name)

input_shape = (1, 3, 480, 960)
input_data = torch.rand(*input_shape).float()
input_channels = input_shape[1]

# Create an instance of the identity neural network.
identity_neural_network = IdentityNeuralNetwork(input_channels)
# Set the model to evaluation mode.
identity_neural_network.eval()
# Export the model to ONNX format with the custom operator as ONNX custom functions.
# TensorRT ONNX parser can parse the ONNX custom functions to TensorRT plugins.
# References: https://github.com/pytorch/pytorch/issues/65199
torch.onnx.export(model=identity_neural_network,
args=(input_data, ),
f=onnx_file_path,
input_names=["X0"],
output_names=["X3"],
opset_version=opset_version,
export_modules_as_functions={IdentityConv})
print(
f"Exported the identity neural network to ONNX format: {onnx_file_path}"
)

ONNX Opset 14 or Below

To export an ONNX model using ONNX Opset 14 or below (ONNX IR < 8), the export_modules_as_functions argument in the torch.onnx.export function cannot be used. Instead, we will have to implement a PyTorch dummy torch.autograd.Function that uses exactly the same IO tensors of the same shape and type and the correct forward implementation is not necessary if the downstream ONNX graph is not data dependent, register it as a PyTorch custom operator for the PyTorch module that we want to export as an ONNX custom operator, and wrap it in a PyTorch dummy module. During the export, we will just switch the the forward pass of the PyTorch module to the forward pass of the PyTorch dummy module.

scripts/export_identity_neural_network_old_opset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Define a PyTorch identity neural network and export it to ONNX format with custom ONNX operators.

import os

import torch
import torch.nn as nn
from torch.onnx.symbolic_helper import _get_tensor_sizes


class IdentityConvBase(nn.Module):

def __init__(self, channels):
super(IdentityConvBase, self).__init__()
self.conv = nn.Conv2d(in_channels=channels,
out_channels=channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=channels,
bias=False)
# Fill the weight values with 1.0.
self.conv.weight.data = torch.ones(channels, 1, 1, 1)
# Turn off the gradient for the weights, making it a constant.
self.conv.weight.requires_grad = False

def forward(self, x):
return self.conv(x)


class DummyIdentityConvOp(torch.autograd.Function):

@staticmethod
def symbolic(g, input, weight, kernel_shape, strides, pads, group):
args = [input, weight]
# These become the operator attributes.
kwargs = {
"kernel_shape_i": kernel_shape,
"strides_i": strides,
"pads_i": pads,
"group_i": group
}
output_type = input.type().with_sizes(_get_tensor_sizes(input))
return g.op("CustomTorchOps::IdentityConv", *args,
**kwargs).setType(output_type)

@staticmethod
def forward(ctx, input, weight, kernel_shape, strides, pads, group):
# We don't have to actually implement the correct forward pass,
# if the downstream graph is not data dependent,
# as long as the shape of the output is correct.
return input


class DummyIdentityConv(nn.Module):

def __init__(self, channels):
super(DummyIdentityConv, self).__init__()
self.kernel_shape = (1, 1)
self.strides = (1, 1)
self.pads = (0, 0, 0, 0)
self.group = channels

# Fill the weight values with 1.0.
self.weight = torch.ones(channels, 1, 1, 1)
# Turn off the gradient for the weights, making it a constant.
self.weight.requires_grad = False

def forward(self, x):
x = DummyIdentityConvOp.apply(x, self.weight, self.kernel_shape,
self.strides, self.pads, self.group)
return x


class IdentityConv(IdentityConvBase):

def __init__(self, channels):
super(IdentityConv, self).__init__(channels)
self.kernel_shape = list(self.conv.kernel_size)
self.strides = list(self.conv.stride)
self.pads = list(self.conv.padding)
# ONNX expects a list of 4 pad values whereas PyTorch uses a list of 2 pad values.
self.pads = self.pads + self.pads
self.group = self.conv.groups

def forward(self, x):
# Call the parent class method.
x = super(IdentityConv, self).forward(x)
# Apply the identity operation.
return x


class IdentityNeuralNetwork(nn.Module):

def __init__(self, channels):
super(IdentityNeuralNetwork, self).__init__()
self.conv1 = IdentityConvBase(channels)
self.conv2 = IdentityConv(channels=channels)
self.conv3 = IdentityConvBase(channels)
# Create a dummy identity convolution only used for ONNX export.
self.conv2_export = DummyIdentityConv(channels=channels)

def forward(self, x):
x = self.conv1(x)
if torch.onnx.is_in_onnx_export():
x = self.conv2_export(x)
else:
x = self.conv2(x)
x = self.conv3(x)
return x


if __name__ == "__main__":

opset_version = 13
data_directory_path = "data"
onnx_file_name = "identity_neural_network.onnx"
onnx_file_path = os.path.join(data_directory_path, onnx_file_name)

input_shape = (1, 3, 480, 960)
input_data = torch.rand(*input_shape).float()
input_channels = input_shape[1]

# Create an instance of the identity neural network.
identity_neural_network = IdentityNeuralNetwork(input_channels)
# Set the model to evaluation mode.
identity_neural_network.eval()
output = identity_neural_network(input_data)
# Export the model to ONNX format with the custom operator as ONNX custom functions.
torch.onnx.export(model=identity_neural_network,
args=(input_data, ),
f=onnx_file_path,
input_names=["X0"],
output_names=["X3"],
opset_version=opset_version)

print(
f"Exported the identity neural network to ONNX format: {onnx_file_path}"
)

Custom ONNX Operator and TensorRT Custom Plugin Validation

Using the TensorRT custom plugin, the TensorRT engine builder, and the TensorRT engine runner we created from the previous article, we can validate our ONNX custom operator export approaches.

1
2
3
$ python scripts/export_identity_neural_network_new_opset.py
$ ./build/src/build_engine
$ ./build/src/run_engine
1
2
3
$ python scripts/export_identity_neural_network_old_opset.py
$ ./build/src/build_engine
$ ./build/src/run_engine

Source Code

The source code of the example can be found in my GitHub repository “TensorRT Custom Plugin Example”.

References

Author

Lei Mao

Posted on

02-11-2024

Updated on

02-11-2024

Licensed under


Comments