PyTorch Model Export to ONNX Failed Due to ATen

Introduction

ONNX is an open format to represent deep learning models. With ONNX, AI developers can more easily move models between state-of-the-art tools and choose the combination that is best for them. So we could often see that people convert models between different deep learning frameworks by first exporting the model to ONNX model using the original deep learning framework and load the ONNX model to the other framework. PyTorch is one of the few deep learning frameworks which natively support ONNX. Here “natively” means that ONNX is included in the PyTorch package, the PyTorch team is actively communicating with the ONNX team and adding new features and supports for PyTorch to ONNX if necessary. Some deep learning frameworks such as TensorFlow which does not natively support ONNX does not have ONNX included in the package. The conversion from TensorFlow to ONNX relies on unofficial third-party efforts and sometimes it does not work in many scenarios.

ONNX also has ONNX Runtime which could serve ONNX model in a high-performance manner for model deployment. NVIDIA TensorRT is also a platform for high-performance deep learning inference. It supports PyTorch model via ONNX format. So people convert PyTorch models to ONNX models, and TensorRT takes in ONNX models, parse the models, and build the serving engine. We could see that, as least so far, ONNX has been very important to PyTorch. While exporting simple and conventional deep learning models from PyTorch to ONNX works most of the time, there could be frustrating experiences of exporting your sophisticated models to ONNX models.

One of the problems causing the failure of converting PyTorch models to ONNX models is ATen operators. ATen stands for “A Tensor Library for C++11”. If you are using some PyTorch classes or functions which were implemented using ATen operators (without being aware) or you are implementing PyTorch C++/CUDA extensions using the ATen library, you might run into problems of exporting your PyTorch model to ONNX model. Unlike other export errors which are due to bad API design or bug, this ATen problem is originated from the mismatch between PyTorch and ONNX. Although the PyTorch team is keeping contributing to ONNX, this gap might not be filled easily in the near future.

In this blog post, I will talk about some of the experiences and a painful “solution”.

ATen Library

This is a list of all the classes and functions that the ATen library currently has. Just like the Kernel fusion in TensorRT, these ATen operators do many basic math operations in one kernel call. For example, if you are going to do A * B + C, without (framework automatic) optimization, it is two kernel calls. First, we call matrix multiplication, then we call addition. Each call has computation overhead. Depending on the math of the operation was doing, the overhead might become the bottleneck of computation. Therefore, we need to minimize the call of kernels. This is where the ATen library comes into play. For A * B + C, we just have to do at::addmm(C, B, A) and the calculation was finished in one call. For more sophisticated math calculations, we could probably have more optimizations inside the single kernel to further accelerate the computation.

ATen library is necessary and beneficial for PyTorch, but it might cause problems of exporting PyTorch models to ONNX models. Let’s see a concrete example which I got online.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F

class mynet(nn.Module):
def __init__(self):
super(mynet, self).__init__()
def forward(self, x):
n, c, h, w = x.size()
x = F.adaptive_avg_pool2d(x, (1, 1))

#return F.interpolate(x, (h, w), mode='bilinear', align_corners=False)
# RuntimeError: ONNX symbolic expected a constant value in the trace

return F.interpolate(x, (480, 640), mode='bilinear', align_corners=False) # this is ok

#return F.interpolate(x, (h, w), mode='bilinear', align_corners=True)
# RuntimeError: ONNX symbolic expected a constant value in the trace

#return F.interpolate(x, (480, 640), mode='bilinear', align_corners=True)
# UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
# RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_bilinear2d


net = mynet()
x = torch.randn(1, 3, 480, 640)

device = torch.device("cuda:0")
net = net.to(device)
x = x.to(device)

out = net(x)
print('out.size ', out.size()) #(1, 3, 480, 640)

torch.onnx.export(net, x, "test.onnx", verbose=True)

