# PyTorch Quantization Aware Training

## Introduction

Static quantization allows the user to generate quantized integer model that is highly efficient during inference. However, sometimes, even with careful post-training calibration, the model accuracies might be sacrificed to some extent that is not acceptable. If this is the case, post-training calibration is not sufficient to generate a quantized integer model. We would have train the model in a way so that the quantization effect has been taken into account. Quantization aware training is capable of modeling the quantization effect during training.

The mechanism of quantization aware training is simple, it places fake quantization modules, i.e., quantization and dequantization modules, at the places where quantization happens during floating-point model to quantized integer model conversion, to simulate the effects of clamping and rounding brought by integer quantization. The fake quantization modules will also monitor scales and zero points of the weights and activations. Once the quantization aware training is finished, the floating point model could be converted to quantized integer model immediately using the information stored in the fake quantization modules.

In this blog post, I would like to show how to use PyTorch to do quantization aware training. More details about the mathematical foundations of quantization for neural networks could be found in my article “Quantization for Neural Networks”.

## PyTorch Quantization Aware Training

Unlike TensorFlow 2.3.0 which supports integer quantization using arbitrary bitwidth from 2 to 16, PyTorch 1.7.0 only supports 8-bit integer quantization. The workflow could be as easy as loading a pre-trained floating point model and apply a quantization aware training wrapper. However, without doing layer fusion, sometimes such kind of easy manipulation would not result in good model performances.

In this case, I will also use the ResNet18 from TorchVision models as an example. All the steps prior, to the quantization aware training steps, including layer fusion and skip connections replacement, are exactly the same as to the ones used in “PyTorch Static Quantization”. The source code could also be downloaded from GitHub.

The quantization aware training steps are also very similar to post-training calibration:

- Train a floating point model or load a pre-trained floating point model.
- Move the model to CPU and switch model to training mode.
- Apply layer fusion.
- Switch model to evaluation mode, check if the layer fusion results in correct model, and switch back to training mode.
- Apply
`torch.quantization.QuantStub()`

and`torch.quantization.QuantStub()`

to the inputs and outputs, respectively. - Specify quantization configurations, such as symmetric quantization or asymmetric quantization, etc.
- Prepare quantization model for quantization aware training.
- Move the model to CUDA and run quantization aware training using CUDA.
- Move the model to CPU and convert the quantization aware trained floating point model to quantized integer model.
- [Optional] Verify accuracies and inference performance gain.
- Save the quantized integer model.

The quantization aware training script is very similar to the one used in “PyTorch Static Quantization”:

1 | import os |

The accuracy and inference performance for quantized model with layer fusions are

1 | FP32 evaluation accuracy: 0.869 |

## Conclusions

Comparing to the accuracy and inference performance from “PyTorch Static Quantization”, PyTorch quantization aware training results in the same inference performance on CPU with better accuracy.

## References

PyTorch Quantization Aware Training

https://leimao.github.io/blog/PyTorch-Quantization-Aware-Training/