PyTorch Pruning

Introduction

State-of-the-art neural networks nowadays have become extremely parameterized in order to maximize the prediction accuracy. However, the model also becomes costly to run and the inference latency becomes a bottleneck. On resource-constrained edge devices, the model has a lot of restrictions and cannot be parameterized as much as we can.

Sparse neural networks could perform as good as dense neural network with respect to the prediction accuracy, and the inference latency becomes much lower theoretically due to its small model size. Neural network pruning is a method to create sparse neural networks from pre-trained dense neural networks.

In this blog post, I would like to show how to use PyTorch to do pruning. More details about the mathematical foundations of pruning for neural networks could be found in my article “Pruning for Neural Networks”.

PyTorch Pruning

To demonstrate the effectiveness of pruning, a ResNet18 model is first pre-trained on CIFAR-10 dataset, achieving a prediction accuracy of $86.9\%$. The pre-trained is further pruned and fine-tuned. The number of parameters could be reduced by $98\%$, i.e., $50\times$ compression , while maintaining the prediction accuracy within $1\%$ of the original model. The source code could be downloaded from GitHub.

The pruning is overall straightforward to do if we don’t need to customize the pruning algorithm. In this case, ResNet18 is able to achieve $50\times$ compression by using L1 unstructured pruning on weights, i.e., prune the weights that have the smallest absolute values.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import os
import copy
import torch
import torch.nn.utils.prune as prune
from utils import set_random_seeds, create_model, prepare_dataloader, train_model, save_model, load_model, evaluate_model, create_classification_report

def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

num_zeros = 0
num_elements = 0

if use_mask == True:
for buffer_name, buffer in module.named_buffers():
if "weight_mask" in buffer_name and weight == True:
num_zeros += torch.sum(buffer == 0).item()
num_elements += buffer.nelement()
if "bias_mask" in buffer_name and bias == True:
num_zeros += torch.sum(buffer == 0).item()
num_elements += buffer.nelement()
else:
for param_name, param in module.named_parameters():
if "weight" in param_name and weight == True:
num_zeros += torch.sum(param == 0).item()
num_elements += param.nelement()
if "bias" in param_name and bias == True:
num_zeros += torch.sum(param == 0).item()
num_elements += param.nelement()

sparsity = num_zeros / num_elements

return num_zeros, num_elements, sparsity


def measure_global_sparsity(model,
weight=True,
bias=False,
conv2d_use_mask=False,
linear_use_mask=False):

num_zeros = 0
num_elements = 0

for module_name, module in model.named_modules():

if isinstance(module, torch.nn.Conv2d):

module_num_zeros, module_num_elements, _ = measure_module_sparsity(
module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
num_zeros += module_num_zeros
num_elements += module_num_elements

elif isinstance(module, torch.nn.Linear):

module_num_zeros, module_num_elements, _ = measure_module_sparsity(
module, weight=weight, bias=bias, use_mask=linear_use_mask)
num_zeros += module_num_zeros
num_elements += module_num_elements

sparsity = num_zeros / num_elements

return num_zeros, num_elements, sparsity


def iterative_pruning_finetuning(model,
train_loader,
test_loader,
device,
learning_rate,
l1_regularization_strength,
l2_regularization_strength,
learning_rate_decay=0.1,
conv2d_prune_amount=0.4,
linear_prune_amount=0.2,
num_iterations=10,
num_epochs_per_iteration=10,
model_filename_prefix="pruned_model",
model_dir="saved_models",
grouped_pruning=False):

for i in range(num_iterations):

print("Pruning and Finetuning {}/{}".format(i + 1, num_iterations))

print("Pruning...")

if grouped_pruning == True:
# Global pruning
# I would rather call it grouped pruning.
parameters_to_prune = []
for module_name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
parameters_to_prune.append((module, "weight"))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=conv2d_prune_amount,
)
else:
for module_name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module,
name="weight",
amount=conv2d_prune_amount)
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module,
name="weight",
amount=linear_prune_amount)

