PyTorch Dynamic Quantization

Introduction

Dynamic quantization quantize the weights of neural networks to integers, but the activations are dynamically quantized during inference. Comparing to floating point neural networks, the size of dynamic quantized model is much smaller since the weights are stored as low-bitwidth integers. Comparing to other quantization techniques, dynamic quantization does not require any data for calibration or fine-tuning. More details about the mathematical foundations of quantization for neural networks could be found in my article “Quantization for Neural Networks”.

Given a pre-trained floating point model, we could easily create an dynamically quantized model, run inference, and potentially achieve better latency without too much additional effort. In this blog post, I would like to show how to use PyTorch to do dynamic quantizations.

PyTorch Dynamic Quantization

Unlike TensorFlow 2.3.0 which supports integer quantization using arbitrary bitwidth from 2 to 16, PyTorch 1.7.0 only supports 8-bit integer quantization. The workflow is as easy as loading a pre-trained floating point model and apply a dynamic quantization wrapper.

In this case, I would like to use the BERT-QA model from HuggingFace Transformers as an example. I was dynamically quantizing the torch.nn.Linear layer for the BERT-QA model since the majority of the computation for Transformer based models are matrix multiplications. The source code could also be downloaded from GitHub.

qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import time
import torch
from transformers import BertTokenizer, BertForQuestionAnswering

def measure_inference_latency(model, inputs, num_samples=100, num_warmups=10):

with torch.no_grad():
for _ in range(num_warmups):
_ = model(**inputs)
torch.cuda.synchronize()

with torch.no_grad():
start_time = time.time()
for _ in range(num_samples):
_ = model(**inputs)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time
elapsed_time_ave = elapsed_time / num_samples

return elapsed_time_ave

def get_bert_qa_model(model_name="deepset/bert-base-cased-squad2", cache_dir="./saved_models"):

# https://huggingface.co/transformers/model_doc/bert.html#transformers.BertForQuestionAnswering
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = BertForQuestionAnswering.from_pretrained(model_name, cache_dir=cache_dir, return_dict=True)

return model, tokenizer

def prepare_qa_inputs(question, text, tokenizer, device=None):

inputs = tokenizer(question, text, return_tensors="pt")
if device is not None:
inputs_cuda = dict()
for input_name in inputs.keys():
inputs_cuda[input_name] = inputs[input_name].to(device)
inputs = inputs_cuda

return inputs

def move_inputs_to_device(inputs, device=None):

inputs_cuda = dict()
for input_name in inputs.keys():
inputs_cuda[input_name] = inputs[input_name].to(device)

return inputs_cuda

def run_qa(model, tokenizer, question, text, device=None):

inputs = prepare_qa_inputs(question=question, text=text, tokenizer=tokenizer)

all_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].numpy()[0])

if device is not None:
inputs = move_inputs_to_device(inputs, device=device)
model = model.to(device)

outputs = model(**inputs)

start_scores = outputs.start_logits
end_scores = outputs.end_logits

answer_start_idx = torch.argmax(start_scores, 1)[0]
answer_end_idx = torch.argmax(end_scores, 1)[0] + 1

answer = " ".join(all_tokens[answer_start_idx : answer_end_idx])

return answer

def get_model_size(model, temp_dir="/tmp"):

model_dir = os.path.join(temp_dir, "temp")
torch.save(model.state_dict(), model_dir)
# model.save_pretrained(model_dir)
size = os.path.getsize(model_dir)
os.remove(model_dir)

return size

def main():

cuda_device = torch.device("cuda:0")
num_samples = 100

model, tokenizer = get_bert_qa_model(model_name="deepset/bert-base-cased-squad2")
model.eval()
# https://pytorch.org/docs/stable/torch.quantization.html?highlight=torch%20quantization%20quantize_dynamic#torch.quantization.quantize_dynamic
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

print("=" * 75)
print("Model Sizes")
print("=" * 75)

model_size = get_model_size(model=model)
quantized_model_size = get_model_size(model=quantized_model)

