PyTorch Static Quantization

Introduction

Static quantization quantizes the weights and activations of the model. It allows the user to fuse activations into preceding layers where possible. Unlike dynamic quantization, where the scales and zero points were collected during inference, the scales and zero points for static quantization were determined prior to inference using a representative dataset. Therefore, static quantization is theoretically faster than dynamic quantization while the model size and memory bandwidth consumptions remain to be the same. Therefore, statically quantized models are more favorable for inference than dynamic quantization models.

In this blog post, I would like to show how to use PyTorch to do static quantizations. More details about the mathematical foundations of quantization for neural networks could be found in my article “Quantization for Neural Networks”.

PyTorch Static Quantization

Unlike TensorFlow 2.3.0 which supports integer quantization using arbitrary bitwidth from 2 to 16, PyTorch 1.7.0 only supports 8-bit integer quantization. The workflow could be as easy as loading a pre-trained floating point model and apply a static quantization wrapper. However, without doing layer fusion, sometimes such kind of easy manipulation would not result in good model performances.

In this case, I would like to use the ResNet18 from TorchVision models as an example. I will do post-training quantization with and without layer fusion and compare their performances. The source code could also be downloaded from GitHub.

Because ResNet has skip connections addition and this addition in the TorchVision implementation uses +. We would have to replace this + (torch.add equivalence) with FloatFunctional.add (torch.add + torch.nn.Identity equivalence) in the model definition. This is because torch.nn.Identity serves as a flag for activation quantization. Without it, there will be no activation quantization for skip connection additions, resulting in erroneous quantization calibration.

In addition, we would like to test layer fusions, such as fusing Conv2D, BatchNorm, and ReLU. To do layer fusion, the torch.nn.Module name could not overlap. Otherwise it will cause erroneous quantization calibration. For example, in ordinary FP32 model, we could define one parameter-free relu = torch.nn.ReLU() and reuse this relu module everywhere. However, if we want to fuse some specific ReLUs, the ReLU modules have to be explicitly separated. So in this case, we will have to define relu1 = torch.nn.ReLU(), relu2 = torch.nn.ReLU(), etc. Sometimes, layer fusion is compulsory, since there are no quantized layer implementations corresponding to some floating point layers, such as BatchNorm.

Taken together, the modified ResNet module definition resnet.py is as follows.

resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
# Modified from
# https://github.com/pytorch/vision/blob/release/0.8.0/torchvision/models/resnet.py

import torch
from torch import Tensor
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion: int = 1

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
# Rename relu to relu1
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
self.skip_add = nn.quantized.FloatFunctional()
# Remember to use two independent ReLU for layer fusion.
self.relu2 = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

# Use FloatFunctional for addition for quantization compatibility
# out += identity
out = self.skip_add.add(identity, out)
out = self.relu2(out)

return out


class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion: int = 4

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu1 = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.skip_add = nn.quantized.FloatFunctional()
self.relu2 = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

# out += identity
out = self.skip_add.add(identity, out)
out = self.relu2(out)

return out


class ResNet(nn.Module):

def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]

def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model


def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)


def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)


def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)


def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.

The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.

The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

The next steps are:

  1. Train a floating point model or load a pre-trained floating point model.
  2. Move the model to CPU and switch model to evaluation mode.
  3. Apply layer fusion and check if the layer fusion results in correct model.
  4. Apply torch.quantization.QuantStub() and torch.quantization.QuantStub() to the inputs and outputs, respectively.
  5. Specify quantization configurations, such as symmetric quantization or asymmetric quantization, etc.
  6. Prepare quantization model for post-training calibration.
  7. Run post-training calibration.
  8. Convert the calibrated floating point model to quantized integer model.
  9. [Optional] Verify accuracies and inference performance gain.
  10. Save the quantized integer model.

Note that step 4 is to ask PyTorch to specifically collect quantization statistics for the inputs and outputs, respectively. Otherwise, since PyTorch collects quantization statistics for weights and activations by default, there will be problems for the input quantization and output dequantization, since there are no quantization statistics collected for them.

cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

import time
import copy
import numpy as np

from resnet import resnet18

def set_random_seeds(random_seed=0):

torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):

train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform)
# We will use test set for validation and test in this project.
# Do not use test set for validation in practice!
test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform)

train_sampler = torch.utils.data.RandomSampler(train_set)
test_sampler = torch.utils.data.SequentialSampler(test_set)

train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=train_batch_size,
sampler=train_sampler, num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(
dataset=test_set, batch_size=eval_batch_size,
sampler=test_sampler, num_workers=num_workers)

return train_loader, test_loader