_, eval_accuracy = evaluate_model(model=model,
test_loader=test_loader,
device=device,
criterion=None)

classification_report = create_classification_report(
model=model, test_loader=test_loader, device=device)

num_zeros, num_elements, sparsity = measure_global_sparsity(
model,
weight=True,
bias=False,
conv2d_use_mask=True,
linear_use_mask=False)

print("Test Accuracy: {:.3f}".format(eval_accuracy))
print("Classification Report:")
print(classification_report)
print("Global Sparsity:")
print("{:.2f}".format(sparsity))

# print(model.conv1._forward_pre_hooks)

print("Fine-tuning...")

train_model(model=model,
train_loader=train_loader,
test_loader=test_loader,
device=device,
l1_regularization_strength=l1_regularization_strength,
l2_regularization_strength=l2_regularization_strength,
learning_rate=learning_rate * (learning_rate_decay**i),
num_epochs=num_epochs_per_iteration)

_, eval_accuracy = evaluate_model(model=model,
test_loader=test_loader,
device=device,
criterion=None)

classification_report = create_classification_report(
model=model, test_loader=test_loader, device=device)

num_zeros, num_elements, sparsity = measure_global_sparsity(
model,
weight=True,
bias=False,
conv2d_use_mask=True,
linear_use_mask=False)

print("Test Accuracy: {:.3f}".format(eval_accuracy))
print("Classification Report:")
print(classification_report)
print("Global Sparsity:")
print("{:.2f}".format(sparsity))

model_filename = "{}_{}.pt".format(model_filename_prefix, i + 1)
model_filepath = os.path.join(model_dir, model_filename)
save_model(model=model,
model_dir=model_dir,
model_filename=model_filename)
model = load_model(model=model,
model_filepath=model_filepath,
device=device)

return model


def remove_parameters(model):

for module_name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
try:
prune.remove(module, "weight")
except:
pass
try:
prune.remove(module, "bias")
except:
pass
elif isinstance(module, torch.nn.Linear):
try:
prune.remove(module, "weight")
except:
pass
try:
prune.remove(module, "bias")
except:
pass

return model


def main():

num_classes = 10
random_seed = 1
l1_regularization_strength = 0
l2_regularization_strength = 1e-4
learning_rate = 1e-3
learning_rate_decay = 1

cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

model_dir = "saved_models"
model_filename = "resnet18_cifar10.pt"
model_filename_prefix = "pruned_model"
pruned_model_filename = "resnet18_pruned_cifar10.pt"
model_filepath = os.path.join(model_dir, model_filename)
pruned_model_filepath = os.path.join(model_dir, pruned_model_filename)

set_random_seeds(random_seed=random_seed)

# Create an untrained model.
model = create_model(num_classes=num_classes)

# Load a pretrained model.
model = load_model(model=model,
model_filepath=model_filepath,
device=cuda_device)

train_loader, test_loader, classes = prepare_dataloader(
num_workers=8, train_batch_size=128, eval_batch_size=256)

_, eval_accuracy = evaluate_model(model=model,
test_loader=test_loader,
device=cuda_device,
criterion=None)

classification_report = create_classification_report(
model=model, test_loader=test_loader, device=cuda_device)

num_zeros, num_elements, sparsity = measure_global_sparsity(model)

print("Test Accuracy: {:.3f}".format(eval_accuracy))
print("Classification Report:")
print(classification_report)
print("Global Sparsity:")
print("{:.2f}".format(sparsity))

print("Iterative Pruning + Fine-Tuning...")

pruned_model = copy.deepcopy(model)

iterative_pruning_finetuning(
model=pruned_model,
train_loader=train_loader,
test_loader=test_loader,
device=cuda_device,
learning_rate=learning_rate,
learning_rate_decay=learning_rate_decay,
l1_regularization_strength=l1_regularization_strength,
l2_regularization_strength=l2_regularization_strength,
conv2d_prune_amount=0.98,
linear_prune_amount=0,
num_iterations=1,
num_epochs_per_iteration=500,
model_filename_prefix=model_filename_prefix,
model_dir=model_dir,
grouped_pruning=True)

