TensorRT Custom Plugin Example

Introduction

TensorRT is a high-performance deep learning inference SDK that accelerates deep learning inference on NVIDIA GPUs. It allows user to create custom plugins for the neural network layers that have not been supported by TensorRT.

In this blog post, I would like to demonstrate how to implement and integrate a custom plugin into TensorRT using a concrete and self-contained example.

Identity ONNX Model

To make the custom plugin implementation and integration less complicated, we created a simle identity ONNX model that consists of three Conv nodes whose weights and attributes are orchestrated so that the convolution operation is a simple identity operation. The second Conv node in the ONNX is replaced with a custom ONNX node IdentityConv that is not defined in the ONNX operator set. Without the custom plugin for the IdentityConv node, TensorRT engine cannot be created from the ONNX model.

create_identity_neural_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Create a neural network that consists of three identity convolutional layers.

import os
import numpy as np
import onnx
import onnx_graphsurgeon as gs

def main():

opset_version = 13
data_directory_path = "data"
onnx_file_name = "identity_neural_network.onnx"
onnx_file_path = os.path.join(data_directory_path, onnx_file_name)

input_shape = (1, 3, 480, 960)
input_data = np.random.rand(*input_shape).astype(np.float32)
input_channels = input_shape[1]

# configure a dummy conv:
weights_shape = (input_channels, 1, 1, 1)
num_groups = input_channels
weights_data = np.ones(weights_shape, dtype=np.float32)

# generate ONNX model
X0 = gs.Variable(name="X0", dtype=np.float32, shape=input_shape)
W0 = gs.Constant(name="W0", values=weights_data)
X1 = gs.Variable(name="X1", dtype=np.float32, shape=input_shape)
W1 = gs.Constant(name="W1", values=weights_data)
X2 = gs.Variable(name="X2", dtype=np.float32, shape=input_shape)
W2 = gs.Constant(name="W2", values=weights_data)
X3 = gs.Variable(name="X3", dtype=np.float32, shape=input_shape)

node_1 = gs.Node(name="Conv-1", op="Conv",
inputs=[X0, W0],
outputs=[X1],
attrs={
"kernel_shape": [1, 1],
"strides": [1, 1],
"pads": [0, 0, 0, 0],
"group": num_groups
})
# Use an custom operator IdentityConv Instead.
# This operator is not defined by ONNX and cannot be parsed by ONNX parser without custom plugin.
# node_2 = gs.Node(name="Conv-2", op="Conv",
# inputs=[X1, W1],
# outputs=[X2],
# attrs={
# "kernel_shape": [1, 1],
# "strides": [1, 1],
# "pads": [0, 0, 0, 0],
# "group": num_groups
# })
node_2 = gs.Node(name="Conv-2", op="IdentityConv",
inputs=[X1, W1],
outputs=[X2],
attrs={
"kernel_shape": [1, 1],
"strides": [1, 1],
"pads": [0, 0, 0, 0],
"group": num_groups
})
node_3 = gs.Node(name="Conv-3", op="Conv",
inputs=[X2, W2],
outputs=[X3],
attrs={
"kernel_shape": [1, 1],
"strides": [1, 1],
"pads": [0, 0, 0, 0],
"group": num_groups
})

graph = gs.Graph(nodes=[node_1, node_2, node_3], inputs=[X0], outputs=[X3], opset=opset_version)
model = gs.export_onnx(graph)
# Shape inference does not quite work here because of the custom operator.
# model = onnx.shape_inference.infer_shapes(model)
onnx.save(model, onnx_file_path)

if __name__ == '__main__':

main()

IdentityConv Custom Plugin Implementation

The custom plugin class has to be derived from the nvinfer1::IPluginV2IOExt or nvinfer1::IPluginV2DynamicExt class. The nvinfer1::IPluginV2Ext class has been deprecated and should not be used. In this example, the nvinfer1::IPluginV2IOExt class is used.

IdentityConvPlugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#ifndef TENSORRT_IDENTITY_CONV_PLUGIN_H
#define TENSORRT_IDENTITY_CONV_PLUGIN_H

#include <string>
#include <vector>

#include <cuda_runtime.h>

#include <NvInferRuntimePlugin.h>

constexpr char const* const kIDENTITY_CONV_PLUGIN_NAME{"IdentityConv"};
constexpr char const* const kIDENTITY_CONV_PLUGIN_VERSION{"1"};

