TensorRT is a high-performance deep learning inference SDK that accelerates deep learning inference on NVIDIA GPUs. It allows user to create custom plugins for the neural network layers that have not been supported by TensorRT.
In this blog post, I would like to demonstrate how to implement and integrate a custom plugin into TensorRT using a concrete and self-contained example.
Identity ONNX Model
To make the custom plugin implementation and integration less complicated, we created a simle identity ONNX model that consists of three Conv nodes whose weights and attributes are orchestrated so that the convolution operation is a simple identity operation. The second Conv node in the ONNX is replaced with a custom ONNX node IdentityConv that is not defined in the ONNX operator set. Without the custom plugin for the IdentityConv node, TensorRT engine cannot be created from the ONNX model.
graph = gs.Graph(nodes=[node_1, node_2, node_3], inputs=[X0], outputs=[X3], opset=opset_version) model = gs.export_onnx(graph) # Shape inference does not quite work here because of the custom operator. # model = onnx.shape_inference.infer_shapes(model) onnx.save(model, onnx_file_path)
if __name__ == '__main__':
main()
IdentityConv Custom Plugin Implementation
The custom plugin class has to be derived from the nvinfer1::IPluginV2IOExt or nvinfer1::IPluginV2DynamicExt class. The nvinfer1::IPluginV2Ext class has been deprecated and should not be used. In this example, the nvinfer1::IPluginV2IOExt class is used.
To perform the IdentityConv operation, the custom plugin class has to override the nvinfer1::IPluginV2IOExt::enqueue method. In our case, we simply copy the input tensor to the output tensor using cudaMemcpyAsync.
voidIdentityConv::deserialize(uint8_tconst* data, size_t length) { // In our simple use case, even though there is no parameter used for this // plugin, we deserialize and serialize some attributes for demonstration // purposes. uint8_tconst* d{data}; mParams.group = read<int32_t>(d); mParams.dtype = read<nvinfer1::DataType>(d); mParams.channelSize = read<int32_t>(d); mParams.height = read<int32_t>(d); mParams.width = read<int32_t>(d); mParams.dtypeBytes = read<size_t>(d); PLUGIN_ASSERT(d == data + length); }
voidIdentityConv::configurePlugin(nvinfer1::PluginTensorDesc const* in, int32_t nbInput, nvinfer1::PluginTensorDesc const* out, int32_t nbOutput)noexcept { // Communicates the number of inputs and outputs, dimensions, and datatypes // of all inputs and outputs, broadcast information for all inputs and // outputs, the chosen plugin format, and maximum batch size. At this point, // the plugin sets up its internal state and selects the most appropriate // algorithm and data structures for the given configuration. Note: Resource // allocation is not allowed in this API because it causes a resource leak.
// This member function will only be called during engine build time.
int32_tIdentityConv::initialize()noexcept { // The configuration is known at this time, and the inference engine is // being created, so the plugin can set up its internal data structures and // prepare for execution. Such setup might include initializing libraries, // allocating memory, etc. In our case, we don't need to prepare anything. return0; }
voidIdentityConv::terminate()noexcept { // The engine context is destroyed, and all the resources held by the plugin // must be released. }
nvinfer1::Dims IdentityConv::getOutputDimensions(int32_t index, nvinfer1::Dims const* inputs, int32_t nbInputDims)noexcept { // Even though non-IPluginV2DynamicExt plugins are compatible with explicit // batch mode networks, their implementation must be independent of the type // of network (implicit/explicit batch mode) in which it is expected to be // used. As such, when using such plugins in explicit batch mode networks: // * The leading dimension of the first input (before being passed to the // plugin) is inferred to be the batch dimension. // * TensorRT pops this first dimension identified above before inputs are // passed to the plugin, and pushes it to the front of any outputs emitted // by the plugin. This means that the batch dimension must not be specified // in getOutputDimensions. PLUGIN_ASSERT(index == 0); PLUGIN_ASSERT(nbInputDims == 2); PLUGIN_ASSERT(inputs != nullptr); // CHW nvinfer1::Dims dimsOutput; PLUGIN_ASSERT(inputs[0].nbDims == 3); // Identity operation. // Just copy the dimensions from the input tensor. dimsOutput.nbDims = inputs[0].nbDims; dimsOutput.d[0] = inputs[0].d[0]; dimsOutput.d[1] = inputs[0].d[1]; dimsOutput.d[2] = inputs[0].d[2];
return dimsOutput; }
size_tIdentityConv::getWorkspaceSize(int32_t maxBatchSize)constnoexcept { // No scratch space is required for this plugin. return0; }
voidIdentityConv::serialize(void* buffer)constnoexcept { char* d{reinterpret_cast<char*>(buffer)}; char* const a{d}; // Be cautious, the order has to match deserialization. write(d, mParams.group); write(d, mParams.dtype); write(d, mParams.channelSize); write(d, mParams.height); write(d, mParams.width); write(d, mParams.dtypeBytes); PLUGIN_ASSERT(d == a + getSerializationSize()); }
boolIdentityConv::supportsFormatCombination( int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs)constnoexcept { // For this method inputs are numbered 0..(nbInputs-1) and outputs are // numbered nbInputs..(nbInputs+nbOutputs-1). Using this numbering, pos is // an index into InOut, where 0 <= pos < nbInputs+nbOutputs. PLUGIN_ASSERT(nbInputs == 2 && nbOutputs == 1 && pos < nbInputs + nbOutputs); bool isValidCombination = false;
// Suppose we support only a limited number of format configurations. isValidCombination |= (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && inOut[pos].type == nvinfer1::DataType::kFLOAT); isValidCombination |= (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && inOut[pos].type == nvinfer1::DataType::kHALF); // Make sure the input tensor and output tensor types and formats are same. isValidCombination &= (pos < nbInputs || (inOut[pos].format == inOut[0].format && inOut[pos].type == inOut[0].type));
nvinfer1::IPluginV2IOExt* IdentityConv::clone()constnoexcept { // It's possible to encounter errors during cloning. // For example, if the memory to allocate is insufficient, exceptions can be // thrown. try { IPluginV2IOExt* const plugin{new IdentityConv{mParams}}; plugin->setPluginNamespace(mPluginNamespace); return plugin; } catch (std::exception const& e) { caughtError(e); } returnnullptr; }
nvinfer1::DataType IdentityConv::getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs)constnoexcept { // One output. PLUGIN_ASSERT(index == 0); PLUGIN_ASSERT(nbInputs == 2); // The output type is the same as the input type. return inputTypes[0]; }
In the above implementation, some of the other caveats, such as how to specify the supported custom plugin configurations, what and how to serialize and deserialize the plugin parameters, have also been demonstrated.
IdentityConv Custom Plugin Creator Implementation
TensorRT will call the nvinfer1::IPluginCreator::createPlugin method to create the custom plugin instance. Therefore, we have to create a custom plugin creator class that is derived from the nvinfer1::IPluginCreator class.
TensorRT allows both static and dynamic registration of custom plugins. The REGISTER_TENSORRT_PLUGIN macro is used to register the custom plugin creator class. In this example, we use the dynamic registration method. Therefore, the REGISTER_TENSORRT_PLUGIN macro is commented out.
// This is not needed for plugin dynamic registration. // REGISTER_TENSORRT_PLUGIN(IdentityConvCreator);
// Plugin creator IdentityConvCreator::IdentityConvCreator() { // Declare the ONNX attributes that the ONNX parser will collect from the // ONNX model that contains the IdentityConv node.
nvinfer1::IPluginV2IOExt* IdentityConvCreator::createPlugin( charconst* name, nvinfer1::PluginFieldCollection const* fc)noexcept { // The attributes from the ONNX node will be parsed and passed via fc. try { nvinfer1::PluginField const* fields{fc->fields}; int32_t nbFields{fc->nbFields};
In the above implementation, some of the other caveats, such as how to rely on the TensorRT ONNX parser to parse the attributes of the custom ONNX node, have also been demonstrated.
Expose IdentityConv Custom Plugin Creator to TensorRT
Because the custom plugin library will be loaded dynamically by TensorRT, we have to expose the plugin creator class to TensorRT from the dynamic library. The setLoggerFinder and getPluginCreators functions are mandatory to implement so that TensorRT can call them successfully during the plugin creation time.
The custom plugin library can be dynamically registered and loaded by TensorRT via the nvinfer1::IPluginRegistry::loadLibrary method. The ONNX parser can then parse the custom ONNX node and use the custom plugin to run the inference for the custom ONNX node.
The custom plugin can also be serialized into the TensorRT engine file via the nvinfer1::IBuilderConfig::setPluginsToSerialize method so that the custom plugin library is not required to be loaded during the inference time.
// Write the serialized engine to a file. std::ofstream engineFile{engine_file_path.c_str(), std::ios::binary}; if (!engineFile.is_open()) { std::cerr << "Failed to open the engine file." << std::endl; return EXIT_FAILURE; } engineFile.write(static_cast<charconst*>(serializedModel->data()), serializedModel->size()); engineFile.close();
std::cout << "Successfully serialized the engine to the file: " << engine_file_path << std::endl;
return EXIT_SUCCESS; }
Run Engine With Custom Plugin
Because the custom plugin has been serialized into the TensorRT engine file, we don’t need to load the custom plugin library during the inference time. We can simply create the runtime and deserialize the engine from the engine file.
// The plugin has already been serialized with the engine. // There is no need to load the plugin library. std::string const data_dir_path{"data"}; std::string const engine_file_name{"identity_neural_network.engine"}; std::string const engine_file_path{data_dir_path + "/" + engine_file_name};
// Create CUDA stream. cudaStream_t stream; CHECK_CUDA_ERROR(cudaStreamCreate(&stream));
// The engine we built is FP32 NCHW IO. nvinfer1::DataType const expected_dtype{nvinfer1::DataType::kFLOAT}; size_tconst expected_dtype_byte_size{4U}; nvinfer1::TensorFormat const expected_format{ nvinfer1::TensorFormat::kLINEAR};
// Error tolerance for unit test. floatconst rtol{1e-5f}; floatconst atol{1e-8f};
// Deserialize the engine. std::unique_ptr<nvinfer1::IRuntime, InferDeleter> runtime{ nvinfer1::createInferRuntime(logger)}; if (runtime == nullptr) { std::cerr << "Failed to create the runtime." << std::endl; return EXIT_FAILURE; }
std::ifstream engine_file{engine_file_path, std::ios::binary}; if (!engine_file) { std::cerr << "Failed to open the engine file." << std::endl; return EXIT_FAILURE; }
// Create random input values. for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i) { size_tconst tensor_size{input_tensor_sizes.at(i)}; create_random_data(static_cast<float*>(input_tensor_host_buffers.at(i)), tensor_size); }
// Copy input data from host to device. for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i) { size_tconst tensor_size_bytes{input_tensor_sizes.at(i) * expected_dtype_byte_size}; CHECK_CUDA_ERROR(cudaMemcpy(input_tensor_device_buffers.at(i), input_tensor_host_buffers.at(i), tensor_size_bytes, cudaMemcpyHostToDevice)); }
// Bind IO tensor buffers to the execution context. for (size_t i{0U}; i < input_tensor_device_buffers.size(); ++i) { charconst* const tensor_name{input_tensor_names.at(i)}; context->setTensorAddress(tensor_name, input_tensor_device_buffers.at(i)); } for (size_t i{0U}; i < output_tensor_device_buffers.size(); ++i) { charconst* const tensor_name{output_tensor_names.at(i)}; context->setTensorAddress(tensor_name, output_tensor_device_buffers.at(i)); }
// Run inference a couple of times. size_tconst num_iterations{8U}; for (size_t i{0U}; i < num_iterations; ++i) { boolconst status{context->enqueueV3(stream)}; if (!status) { std::cerr << "Failed to run inference." << std::endl; return EXIT_FAILURE; } }
// Copy output data from device to host. for (size_t i{0U}; i < output_tensor_host_buffers.size(); ++i) { size_tconst tensor_size_bytes{output_tensor_sizes.at(i) * expected_dtype_byte_size}; CHECK_CUDA_ERROR(cudaMemcpy(output_tensor_host_buffers.at(i), output_tensor_device_buffers.at(i), tensor_size_bytes, cudaMemcpyDeviceToHost)); }
// Verify the output given it's an identity neural network. for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i) { if (input_tensor_sizes.at(i) != output_tensor_sizes.at(i)) { std::cerr << "Input and output tensor sizes do not match." << std::endl; return EXIT_FAILURE; } if (!all_close(static_cast<float*>(input_tensor_host_buffers.at(i)), static_cast<float*>(output_tensor_host_buffers.at(i)), input_tensor_sizes.at(i), rtol, atol)) { std::cerr << "Input and output tensor values do not match." << std::endl; return EXIT_FAILURE; } }
std::cout << "Successfully verified the output." << std::endl;
// Release resources. CHECK_CUDA_ERROR(cudaStreamDestroy(stream)); for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i) { CHECK_CUDA_ERROR(cudaFreeHost(input_tensor_host_buffers.at(i))); } for (size_t i{0U}; i < input_tensor_device_buffers.size(); ++i) { CHECK_CUDA_ERROR(cudaFree(input_tensor_device_buffers.at(i))); } for (size_t i{0U}; i < output_tensor_host_buffers.size(); ++i) { CHECK_CUDA_ERROR(cudaFreeHost(output_tensor_host_buffers.at(i))); } for (size_t i{0U}; i < output_tensor_device_buffers.size(); ++i) { CHECK_CUDA_ERROR(cudaFree(output_tensor_device_buffers.at(i))); } }
Because the output tensor values matches the input tensor values, we have successfully verified the implementation and integration of the custom plugin.