ONNX IO Stream

Introduction

The PyTorch to ONNX conversion and ONNX inference often require saving and loading ONNX files on hard drive. With streams, sometimes referred as file-like objects, it is possible to save and load ONNX files on memory, which is significantly faster than the same process on hard drive.

In this blog post, I would like to discuss ONNX IO with streams.

ONNX IO With Streams

The following example from PyTorch to ONNX export and ONNX Runtime inference does not require any interaction with hard drive.

onnx_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import io

import onnx
import onnxruntime as rt
import numpy as np
import torch
import torchvision


def main() -> None:

input_shape = (1, 3, 224, 224)

# Create a PyTorch model for ONNX export.
torch_model = torchvision.models.resnet18(pretrained=False)

# Create a file-like binary stream using an in-memory bytes buffer.
with io.BytesIO() as f:

# Export the model to the binary stream.
torch.onnx.export(model=torch_model,
args=torch.randn(*input_shape),
f=f)

# Use ONNX load_model API to load a model from a binary stream.
# Change the stream position to the start of the stream.
f.seek(0)
model_proto_from_binary_stream = onnx.load_model(f, onnx.ModelProto)

# Use ONNX load_model_from_string API to load a model from a binary string.
model_proto_from_binary_string = onnx.load_model_from_string(
f.getvalue(), onnx.ModelProto)

# Equivalence of the two ONNX models loaded using different approaches.
assert model_proto_from_binary_stream == model_proto_from_binary_string

model_proto = model_proto_from_binary_stream

with io.BytesIO() as f:

# Use ONNX save_model API to save model to a binary stream.
onnx.save_model(model_proto, f)

# Use ONNX load_model API to load a model from a binary stream.
# Change the stream position to the start of the stream.
f.seek(0)
model_proto_from_binary_stream = onnx.load_model(f, onnx.ModelProto)

# Use ONNX load_model_from_string API to load a model from a binary string.
model_proto_from_binary_string = onnx.load_model_from_string(
f.getvalue(), onnx.ModelProto)

assert model_proto == model_proto_from_binary_stream
assert model_proto == model_proto_from_binary_string

# Use ONNX _serialize to get binary string from ONNX model.
model_proto_bytes = onnx._serialize(model_proto)
assert model_proto_bytes == f.getvalue()

# Use ONNX _deserialize to get ONNX model from binary string.
model_proto_from_deserialization = onnx._deserialize(
model_proto_bytes, onnx.ModelProto())
assert model_proto == model_proto_from_deserialization

# Run ONNX Runtime.
# InferenceSession could also take bytes.
inference_session = rt.InferenceSession(model_proto_bytes)
onnxruntime_random_input = np.random.randn(*input_shape).astype(np.float32)

input_name = inference_session.get_inputs()[0].name
prediction = inference_session.run(
None, {input_name: onnxruntime_random_input})[0]


if __name__ == "__main__":

main()

References

Author

Lei Mao

Posted on

01-03-2022

Updated on

01-03-2022

Licensed under


Comments