namespace nvinfer1
{
namespace plugin
{

struct IdentityConvParameters
{
int32_t group;
nvinfer1::DataType dtype;
int32_t channelSize;
int32_t height;
int32_t width;
size_t dtypeBytes;
};

class IdentityConv : public nvinfer1::IPluginV2IOExt
{
public:
IdentityConv(IdentityConvParameters params);

IdentityConv(void const* data, size_t length);

~IdentityConv() override = default;

int32_t getNbOutputs() const noexcept override;

nvinfer1::Dims getOutputDimensions(int32_t index,
nvinfer1::Dims const* inputs,
int32_t nbInputDims) noexcept override;

int32_t initialize() noexcept override;

void terminate() noexcept override;

size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept override;

int32_t enqueue(int32_t batchSize, void const* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) noexcept override;

size_t getSerializationSize() const noexcept override;

void serialize(void* buffer) const noexcept override;

void configurePlugin(nvinfer1::PluginTensorDesc const* in, int32_t nbInput,
nvinfer1::PluginTensorDesc const* out,
int32_t nbOutput) noexcept override;

bool supportsFormatCombination(int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs,
int32_t nbOutputs) const noexcept override;

char const* getPluginType() const noexcept override;

char const* getPluginVersion() const noexcept override;

void destroy() noexcept override;

IPluginV2IOExt* clone() const noexcept override;

nvinfer1::DataType
getOutputDataType(int32_t index, nvinfer1::DataType const* inputType,
int32_t nbInputs) const noexcept override;

void setPluginNamespace(char const* pluginNamespace) noexcept override;

char const* getPluginNamespace() const noexcept override;

bool isOutputBroadcastAcrossBatch(int32_t outputIndex,
bool const* inputIsBroadcasted,
int32_t nbInputs) const noexcept override;

bool
canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept override;

private:
void deserialize(uint8_t const* data, size_t length);

// TensorRT plugin parameters.
IdentityConvParameters mParams;

char const* mPluginNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // TENSORRT_IDENTITY_CONV_PLUGIN_H

To perform the IdentityConv operation, the custom plugin class has to override the nvinfer1::IPluginV2IOExt::enqueue method. In our case, we simply copy the input tensor to the output tensor using cudaMemcpyAsync.

IdentityConvPlugin.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#include <cstdlib>
#include <cstring>
#include <exception>
#include <iostream>
#include <vector>

#include <NvInferRuntime.h>
#include <NvInferRuntimePlugin.h>

#include "IdentityConvPlugin.h"
#include "PluginUtils.h"

namespace nvinfer1
{
namespace plugin
{

// Write values into buffer
template <typename Type, typename BufferType>
void write(BufferType*& buffer, Type const& val)
{
static_assert(sizeof(BufferType) == 1, "BufferType must be a 1 byte type.");
std::memcpy(buffer, &val, sizeof(Type));
buffer += sizeof(Type);
}

// Read values from buffer
template <typename OutType, typename BufferType>
OutType read(BufferType const*& buffer)
{
static_assert(sizeof(BufferType) == 1, "BufferType must be a 1 byte type.");
OutType val{};
std::memcpy(&val, static_cast<void const*>(buffer), sizeof(OutType));
buffer += sizeof(OutType);
return val;
}

IdentityConv::IdentityConv(IdentityConvParameters params) : mParams{params} {}

IdentityConv::IdentityConv(void const* data, size_t length)
{
deserialize(static_cast<uint8_t const*>(data), length);
}

void IdentityConv::deserialize(uint8_t const* data, size_t length)
{
// In our simple use case, even though there is no parameter used for this
// plugin, we deserialize and serialize some attributes for demonstration
// purposes.
uint8_t const* d{data};
mParams.group = read<int32_t>(d);
mParams.dtype = read<nvinfer1::DataType>(d);
mParams.channelSize = read<int32_t>(d);
mParams.height = read<int32_t>(d);
mParams.width = read<int32_t>(d);
mParams.dtypeBytes = read<size_t>(d);
PLUGIN_ASSERT(d == data + length);
}

int32_t IdentityConv::getNbOutputs() const noexcept { return 1; }

void IdentityConv::configurePlugin(nvinfer1::PluginTensorDesc const* in,
int32_t nbInput,
nvinfer1::PluginTensorDesc const* out,
int32_t nbOutput) noexcept
{
// Communicates the number of inputs and outputs, dimensions, and datatypes
// of all inputs and outputs, broadcast information for all inputs and
// outputs, the chosen plugin format, and maximum batch size. At this point,
// the plugin sets up its internal state and selects the most appropriate
// algorithm and data structures for the given configuration. Note: Resource
// allocation is not allowed in this API because it causes a resource leak.

// This member function will only be called during engine build time.

// Validate input arguments.
PLUGIN_ASSERT(nbInput == 2);
PLUGIN_ASSERT(nbOutput == 1);
PLUGIN_ASSERT(in[0].dims.nbDims == 3);
PLUGIN_ASSERT(out[0].dims.nbDims == 3);
PLUGIN_ASSERT(in[0].dims.d[0] == out[0].dims.d[0]);
PLUGIN_ASSERT(in[0].dims.d[1] == out[0].dims.d[1]);
PLUGIN_ASSERT(in[0].dims.d[2] == out[0].dims.d[2]);
PLUGIN_ASSERT(in[0].type == out[0].type);

mParams.dtype = in[0].type;
mParams.channelSize = in[0].dims.d[0];
mParams.height = in[0].dims.d[1];
mParams.width = in[0].dims.d[2];

if (mParams.dtype == nvinfer1::DataType::kINT8)
{
mParams.dtypeBytes = 1;
}
else if (mParams.dtype == nvinfer1::DataType::kHALF)
{
mParams.dtypeBytes = 2;
}
else if (mParams.dtype == nvinfer1::DataType::kFLOAT)
{
mParams.dtypeBytes = 4;
}
else
{
PLUGIN_ASSERT(false);
}
}

int32_t IdentityConv::initialize() noexcept
{
// The configuration is known at this time, and the inference engine is
// being created, so the plugin can set up its internal data structures and
// prepare for execution. Such setup might include initializing libraries,
// allocating memory, etc. In our case, we don't need to prepare anything.
return 0;
}

void IdentityConv::terminate() noexcept
{
// The engine context is destroyed, and all the resources held by the plugin
// must be released.
}

nvinfer1::Dims IdentityConv::getOutputDimensions(int32_t index,
nvinfer1::Dims const* inputs,
int32_t nbInputDims) noexcept
{
// Even though non-IPluginV2DynamicExt plugins are compatible with explicit
// batch mode networks, their implementation must be independent of the type
// of network (implicit/explicit batch mode) in which it is expected to be
// used. As such, when using such plugins in explicit batch mode networks:
// * The leading dimension of the first input (before being passed to the
// plugin) is inferred to be the batch dimension.
// * TensorRT pops this first dimension identified above before inputs are
// passed to the plugin, and pushes it to the front of any outputs emitted
// by the plugin. This means that the batch dimension must not be specified
// in getOutputDimensions.
PLUGIN_ASSERT(index == 0);
PLUGIN_ASSERT(nbInputDims == 2);
PLUGIN_ASSERT(inputs != nullptr);
// CHW
nvinfer1::Dims dimsOutput;
PLUGIN_ASSERT(inputs[0].nbDims == 3);
// Identity operation.
// Just copy the dimensions from the input tensor.
dimsOutput.nbDims = inputs[0].nbDims;
dimsOutput.d[0] = inputs[0].d[0];
dimsOutput.d[1] = inputs[0].d[1];
dimsOutput.d[2] = inputs[0].d[2];

return dimsOutput;
}

size_t IdentityConv::getWorkspaceSize(int32_t maxBatchSize) const noexcept
{
// No scratch space is required for this plugin.
return 0;
}

size_t IdentityConv::getSerializationSize() const noexcept
{
// return sizeof(IdentityConvParameters);
return sizeof(int32_t) * 4 + sizeof(nvinfer1::DataType) + sizeof(size_t);
}

void IdentityConv::serialize(void* buffer) const noexcept
{
char* d{reinterpret_cast<char*>(buffer)};
char* const a{d};
// Be cautious, the order has to match deserialization.
write(d, mParams.group);
write(d, mParams.dtype);
write(d, mParams.channelSize);
write(d, mParams.height);
write(d, mParams.width);
write(d, mParams.dtypeBytes);
PLUGIN_ASSERT(d == a + getSerializationSize());
}

bool IdentityConv::supportsFormatCombination(
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs,
int32_t nbOutputs) const noexcept
{
// For this method inputs are numbered 0..(nbInputs-1) and outputs are
// numbered nbInputs..(nbInputs+nbOutputs-1). Using this numbering, pos is
// an index into InOut, where 0 <= pos < nbInputs+nbOutputs.
PLUGIN_ASSERT(nbInputs == 2 && nbOutputs == 1 &&
pos < nbInputs + nbOutputs);
bool isValidCombination = false;

// Suppose we support only a limited number of format configurations.
isValidCombination |=
(inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kFLOAT);
isValidCombination |=
(inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF);
// Make sure the input tensor and output tensor types and formats are same.
isValidCombination &=
(pos < nbInputs || (inOut[pos].format == inOut[0].format &&
inOut[pos].type == inOut[0].type));

return isValidCombination;
}

char const* IdentityConv::getPluginType() const noexcept
{
return kIDENTITY_CONV_PLUGIN_NAME;
}

char const* IdentityConv::getPluginVersion() const noexcept
{
return kIDENTITY_CONV_PLUGIN_VERSION;
}

void IdentityConv::destroy() noexcept { delete this; }

nvinfer1::IPluginV2IOExt* IdentityConv::clone() const noexcept
{
// It's possible to encounter errors during cloning.
// For example, if the memory to allocate is insufficient, exceptions can be
// thrown.
try
{
IPluginV2IOExt* const plugin{new IdentityConv{mParams}};
plugin->setPluginNamespace(mPluginNamespace);
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

void IdentityConv::setPluginNamespace(char const* pluginNamespace) noexcept
{
mPluginNamespace = pluginNamespace;
}

char const* IdentityConv::getPluginNamespace() const noexcept
{
return mPluginNamespace;
}

nvinfer1::DataType
IdentityConv::getOutputDataType(int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept
{
// One output.
PLUGIN_ASSERT(index == 0);
PLUGIN_ASSERT(nbInputs == 2);
// The output type is the same as the input type.
return inputTypes[0];
}

bool IdentityConv::isOutputBroadcastAcrossBatch(int32_t outputIndex,
bool const* inputIsBroadcasted,
int32_t nbInputs) const noexcept
{
return false;
}

bool IdentityConv::canBroadcastInputAcrossBatch(
int32_t inputIndex) const noexcept
{
return false;
}

int32_t IdentityConv::enqueue(int32_t batchSize, void const* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
size_t const inputSize{static_cast<size_t>(batchSize * mParams.channelSize *
mParams.height * mParams.width)};
size_t const inputSizeBytes{inputSize * mParams.dtypeBytes};
cudaError_t const status{cudaMemcpyAsync(outputs[0], inputs[0],
inputSizeBytes,
cudaMemcpyDeviceToDevice, stream)};
return status;
}

} // namespace plugin
} // namespace nvinfer1

In the above implementation, some of the other caveats, such as how to specify the supported custom plugin configurations, what and how to serialize and deserialize the plugin parameters, have also been demonstrated.

IdentityConv Custom Plugin Creator Implementation

TensorRT will call the nvinfer1::IPluginCreator::createPlugin method to create the custom plugin instance. Therefore, we have to create a custom plugin creator class that is derived from the nvinfer1::IPluginCreator class.

IdentityConvPluginCreator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#ifndef TENSORRT_IDENTITY_CONV_PLUGIN_CREATOR_H
#define TENSORRT_IDENTITY_CONV_PLUGIN_CREATOR_H

#include <vector>

#include <NvInferRuntime.h>

namespace nvinfer1
{
namespace plugin
{

class BaseCreator : public nvinfer1::IPluginCreator
{
public:
void setPluginNamespace(char const* libNamespace) noexcept override
{
mNamespace = libNamespace;
}

char const* getPluginNamespace() const noexcept override
{
return mNamespace.c_str();
}

protected:
std::string mNamespace;
};

// Plugin factory class.
class IdentityConvCreator : public BaseCreator
{
public:
IdentityConvCreator();

~IdentityConvCreator() override = default;

char const* getPluginName() const noexcept override;

char const* getPluginVersion() const noexcept override;

nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;

nvinfer1::IPluginV2IOExt*
createPlugin(char const* name,
nvinfer1::PluginFieldCollection const* fc) noexcept override;

nvinfer1::IPluginV2IOExt*
deserializePlugin(char const* name, void const* serialData,
size_t serialLength) noexcept override;

private:
nvinfer1::PluginFieldCollection mFC;
std::vector<nvinfer1::PluginField> mPluginAttributes;

protected:
std::string mNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // TENSORRT_IDENTITY_CONV_PLUGIN_CREATOR_H

TensorRT allows both static and dynamic registration of custom plugins. The REGISTER_TENSORRT_PLUGIN macro is used to register the custom plugin creator class. In this example, we use the dynamic registration method. Therefore, the REGISTER_TENSORRT_PLUGIN macro is commented out.

IdentityConvPluginCreator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include <exception>
#include <iostream>
#include <mutex>

#include <NvInferRuntimePlugin.h>

#include "IdentityConvPlugin.h"
#include "IdentityConvPluginCreator.h"
#include "PluginUtils.h"

namespace nvinfer1
{
namespace plugin
{

// This is not needed for plugin dynamic registration.
// REGISTER_TENSORRT_PLUGIN(IdentityConvCreator);

// Plugin creator
IdentityConvCreator::IdentityConvCreator()
{
// Declare the ONNX attributes that the ONNX parser will collect from the
// ONNX model that contains the IdentityConv node.

// In our dummy case,
// attrs={
// "kernel_shape": [1, 1],
// "strides": [1, 1],
// "pads": [0, 0, 0, 0],
// "group": num_groups
// }

mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField(
"kernel_shape", nullptr, PluginFieldType::kINT32, 2));
mPluginAttributes.emplace_back(
nvinfer1::PluginField("strides", nullptr, PluginFieldType::kINT32, 2));
mPluginAttributes.emplace_back(
nvinfer1::PluginField("pads", nullptr, PluginFieldType::kINT32, 4));
mPluginAttributes.emplace_back(
nvinfer1::PluginField("group", nullptr, PluginFieldType::kINT32, 1));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

char const* IdentityConvCreator::getPluginName() const noexcept
{
return kIDENTITY_CONV_PLUGIN_NAME;
}

char const* IdentityConvCreator::getPluginVersion() const noexcept
{
return kIDENTITY_CONV_PLUGIN_VERSION;
}

nvinfer1::PluginFieldCollection const*
IdentityConvCreator::getFieldNames() noexcept
{
return &mFC;
}

nvinfer1::IPluginV2IOExt* IdentityConvCreator::createPlugin(
char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept
{
// The attributes from the ONNX node will be parsed and passed via fc.
try
{
nvinfer1::PluginField const* fields{fc->fields};
int32_t nbFields{fc->nbFields};

PLUGIN_VALIDATE(nbFields == 4);

std::vector<int32_t> kernelShape{};
std::vector<int32_t> strides{};
std::vector<int32_t> pads{};
int32_t group{};

for (int32_t i{0}; i < nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "kernel_shape"))
{
PLUGIN_VALIDATE(fields[i].type ==
nvinfer1::PluginFieldType::kINT32);
int32_t const* const kernelShapeData{
static_cast<int32_t const*>(fields[i].data)};
for (int32_t j{0}; j < fields[i].length; ++j)
{
kernelShape.push_back(kernelShapeData[j]);
}
}
if (!strcmp(attrName, "strides"))
{
PLUGIN_VALIDATE(fields[i].type ==
nvinfer1::PluginFieldType::kINT32);
int32_t const* const stridesData{
static_cast<int32_t const*>(fields[i].data)};
for (int32_t j{0}; j < fields[i].length; ++j)
{
strides.push_back(stridesData[j]);
}
}
if (!strcmp(attrName, "pads"))
{
PLUGIN_VALIDATE(fields[i].type ==
nvinfer1::PluginFieldType::kINT32);
int32_t const* const padsData{
static_cast<int32_t const*>(fields[i].data)};
for (int32_t j{0}; j < fields[i].length; ++j)
{
pads.push_back(padsData[j]);
}
}
if (!strcmp(attrName, "group"))
{
PLUGIN_VALIDATE(fields[i].type ==
nvinfer1::PluginFieldType::kINT32);
PLUGIN_VALIDATE(fields[i].length == 1);
group = *(static_cast<int32_t const*>(fields[i].data));
}
}

// Log the attributes parsed from ONNX node.
std::stringstream ss;
ss << "Plugin Attributes:";
logInfo(ss.str().c_str());

ss.str("");
ss << "kernel_shape: ";
for (auto const& val : kernelShape)
{
ss << val << " ";
}
logInfo(ss.str().c_str());

ss.str("");
ss << "strides: ";
for (auto const& val : strides)
{
ss << val << " ";
}
logInfo(ss.str().c_str());

ss.str("");
ss << "pads: ";
for (auto const& val : pads)
{
ss << val << " ";
}
logInfo(ss.str().c_str());

ss.str("");
ss << "group: " << group;
logInfo(ss.str().c_str());

IdentityConvParameters const params{.group = group};

IdentityConv* const plugin{new IdentityConv{params}};
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

nvinfer1::IPluginV2IOExt*
IdentityConvCreator::deserializePlugin(char const* name, void const* serialData,
size_t serialLength) noexcept
{
try
{
IdentityConv* plugin = new IdentityConv{serialData, serialLength};
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

} // namespace plugin
} // namespace nvinfer1

In the above implementation, some of the other caveats, such as how to rely on the TensorRT ONNX parser to parse the attributes of the custom ONNX node, have also been demonstrated.

Expose IdentityConv Custom Plugin Creator to TensorRT

Because the custom plugin library will be loaded dynamically by TensorRT, we have to expose the plugin creator class to TensorRT from the dynamic library. The setLoggerFinder and getPluginCreators functions are mandatory to implement so that TensorRT can call them successfully during the plugin creation time.

PluginRegistration.h
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef TENSORRT_PLUGIN_REGISTRATION_H
#define TENSORRT_PLUGIN_REGISTRATION_H

#include <NvInferRuntime.h>

// These are the functions that TensorRT library will call at the runtime.

extern "C" void setLoggerFinder(nvinfer1::ILoggerFinder* finder);

extern "C" nvinfer1::IPluginCreator* const*
getPluginCreators(int32_t& nbCreators);

#endif // TENSORRT_PLUGIN_REGISTRATION_H

If we have multiple custom plugins, we can register them all in the getPluginCreators function. In this example, we only have one custom plugin.

PluginRegistration.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <iostream>
#include <mutex>

#include <NvInferRuntime.h>

#include "IdentityConvPluginCreator.h"

class ThreadSafeLoggerFinder
{
public:
ThreadSafeLoggerFinder() = default;

// Set the logger finder.
void setLoggerFinder(nvinfer1::ILoggerFinder* finder)
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder == nullptr && finder != nullptr)
{
mLoggerFinder = finder;
}
}

// Get the logger.
nvinfer1::ILogger* getLogger() noexcept
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder != nullptr)
{
return mLoggerFinder->findLogger();
}
return nullptr;
}

private:
nvinfer1::ILoggerFinder* mLoggerFinder{nullptr};
std::mutex mMutex;
};

ThreadSafeLoggerFinder gLoggerFinder;

// Not exposing this function to the user to get the plugin logger for the
// moment. Can switch the plugin logger to this in the future.

// ILogger* getPluginLogger()
// {
// return gLoggerFinder.getLogger();
// }

extern "C" void setLoggerFinder(nvinfer1::ILoggerFinder* finder)
{
gLoggerFinder.setLoggerFinder(finder);
}

extern "C" nvinfer1::IPluginCreator* const*
getPluginCreators(int32_t& nbCreators)
{
nbCreators = 1;
static nvinfer1::plugin::IdentityConvCreator identityConvCreator{};
static nvinfer1::IPluginCreator* const pluginCreatorList[] = {
&identityConvCreator};
return pluginCreatorList;
}

Build Engine With Custom Plugin

The custom plugin library can be dynamically registered and loaded by TensorRT via the nvinfer1::IPluginRegistry::loadLibrary method. The ONNX parser can then parse the custom ONNX node and use the custom plugin to run the inference for the custom ONNX node.

The custom plugin can also be serialized into the TensorRT engine file via the nvinfer1::IBuilderConfig::setPluginsToSerialize method so that the custom plugin library is not required to be loaded during the inference time.

build_engine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Create a TensorRT engine building program that builds an engine from an ONNX
// file and uses a custom plugin.

#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>

#include <NvInfer.h>
#include <NvOnnxParser.h>


class CustomLogger : public nvinfer1::ILogger
{
void log(nvinfer1::ILogger::Severity severity,
const char* msg) noexcept override
{
if (severity <= nvinfer1::ILogger::Severity::kINFO)
{
std::cout << msg << std::endl;
}
}
};

struct InferDeleter
{
template <typename T>
void operator()(T* obj) const
{
delete obj;
}
};

int main(int argc, char** argv)
{
CustomLogger logger{};

std::string const data_dir_path{"data"};
std::string const onnx_file_name{"identity_neural_network.onnx"};
std::string const engine_file_name{"identity_neural_network.engine"};
std::string const onnx_file_path{data_dir_path + "/" + onnx_file_name};
std::string const engine_file_path{data_dir_path + "/" + engine_file_name};
std::string const plugin_library_name{"libidentity_conv.so"};
std::string const plugin_library_dir_path{"build/src"};
std::string const plugin_library_path{plugin_library_dir_path + "/" +
plugin_library_name};
char const* const plugin_library_path_c_str{plugin_library_path.c_str()};

// Create the builder.
std::unique_ptr<nvinfer1::IBuilder, InferDeleter> builder{
nvinfer1::createInferBuilder(logger)};
if (builder == nullptr)
{
std::cerr << "Failed to create the builder." << std::endl;
return EXIT_FAILURE;
}
void* const plugin_handle{
builder->getPluginRegistry().loadLibrary(plugin_library_path.c_str())};
if (plugin_handle == nullptr)
{
std::cerr << "Failed to load the plugin library." << std::endl;
return EXIT_FAILURE;
}

// Create the network.
uint32_t const flag{
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)};
std::unique_ptr<nvinfer1::INetworkDefinition, InferDeleter> network{
builder->createNetworkV2(flag)};
if (network == nullptr)
{
std::cerr << "Failed to create the network." << std::endl;
return EXIT_FAILURE;
}

// Create the parser.
std::unique_ptr<nvonnxparser::IParser, InferDeleter> parser{
nvonnxparser::createParser(*network, logger)};
if (parser == nullptr)
{
std::cerr << "Failed to create the parser." << std::endl;
return EXIT_FAILURE;
}
parser->parseFromFile(
onnx_file_path.c_str(),
static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));
for (int32_t i = 0; i < parser->getNbErrors(); ++i)
{
std::cout << parser->getError(i)->desc() << std::endl;
}

// Set the allowed IO tensor formats.
uint32_t const formats{
1U << static_cast<uint32_t>(nvinfer1::TensorFormat::kLINEAR)};
nvinfer1::DataType const dtype{nvinfer1::DataType::kFLOAT};
network->getInput(0)->setAllowedFormats(formats);
network->getInput(0)->setType(dtype);
network->getOutput(0)->setAllowedFormats(formats);
network->getOutput(0)->setType(dtype);

// Build the engine.
std::unique_ptr<nvinfer1::IBuilderConfig, InferDeleter> config{
builder->createBuilderConfig()};
if (config == nullptr)
{
std::cerr << "Failed to create the builder config." << std::endl;
return EXIT_FAILURE;
}
config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1U << 20);
config->setFlag(nvinfer1::BuilderFlag::kFP16);
config->setPluginsToSerialize(&plugin_library_path_c_str, 1);

std::unique_ptr<nvinfer1::IHostMemory, InferDeleter> serializedModel{
builder->buildSerializedNetwork(*network, *config)};

// Write the serialized engine to a file.
std::ofstream engineFile{engine_file_path.c_str(), std::ios::binary};
if (!engineFile.is_open())
{
std::cerr << "Failed to open the engine file." << std::endl;
return EXIT_FAILURE;
}
engineFile.write(static_cast<char const*>(serializedModel->data()),
serializedModel->size());
engineFile.close();

std::cout << "Successfully serialized the engine to the file: "
<< engine_file_path << std::endl;

return EXIT_SUCCESS;
}

Run Engine With Custom Plugin

Because the custom plugin has been serialized into the TensorRT engine file, we don’t need to load the custom plugin library during the inference time. We can simply create the runtime and deserialize the engine from the engine file.

run_engine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <sstream>
#include <vector>

#include <NvInfer.h>

#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() check_last(__FILE__, __LINE__)
void check_last(const char* const file, const int line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

class CustomLogger : public nvinfer1::ILogger
{
void log(nvinfer1::ILogger::Severity severity,
const char* msg) noexcept override
{
// suppress info-level messages
if (severity <= nvinfer1::ILogger::Severity::kINFO)
{
std::cout << msg << std::endl;
}
}
};

struct InferDeleter
{
template <typename T>
void operator()(T* obj) const
{
delete obj;
}
};

void create_random_data(float* data, size_t const size, unsigned int seed = 1U)
{
std::default_random_engine eng(seed);
std::uniform_int_distribution<int32_t> dis(-16, 16);
auto const rand = [&dis, &eng]() { return dis(eng); };
for (size_t i{0U}; i < size; ++i)
{
data[i] = static_cast<float>(rand());
}
}

bool all_close(float const* a, float const* b, size_t size, float rtol = 1e-5f,
float atol = 1e-8f)
{
for (size_t i{0U}; i < size; ++i)
{
float const diff{std::abs(a[i] - b[i])};
if (diff > (atol + rtol * std::abs(b[i])))
{
std::cout << "a[" << i << "]: " << a[i] << std::endl;
std::cout << "b[" << i << "]: " << b[i] << std::endl;
return false;
}
}
return true;
}

int main(int argc, char** argv)
{
CustomLogger logger{};

// The plugin has already been serialized with the engine.
// There is no need to load the plugin library.
std::string const data_dir_path{"data"};
std::string const engine_file_name{"identity_neural_network.engine"};
std::string const engine_file_path{data_dir_path + "/" + engine_file_name};

// Create CUDA stream.
cudaStream_t stream;
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));

// The engine we built is FP32 NCHW IO.
nvinfer1::DataType const expected_dtype{nvinfer1::DataType::kFLOAT};
size_t const expected_dtype_byte_size{4U};
nvinfer1::TensorFormat const expected_format{
nvinfer1::TensorFormat::kLINEAR};

// IO tensor information and buffers.
std::vector<nvinfer1::Dims> input_tensor_shapes{};
std::vector<nvinfer1::Dims> output_tensor_shapes{};
std::vector<size_t> input_tensor_sizes{};
std::vector<size_t> output_tensor_sizes{};
std::vector<char const*> input_tensor_names{};
std::vector<char const*> output_tensor_names{};
std::vector<void*> input_tensor_host_buffers{};
std::vector<void*> input_tensor_device_buffers{};
std::vector<void*> output_tensor_host_buffers{};
std::vector<void*> output_tensor_device_buffers{};

// Error tolerance for unit test.
float const rtol{1e-5f};
float const atol{1e-8f};

// Deserialize the engine.
std::unique_ptr<nvinfer1::IRuntime, InferDeleter> runtime{
nvinfer1::createInferRuntime(logger)};
if (runtime == nullptr)
{
std::cerr << "Failed to create the runtime." << std::endl;
return EXIT_FAILURE;
}

std::ifstream engine_file{engine_file_path, std::ios::binary};
if (!engine_file)
{
std::cerr << "Failed to open the engine file." << std::endl;
return EXIT_FAILURE;
}

engine_file.seekg(0, std::ios::end);
size_t const engine_file_size{static_cast<size_t>(engine_file.tellg())};
engine_file.seekg(0, std::ios::beg);

std::unique_ptr<char[]> engine_data{new char[engine_file_size]};
engine_file.read(engine_data.get(), engine_file_size);

std::unique_ptr<nvinfer1::ICudaEngine, InferDeleter> engine{
runtime->deserializeCudaEngine(engine_data.get(), engine_file_size)};
if (engine == nullptr)
{
std::cerr << "Failed to deserialize the engine." << std::endl;
return EXIT_FAILURE;
}

// Create the execution context.
std::unique_ptr<nvinfer1::IExecutionContext, InferDeleter> context{
engine->createExecutionContext()};
if (context == nullptr)
{
std::cerr << "Failed to create the execution context." << std::endl;
return EXIT_FAILURE;
}

// Check the number of IO tensors.
int32_t const num_io_tensors{engine->getNbIOTensors()};
std::cout << "Number of IO Tensors: " << num_io_tensors << std::endl;
for (int32_t i{0}; i < num_io_tensors; ++i)
{
char const* const tensor_name{engine->getIOTensorName(i)};
std::cout << "Tensor name: " << tensor_name << std::endl;
nvinfer1::TensorIOMode const io_mode{
engine->getTensorIOMode(tensor_name)};
nvinfer1::DataType const dtype{engine->getTensorDataType(tensor_name)};
if (dtype != expected_dtype)
{
std::cerr << "Invalid data type." << std::endl;
return EXIT_FAILURE;
}
nvinfer1::TensorFormat const format{
engine->getTensorFormat(tensor_name)};
if (format != expected_format)
{
std::cerr << "Invalid tensor format." << std::endl;
return EXIT_FAILURE;
}
// Because the input and output shapes are static,
// there is no need to set the IO tensor shapes.
nvinfer1::Dims const shape{engine->getTensorShape(tensor_name)};
// Print out dims.
size_t tensor_size{1U};
std::cout << "Tensor Dims: ";
for (int32_t j{0}; j < shape.nbDims; ++j)
{
tensor_size *= shape.d[j];
std::cout << shape.d[j] << " ";
}
std::cout << std::endl;

// FP32 NCHW tensor format.
size_t tensor_size_bytes{tensor_size * expected_dtype_byte_size};

// Allocate host memory for the tensor.
void* tensor_host_buffer{nullptr};
CHECK_CUDA_ERROR(
cudaMallocHost(&tensor_host_buffer, tensor_size_bytes));
// Allocate device memory for the tensor.
void* tensor_device_buffer{nullptr};
CHECK_CUDA_ERROR(cudaMalloc(&tensor_device_buffer, tensor_size_bytes));

if (io_mode == nvinfer1::TensorIOMode::kINPUT)
{
input_tensor_host_buffers.push_back(tensor_host_buffer);
input_tensor_device_buffers.push_back(tensor_device_buffer);
input_tensor_shapes.push_back(shape);
input_tensor_sizes.push_back(tensor_size);
input_tensor_names.push_back(tensor_name);
}
else
{
output_tensor_host_buffers.push_back(tensor_host_buffer);
output_tensor_device_buffers.push_back(tensor_device_buffer);
output_tensor_shapes.push_back(shape);
output_tensor_sizes.push_back(tensor_size);
output_tensor_names.push_back(tensor_name);
}
}

// Create random input values.
for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i)
{
size_t const tensor_size{input_tensor_sizes.at(i)};
create_random_data(static_cast<float*>(input_tensor_host_buffers.at(i)),
tensor_size);
}

// Copy input data from host to device.
for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i)
{
size_t const tensor_size_bytes{input_tensor_sizes.at(i) *
expected_dtype_byte_size};
CHECK_CUDA_ERROR(cudaMemcpy(input_tensor_device_buffers.at(i),
input_tensor_host_buffers.at(i),
tensor_size_bytes, cudaMemcpyHostToDevice));
}

// Bind IO tensor buffers to the execution context.
for (size_t i{0U}; i < input_tensor_device_buffers.size(); ++i)
{
char const* const tensor_name{input_tensor_names.at(i)};
context->setTensorAddress(tensor_name,
input_tensor_device_buffers.at(i));
}
for (size_t i{0U}; i < output_tensor_device_buffers.size(); ++i)
{
char const* const tensor_name{output_tensor_names.at(i)};
context->setTensorAddress(tensor_name,
output_tensor_device_buffers.at(i));
}

// Run inference a couple of times.
size_t const num_iterations{8U};
for (size_t i{0U}; i < num_iterations; ++i)
{
bool const status{context->enqueueV3(stream)};
if (!status)
{
std::cerr << "Failed to run inference." << std::endl;
return EXIT_FAILURE;
}
}

// Synchronize.
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

// Copy output data from device to host.
for (size_t i{0U}; i < output_tensor_host_buffers.size(); ++i)
{
size_t const tensor_size_bytes{output_tensor_sizes.at(i) *
expected_dtype_byte_size};
CHECK_CUDA_ERROR(cudaMemcpy(output_tensor_host_buffers.at(i),
output_tensor_device_buffers.at(i),
tensor_size_bytes, cudaMemcpyDeviceToHost));
}

// Verify the output given it's an identity neural network.
for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i)
{
if (input_tensor_sizes.at(i) != output_tensor_sizes.at(i))
{
std::cerr << "Input and output tensor sizes do not match."
<< std::endl;
return EXIT_FAILURE;
}
if (!all_close(static_cast<float*>(input_tensor_host_buffers.at(i)),
static_cast<float*>(output_tensor_host_buffers.at(i)),
input_tensor_sizes.at(i), rtol, atol))
{
std::cerr << "Input and output tensor values do not match."
<< std::endl;
return EXIT_FAILURE;
}
}

std::cout << "Successfully verified the output." << std::endl;

// Release resources.
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
for (size_t i{0U}; i < input_tensor_host_buffers.size(); ++i)
{
CHECK_CUDA_ERROR(cudaFreeHost(input_tensor_host_buffers.at(i)));
}
for (size_t i{0U}; i < input_tensor_device_buffers.size(); ++i)
{
CHECK_CUDA_ERROR(cudaFree(input_tensor_device_buffers.at(i)));
}
for (size_t i{0U}; i < output_tensor_host_buffers.size(); ++i)
{
CHECK_CUDA_ERROR(cudaFreeHost(output_tensor_host_buffers.at(i)));
}
for (size_t i{0U}; i < output_tensor_device_buffers.size(); ++i)
{
CHECK_CUDA_ERROR(cudaFree(output_tensor_device_buffers.at(i)));
}
}

Because the output tensor values matches the input tensor values, we have successfully verified the implementation and integration of the custom plugin.

Source Code

The source code of the example can be found in my GitHub repository “TensorRT Custom Plugin Example”.

References

Author

Lei Mao

Posted on

01-27-2024

Updated on

01-27-2024

Licensed under


Comments