In the above example, if mynet returns return F.interpolate(x, (480, 640), mode='bilinear', align_corners=True) instead, we will get the following error in PyTorch 1.2.0a0.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
out.size  torch.Size([1, 3, 480, 640])
/opt/conda/lib/python3.6/site-packages/torch/onnx/symbolic_helper.py:171: UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
graph(%x.1 : Float(1, 3, 480, 640)):
%1 : Float(1, 3, 1, 1) = onnx::GlobalAveragePool(%x.1), scope: mynet
%2 : int[] = onnx::Constant[value= 480 640 [ Variable[CPULongType]{2} ]]()
%3 : Long() = onnx::Constant[value={1}](), scope: mynet
%4 : Float(1, 3, 480, 640) = aten::upsample_bilinear2d(%1, %2, %3), scope: mynet
return (%4)

Traceback (most recent call last):
File "test2.py", line 35, in <module>
torch.onnx.export(net, x, "test.onnx", verbose=True)
File "/opt/conda/lib/python3.6/site-packages/torch/onnx/__init__.py", line 32, in export
return utils.export(*args, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/onnx/utils.py", line 133, in export
example_outputs=example_outputs, strip_doc_string=strip_doc_string)
File "/opt/conda/lib/python3.6/site-packages/torch/onnx/utils.py", line 371, in _export
strip_doc_string)
RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_bilinear2d

Defined at:


Graph we tried to export:
graph(%x.1 : Float(1, 3, 480, 640)):
%1 : Float(1, 3, 1, 1) = onnx::GlobalAveragePool(%x.1), scope: mynet
%2 : int[] = onnx::Constant[value= 480 640 [ Variable[CPULongType]{2} ]]()
%3 : Long() = onnx::Constant[value={1}](), scope: mynet
%4 : Float(1, 3, 480, 640) = aten::upsample_bilinear2d(%1, %2, %3), scope: mynet
return (%4)

This is because aten::upsample_bilinear2d was used to do F.interpolate(x, (480, 640), mode='bilinear', align_corners=True) in PyTorch, but there is no corresponding representation and implementation of this aten::upsample_bilinear2d in ONNX so ONNX does not recognize and understand aten::upsample_bilinear2d. Currently ONNX does not allow bypassing the unknown operators, therefore, exporting model from PyTorch to ONNX was failed.

ONNX does have a lot of operators which have correspondence to many of the ATen operators. ONNX recognize those ATen operators by asking the PyTorch team (or user) to create a symbolic link of the ATen operator to ONNX operator. However, if there is no such operator implementation in ONNX, creating symbolic links is useless.

Although I believe PyTorch and ONNX teams are collaborating trying hard to support all the ATen operators in the list, this might not be easily achieved in the near future from the progress I see so far. Although I don’t quite understand that since there is already open source C++ implementation on those ATen operators, why can’t ONNX just absorb those easily. There might be more ATen operators coming out in the future, as long as there is any single ATen operator which is not supported by ONNX, users will complain.

A Painful Solution

A very painful solution to export PyTorch model containing ONNX-unsupported ATen operator to ONNX model is to work on the PyTorch code to replace the ATen operator to several smaller ONNX-supported operators. This will require us to use our math knowledge in deep learning.

For example, assuming at::addmm operator is not supported by ONNX but torch::mul and torch::add operators are supported by ONNX. We also assume that in your PyTorch code torch.nn.functional.affine (which might not exist at all) is using at::addmm, torch.mul is corresponding to torch::mul and torch.add is corresponding to torch::add. We may just have to replace

1
torch.nn.functional.affine(A, B, C)

to

1
torch.mul(torch.mul(A, B), C)

In this way, ONNX will recognize all the operators used in the PyTorch model thus the conversion will be successful. However, the example I gave above is not realistic. In practice, it could be much harder and you will need to know all the details, check the implementation of the ATen operator, and implement on your own using ONNX-supported operators. Since it is not automatic, people will certainly not feel appreciated.

Final Remarks

It seems that the PyTorch team is migrating the at namespace to torch namespace.

PyTorch Model Export to ONNX Failed Due to ATen

https://leimao.github.io/blog/PyTorch-ATen-ONNX/

Author

Lei Mao

Posted on

07-03-2019

Updated on

07-03-2019

Licensed under


Comments