print("FP32 Model Size: {:.2f} MB".format(model_size / (2 ** 20)))
print("INT8 Model Size: {:.2f} MB".format(quantized_model_size / (2 ** 20)))

question = "What publication printed that the wealthiest 1% have more money than those in the bottom 90%?"

text = "According to PolitiFact the top 400 richest Americans \"have more wealth than half of all Americans combined.\" According to the New York Times on July 22, 2014, the \"richest 1 percent in the United States now own more wealth than the bottom 90 percent\". Inherited wealth may help explain why many Americans who have become rich may have had a \"substantial head start\". In September 2012, according to the Institute for Policy Studies, \"over 60 percent\" of the Forbes richest 400 Americans \"grew up in substantial privilege\"."

inputs = prepare_qa_inputs(question=question, text=text, tokenizer=tokenizer)
answer = run_qa(model=model, tokenizer=tokenizer, question=question, text=text)
answer_quantized = run_qa(model=quantized_model, tokenizer=tokenizer, question=question, text=text)

print("=" * 75)
print("BERT QA Example")
print("=" * 75)

print("Text: ")
print(text)
print("Question: ")
print(question)
print("Model Answer: ")
print(answer)
print("Dynamic Quantized Model Answer: ")
print(answer_quantized)

print("=" * 75)
print("BERT QA Inference Latencies")
print("=" * 75)

model_latency = measure_inference_latency(model=model, inputs=inputs, num_samples=num_samples)
print("CPU Inference Latency: {:.2f} ms / sample".format(model_latency * 1000))

quantized_model_latency = measure_inference_latency(model=quantized_model, inputs=inputs, num_samples=num_samples)
print("Dynamic Quantized CPU Inference Latency: {:.2f} ms / sample".format(quantized_model_latency * 1000))

inputs_cuda = move_inputs_to_device(inputs, device=cuda_device)
model.to(cuda_device)
model_cuda_latency = measure_inference_latency(model=model, inputs=inputs_cuda, num_samples=num_samples)
print("CUDA Inference Latency: {:.2f} ms / sample".format(model_cuda_latency * 1000))

# No CUDA backend for dynamic quantization in PyTorch 1.7.0
# quantized_model_cuda = quantized_model.to(cuda_device)
# quantized_model_cuda_latency = measure_inference_latency(model=quantized_model_cuda, inputs=inputs_cuda, num_samples=num_samples)
# print("Dynamic Quantized GPU Inference Latency: {:.2f} ms / sample".format(quantized_model_cuda_latency * 1000))

if __name__ == "__main__":

main()

With PyTorch 1.7.0, we could do dynamic quantization using x86-64 and aarch64 CPUs. However, NVIDIA GPUs have not been supported for PyTorch dynamic quantization yet.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
$ python qa.py 
===========================================================================
Model Sizes
===========================================================================
FP32 Model Size: 411.00 MB
INT8 Model Size: 168.05 MB
===========================================================================
BERT QA Example
===========================================================================
Text:
According to PolitiFact the top 400 richest Americans "have more wealth than half of all Americans combined." According to the New York Times on July 22, 2014, the "richest 1 percent in the United States now own more wealth than the bottom 90 percent". Inherited wealth may help explain why many Americans who have become rich may have had a "substantial head start". In September 2012, according to the Institute for Policy Studies, "over 60 percent" of the Forbes richest 400 Americans "grew up in substantial privilege".
Question:
What publication printed that the wealthiest 1% have more money than those in the bottom 90%?
Model Answer:
New York Times
Dynamic Quantized Model Answer:
New York Times
===========================================================================
BERT QA Inference Latencies
===========================================================================
CPU Inference Latency: 52.27 ms / sample
Dynamic Quantized CPU Inference Latency: 40.63 ms / sample
CUDA Inference Latency: 7.02 ms / sample

We could see that the model size of the INT8 quantized model is much smaller than the FP32 model. The inference latency of INT8 dynamic quantization on CPU is much faster than the FP32 ordinary inference on CPU. However, FP32 inference using NVIDIA GPU is still the fastest.

References

Author

Lei Mao

Posted on

11-14-2020

Updated on

04-29-2021

Licensed under


Comments