# Apply mask to the parameters and remove the mask.
remove_parameters(model=pruned_model)

_, eval_accuracy = evaluate_model(model=pruned_model,
test_loader=test_loader,
device=cuda_device,
criterion=None)

classification_report = create_classification_report(
model=pruned_model, test_loader=test_loader, device=cuda_device)

num_zeros, num_elements, sparsity = measure_global_sparsity(pruned_model)

print("Test Accuracy: {:.3f}".format(eval_accuracy))
print("Classification Report:")
print(classification_report)
print("Global Sparsity:")
print("{:.2f}".format(sparsity))

save_model(model=model, model_dir=model_dir, model_filename=model_filename)


if __name__ == "__main__":

main()

Caveats

Sparsity for Iterative Pruning

The prune.l1_unstructured function uses an amount argument which could be either the percentage of connections to prune (if it is a float between $0$ and $1$), or the absolute number of connections to prune (if it is a non-negative integer). When it is the percentage, it is the the relative percentage to the number of unmasked parameters in the module. For example, in iterative pruning, we prune the weights of a certain layer by amount=0.2 in the first iteration and further prune the same layer by amount=0.2 in the second iteration. The amount of the valid parameters after the pruning will be $1 \times (1 - 0.2) \times (1 - 0.2)$, and the sparsity of the parameters, i.e., the prune rate, in this module will be $1 - 1 \times (1 - 0.2) \times (1 - 0.2)$.

Formally, the final prune rate could be calculated using the following equation. Suppose the relative prune rate for each iteration is $\gamma$, the final prune rate, after $n$ iterations, will be

$$
1 - (1 - \gamma)^n
$$

Similarly, it is also easy to derive the final prune rate for the scenario that $\gamma$ is different in each iteration.

Local Pruning VS Grouped Pruning

Local pruning is to prune the parameters module by module. The parameters from other modules do not affect the parameters being pruned. We could specify the prune rate for each layer in the network explicitly.

Grouped pruning, sometimes referred as global pruning, grouped many different modules and prune the parameters in these modules as if they were from one module. We could also specify the prune rate explicitly. However, the prune rate for each individual layer will be different.

In our ResNet18-CIFAR10 example, group pruning performs much better than local pruning. With group pruning, we could maintain the prediction accuracy to be $86.8\%$ at a pruning rate of $98\%$, whereas with local pruning, we could only maintain the prediction accuracy to be around $82.8\%$ at a pruning rate of $94\%$.

One-Time VS Multi-Time Iterative Pruning + Fine-Tuning

Unlike one-time iterative pruning + fine-tuning which achieves the desired prune rate by pruning and fine-tuning once, multi-time iterative pruning + fine-tuning achieves the desired prune rate by pruning and fine-tuning multiple-times. For example, to achieve the desired prune rate of $98\%$, we could run pruning and fine-tuning for many iterations, achieving prune rate of $30\%$, $50\%$, $66\%$, $76\%$, $\cdots$, $98\%$ in each iteration.

Usually multi-time iterative pruning + fine-tuning is better than one-time iterative pruning + fine-tuning. However, in our ResNet18-CIFAR10 example, there is almost no difference. Using grouped pruning, both one-time iterative pruning + fine-tuning and multi-time iterative pruning + fine-tuning could maintain the prediction accuracy to be around $86.8\%$ at a prune rate of $98\%$.

Final Remarks

It seems that PyTorch has not supported converting the sparse neural networks to use sparse tensor. Once it is supported, we could really see how much faster it is to run a sparse neural network after pruning comparing to its original dense neural network before pruning.

References

Author

Lei Mao

Posted on

02-18-2021

Updated on

02-18-2024

Licensed under


Comments