ONNX Runtime C++ Inference

Introduction

ONNX is the open standard format for neural network model interoperability. It also has an ONNX Runtime that is able to execute the neural network model using different execution providers, such as CPU, CUDA, TensorRT, etc. While there has been a lot of examples for running inference using ONNX Runtime Python APIs, the examples using ONNX Runtime C++ APIs are quite limited.

In this blog post, I would like discuss how to do image processing using OpenCV C++ APIs and run inference using ONNX Runtime C++ APIs.

Example

In this example, I used the public SqueezeNet ONNX model and royalty-free images from Pixabay. The implementation, however, should also be compatible with most of the ImageNet classification neural networks and images from other sources with slight modifications. In addition, I also compared the inference latencies measured from the CPU and CUDA execution providers.

The implementation and the Docker container are available from the GitHub.

Installation

In this example, we used OpenCV for image processing and ONNX Runtime for inference. The C++ headers and libraries for OpenCV and ONNX Runtime are usually not available in the system or a well-maintained Docker container. We would have to build OpenCV and ONNX Runtime from source and install. OpenCV and ONNX Runtime do support CUDA. So we would have to build the CUDA components for at least ONNX Runtime. The build takes very long time and I recommend to use the prepared Dockerfile to build a Docker container instead of building the library manually.

Image Processing

The image processing process using OpenCV C++ APIs is not as straightforward as using OpenCV Python APIs. We would have to

  1. Read an image in HWC BGR UINT8 format.
  2. Resize the image.
  3. Convert the image to HWC RGB UINT8 format.
  4. Convert the image to HWC RGB float format by dividing each pixel by 255.
  5. Split the RGB channels from the image.
  6. Normalize each channel.
  7. Merge the RGB channels back to the image.
  8. Convert the image to CHW RGB float format.

The implementation looks as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
cv::Mat imageBGR = cv::imread(imageFilepath, cv::ImreadModes::IMREAD_COLOR);
cv::Mat resizedImageBGR, resizedImageRGB, resizedImage, preprocessedImage;
cv::resize(imageBGR, resizedImageBGR,
cv::Size(inputDims.at(2), inputDims.at(3)),
cv::InterpolationFlags::INTER_CUBIC);
cv::cvtColor(resizedImageBGR, resizedImageRGB,
cv::ColorConversionCodes::COLOR_BGR2RGB);
resizedImageRGB.convertTo(resizedImage, CV_32F, 1.0 / 255);

cv::Mat channels[3];
cv::split(resizedImage, channels);
// Normalization per channel
// Normalization parameters obtained from
// https://github.com/onnx/models/tree/master/vision/classification/squeezenet
channels[0] = (channels[0] - 0.485) / 0.229;
channels[1] = (channels[1] - 0.456) / 0.224;
channels[2] = (channels[2] - 0.406) / 0.225;
cv::merge(channels, 3, resizedImage);
// HWC to CHW
cv::dnn::blobFromImage(resizedImage, preprocessedImage);

Run Inference

To run inference using ONNX Runtime, the user is responsible for creating and managing the input and output buffers. These buffers could be created and managed via std::vector. The linear-format input data should be copied to the buffer for ONNX Runtime inference.

1
2
3
4
5
6
7
8
9
size_t inputTensorSize = vectorProduct(inputDims);
std::vector<float> inputTensorValues(inputTensorSize);
inputTensorValues.assign(preprocessedImage.begin<float>(),
preprocessedImage.end<float>());

size_t outputTensorSize = vectorProduct(outputDims);
assert(("Output tensor size should equal to the label set size.",
labels.size() == outputTensorSize));
std::vector<float> outputTensorValues(outputTensorSize);

Once the buffers were created, they would be used for creating instances of Ort::Value which is the tensor format for ONNX Runtime. There could be multiple inputs for a neural network, so we have to prepare an array of Ort::Value instances for inputs and outputs respectively even if we only have one input and one output.

1
2
3
4
5
6
7
8
9
10
std::vector<Ort::Value> inputTensors;
std::vector<Ort::Value> outputTensors;
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
inputTensors.push_back(Ort::Value::CreateTensor<float>(
memoryInfo, inputTensorValues.data(), inputTensorSize, inputDims.data(),
inputDims.size()));
outputTensors.push_back(Ort::Value::CreateTensor<float>(
memoryInfo, outputTensorValues.data(), outputTensorSize,
outputDims.data(), outputDims.size()));

Creating ONNX Runtime inference sessions, querying input and output names, dimensions, and types are trivial, and I will skip these here.

To run inference, we provide the run options, an array of input names corresponding to the the inputs in the input tensor, an array of input tensor, number of inputs, an array of output names corresponding to the the outputs in the output tensor, an array of output tensor, number of outputs.

1
2
3
4
// https://github.com/microsoft/onnxruntime/blob/rel-1.6.0/include/onnxruntime/core/session/onnxruntime_cxx_api.h#L353
session.Run(Ort::RunOptions{nullptr}, inputNames.data(),
inputTensors.data(), 1, outputNames.data(),
outputTensors.data(), 1);

The inference result could be found in the buffer for the output tensors, which are usually the buffer from std::vector instances.

Demo

We feeded a bee eater image to the neural network, and run the inference using CPU and CUDA execution providers.

Bee Eater
1
2
3
4
5
6
7
8
9
10
11
12
13
14
$ ./inference  --use_cpu
Inference Execution Provider: CPU
Number of Input Nodes: 1
Number of Output Nodes: 1
Input Name: data
Input Type: float
Input Dimensions: [1, 3, 224, 224]
Output Name: squeezenet0_flatten0_reshape0
Output Type: float
Output Dimensions: [1, 1000]
Predicted Label ID: 92
Predicted Label: n01828970 bee eater
Uncalibrated Confidence: 0.996137
Minimum Inference Latency: 7.45 ms
1
2
3
4
5
6
7
8
9
10
11
12
13
14
$ ./inference  --use_cuda
Inference Execution Provider: CUDA
Number of Input Nodes: 1
Number of Output Nodes: 1
Input Name: data
Input Type: float
Input Dimensions: [1, 3, 224, 224]
Output Name: squeezenet0_flatten0_reshape0
Output Type: float
Output Dimensions: [1, 1000]
Predicted Label ID: 92
Predicted Label: n01828970 bee eater
Uncalibrated Confidence: 0.996137
Minimum Inference Latency: 0.98 ms

The ONNX Runtime inference implementation has successfully classify the bee eater image as bee eater with high confidence. The inference latency using CUDA is 0.98 ms on an NVIDIA RTX 2080TI GPU whereas the inference latency using CPU is 7.45 ms on an Intel i9-9900K CPU.

Final Remarks

Using TensorRT execution provider might result in even better inference latency. However, I did not measure it because creating a correct Docker container and build the correct libraries are very tedious.

References

Author

Lei Mao

Posted on

12-23-2020

Updated on

12-23-2020

Licensed under


Comments