def evaluate_model(model, test_loader, device, criterion=None):

model.eval()
model.to(device)

running_loss = 0
running_corrects = 0

for inputs, labels in test_loader:

inputs = inputs.to(device)
labels = labels.to(device)

outputs = model(inputs)
_, preds = torch.max(outputs, 1)

if criterion is not None:
loss = criterion(outputs, labels).item()
else:
loss = 0

# statistics
running_loss += loss * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

eval_loss = running_loss / len(test_loader.dataset)
eval_accuracy = running_corrects / len(test_loader.dataset)

return eval_loss, eval_accuracy

def train_model(model, train_loader, test_loader, device):

# The training configurations were not carefully selected.
learning_rate = 1e-2
num_epochs = 20

criterion = nn.CrossEntropyLoss()

model.to(device)

# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

for epoch in range(num_epochs):

# Training
model.train()

running_loss = 0
running_corrects = 0

for inputs, labels in train_loader:

inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

train_loss = running_loss / len(train_loader.dataset)
train_accuracy = running_corrects / len(train_loader.dataset)

# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)

print("Epoch: {:02d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))

return model

def calibrate_model(model, loader, device=torch.device("cpu:0")):

model.to(device)
model.eval()

for inputs, labels in loader:
inputs = inputs.to(device)
labels = labels.to(device)
_ = model(inputs)

def measure_inference_latency(model,
device,
input_size=(1, 3, 32, 32),
num_samples=100,
num_warmups=10):

model.to(device)
model.eval()

x = torch.rand(size=input_size).to(device)

with torch.no_grad():
for _ in range(num_warmups):
_ = model(x)
torch.cuda.synchronize()

with torch.no_grad():
start_time = time.time()
for _ in range(num_samples):
_ = model(x)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time
elapsed_time_ave = elapsed_time / num_samples

return elapsed_time_ave

def save_model(model, model_dir, model_filename):

if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filepath = os.path.join(model_dir, model_filename)
torch.save(model.state_dict(), model_filepath)

def load_model(model, model_filepath, device):

model.load_state_dict(torch.load(model_filepath, map_location=device))

return model

def save_torchscript_model(model, model_dir, model_filename):

if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filepath = os.path.join(model_dir, model_filename)
torch.jit.save(torch.jit.script(model), model_filepath)

def load_torchscript_model(model_filepath, device):

model = torch.jit.load(model_filepath, map_location=device)

return model

def create_model(num_classes=10):

# The number of channels in ResNet18 is divisible by 8.
# This is required for fast GEMM integer matrix multiplication.
# model = torchvision.models.resnet18(pretrained=False)
model = resnet18(num_classes=num_classes, pretrained=False)

# We would use the pretrained ResNet18 as a feature extractor.
# for param in model.parameters():
# param.requires_grad = False

# Modify the last FC layer
# num_features = model.fc.in_features
# model.fc = nn.Linear(num_features, 10)

return model

class QuantizedResNet18(nn.Module):
def __init__(self, model_fp32):
super(QuantizedResNet18, self).__init__()
# QuantStub converts tensors from floating point to quantized.
# This will only be used for inputs.
self.quant = torch.quantization.QuantStub()
# DeQuantStub converts tensors from quantized to floating point.
# This will only be used for outputs.
self.dequant = torch.quantization.DeQuantStub()
# FP32 model
self.model_fp32 = model_fp32

def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.model_fp32(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x

def model_equivalence(model_1, model_2, device, rtol=1e-05, atol=1e-08, num_tests=100, input_size=(1,3,32,32)):

model_1.to(device)
model_2.to(device)

for _ in range(num_tests):
x = torch.rand(size=input_size).to(device)
y1 = model_1(x).detach().cpu().numpy()
y2 = model_2(x).detach().cpu().numpy()
if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False:
print("Model equivalence test sample failed: ")
print(y1)
print(y2)
return False

return True

def main():

random_seed = 0
num_classes = 10
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

model_dir = "saved_models"
model_filename = "resnet18_cifar10.pt"
quantized_model_filename = "resnet18_quantized_cifar10.pt"
model_filepath = os.path.join(model_dir, model_filename)
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename)

set_random_seeds(random_seed=random_seed)

# Create an untrained model.
model = create_model(num_classes=num_classes)

train_loader, test_loader = prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256)

# Train model.
model = train_model(model=model, train_loader=train_loader, test_loader=test_loader, device=cuda_device)
# Save model.
save_model(model=model, model_dir=model_dir, model_filename=model_filename)
# Load a pretrained model.
model = load_model(model=model, model_filepath=model_filepath, device=cuda_device)
# Move the model to CPU since static quantization does not support CUDA currently.
model.to(cpu_device)
# Make a copy of the model for layer fusion
fused_model = copy.deepcopy(model)

