The PyTorch to ONNX conversion and ONNX inference often require saving and loading ONNX files on hard drive. With streams, sometimes referred as file-like objects, it is possible to save and load ONNX files on memory, which is significantly faster than the same process on hard drive.
In this blog post, I would like to discuss ONNX IO with streams.
ONNX IO With Streams
The following example from PyTorch to ONNX export and ONNX Runtime inference does not require any interaction with hard drive.
import onnx import onnxruntime as rt import numpy as np import torch import torchvision
defmain() -> None:
input_shape = (1, 3, 224, 224)
# Create a PyTorch model for ONNX export. torch_model = torchvision.models.resnet18(pretrained=False)
# Create a file-like binary stream using an in-memory bytes buffer. with io.BytesIO() as f:
# Export the model to the binary stream. torch.onnx.export(model=torch_model, args=torch.randn(*input_shape), f=f)
# Use ONNX load_model API to load a model from a binary stream. # Change the stream position to the start of the stream. f.seek(0) model_proto_from_binary_stream = onnx.load_model(f, onnx.ModelProto)
# Use ONNX load_model_from_string API to load a model from a binary string. model_proto_from_binary_string = onnx.load_model_from_string( f.getvalue(), onnx.ModelProto)
# Equivalence of the two ONNX models loaded using different approaches. assert model_proto_from_binary_stream == model_proto_from_binary_string
model_proto = model_proto_from_binary_stream
with io.BytesIO() as f:
# Use ONNX save_model API to save model to a binary stream. onnx.save_model(model_proto, f)
# Use ONNX load_model API to load a model from a binary stream. # Change the stream position to the start of the stream. f.seek(0) model_proto_from_binary_stream = onnx.load_model(f, onnx.ModelProto)
# Use ONNX load_model_from_string API to load a model from a binary string. model_proto_from_binary_string = onnx.load_model_from_string( f.getvalue(), onnx.ModelProto)
# Use ONNX _serialize to get binary string from ONNX model. model_proto_bytes = onnx._serialize(model_proto) assert model_proto_bytes == f.getvalue()
# Use ONNX _deserialize to get ONNX model from binary string. model_proto_from_deserialization = onnx._deserialize( model_proto_bytes, onnx.ModelProto()) assert model_proto == model_proto_from_deserialization
# Run ONNX Runtime. # InferenceSession could also take bytes. inference_session = rt.InferenceSession(model_proto_bytes) onnxruntime_random_input = np.random.randn(*input_shape).astype(np.float32)