model.eval()
# The model has to be switched to evaluation mode before any layer fusion.
# Otherwise the quantization will not work correctly.
fused_model.eval()

# Fuse the model in place rather manually.
fused_model = torch.quantization.fuse_modules(fused_model, [["conv1", "bn1", "relu"]], inplace=True)
for module_name, module in fused_model.named_children():
if "layer" in module_name:
for basic_block_name, basic_block in module.named_children():
torch.quantization.fuse_modules(basic_block, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]], inplace=True)
for sub_block_name, sub_block in basic_block.named_children():
if sub_block_name == "downsample":
torch.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)

# Print FP32 model.
print(model)
# Print fused model.
print(fused_model)

# Model and fused model should be equivalent.
assert model_equivalence(model_1=model, model_2=fused_model, device=cpu_device, rtol=1e-03, atol=1e-06, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!"

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
quantized_model = QuantizedResNet18(model_fp32=fused_model)
# Using un-fused model will fail.
# Because there is no quantized layer implementation for a single batch normalization layer.
# quantized_model = QuantizedResNet18(model_fp32=model)
# Select quantization schemes from
# https://pytorch.org/docs/stable/quantization-support.html
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
# Custom quantization configurations
# quantization_config = torch.quantization.default_qconfig
# quantization_config = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8), weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

quantized_model.qconfig = quantization_config

# Print quantization configurations
print(quantized_model.qconfig)

torch.quantization.prepare(quantized_model, inplace=True)

# Use training data for calibration.
calibrate_model(model=quantized_model, loader=train_loader, device=cpu_device)

quantized_model = torch.quantization.convert(quantized_model, inplace=True)

# Using high-level static quantization wrapper
# The above steps, including torch.quantization.prepare, calibrate_model, and torch.quantization.convert, are also equivalent to
# quantized_model = torch.quantization.quantize(model=quantized_model, run_fn=calibrate_model, run_args=[train_loader], mapping=None, inplace=False)

quantized_model.eval()

# Print quantized model.
print(quantized_model)

# Save quantized model.
save_torchscript_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename)

# Load quantized model.
quantized_jit_model = load_torchscript_model(model_filepath=quantized_model_filepath, device=cpu_device)

_, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=cpu_device, criterion=None)
_, int8_eval_accuracy = evaluate_model(model=quantized_jit_model, test_loader=test_loader, device=cpu_device, criterion=None)

# Skip this assertion since the values might deviate a lot.
# assert model_equivalence(model_1=model, model_2=quantized_jit_model, device=cpu_device, rtol=1e-01, atol=1e-02, num_tests=100, input_size=(1,3,32,32)), "Quantized model deviates from the original model too much!"

print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy))
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy))

fp32_cpu_inference_latency = measure_inference_latency(model=model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_jit_cpu_inference_latency = measure_inference_latency(model=quantized_jit_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
fp32_gpu_inference_latency = measure_inference_latency(model=model, device=cuda_device, input_size=(1,3,32,32), num_samples=100)

print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000))
print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(fp32_gpu_inference_latency * 1000))
print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(int8_cpu_inference_latency * 1000))
print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(int8_jit_cpu_inference_latency * 1000))

if __name__ == "__main__":

main()

The accuracy and inference performance for quantized model with layer fusions are

1
2
3
4
5
6
FP32 evaluation accuracy: 0.869
INT8 evaluation accuracy: 0.868
FP32 CPU Inference Latency: 4.68 ms / sample
FP32 CUDA Inference Latency: 3.70 ms / sample
INT8 CPU Inference Latency: 2.03 ms / sample
INT8 JIT CPU Inference Latency: 0.45 ms / sample

Note that the evaluation accuracy of ResNet18 for the CIFAR10 ($32 \times 32$) dataset is not as high as 0.95. This is because the original ResNet was designed specifically for ImageNet ($224 \times 224$) classification. To improve the accuracy of ResNet for the CIFAR10 dataset, we could change the kernel size of the first convolution from $7$ to $3$ and reduce the stride of the first convolution from $2$ to $1$.

Conclusions

PyTorch quantization results in much faster inference performance on CPU with minimum accuracy loss.

Extensions

To do quantization inference on CUDA, please refer to TensorRT for symmetric post-training quantization. The scale values of PyTorch symmetrically quantized models could also be used for TensorRT to generate inference engine without doing additional post-training quantization.

References

Author

Lei Mao

Posted on

11-28-2020

Updated on

04-29-2021

Licensed under


Comments