Pruning for Neural Networks

Introduction

Neural networks have become extremely popular nowadays since they usually can “solve” complex artificial intelligence problems that other conventional models cannot. Currently, the convention is with more data and larger neural network, we could archive better accuracy from training the neural networks. GPT-$3$, the state-of-the-art language model, is made up of $175$ billion parameters ($700$ GB for FP$32$ precision). While these neural networks are doing impressively well, they are costly to run and not applicable for edge devices.

Sparse neural networks, comparing to dense neural network, have sparse weight and neuron tensors in the model. Therefore, the number of useful parameters in the model is smaller, and the model will be more suitable for running on edge devices. Recent studies show that sparse neural networks could perform as good as its corresponding dense neural networks. Such sparse neural network could be found by pruning dense neural networks.

In this article post, I would like to elucidate the mathematics for pruning neural networks and discuss the general protocol for pruning neural works.

Sparse Tensor VS Dense Tensor

Let’s look at a two dimensional tensor (matrix) $W \in \mathbb{R}^{m \times n}$ as an example.

Data Structure

Dense Tensor Data Structure

The dense tensor data structure could be illustrated using a tensor mathematical expression.

$$
\begin{align}
W &=
\begin{bmatrix}
W_{1,1} & W_{1,2} & \cdots & W_{1,n}\\
W_{2,1} & W_{2,2} & \cdots & W_{2,n}\\
\vdots & \vdots & \ddots & \vdots \\
W_{m,1} & W_{m,2} & \cdots & W_{m,n}\\
\end{bmatrix} \nonumber\\
\end{align}
$$

The memory used for a dense tensor data structure is determined by the its shape and is fixed during runtime. A sparse tensor could also be represented using dense data tensor structure. Most of the entries in the dense tensor data structure will be $0$.

Sparse Tensor Data Structure

$$
\begin{align}
W &=
\begin{bmatrix}
m \\
n \\
(i_1, j_1), W_{i_1, j_1} \\
(i_2, j_2), W_{i_2, j_2} \\
\vdots \\
(i_k, j_k), W_{i_k, j_k} \\
\vdots \\
\end{bmatrix} \nonumber\\
\end{align}
$$

where $m$ and $n$ are the shape of the tensor, $(i_k, j_k)$ are the indices of the entries in the tensor whose values are not $0$.

The memory used for the sparse tensor data structure could be much smaller than the dense tensor data structure if the tensor is truly sparse. However, if a dense tensor is represented using the sparse tensor data structure, the memory usage would be much higher than the dense tensor data structure, because we also recorded the shape of the tensor and the indices of the non-zero entries.

Tensor Multiplication

Suppose we have two matrices, $W_1 \in \mathbb{R}^{m_1 \times m_2}$ and $W_2 \in \mathbb{R}^{m_2 \times m_3}$.

To do matrix multiplication $W_3 = W_1 \times W_2$, for each entry in $W_3$, we have to do

$$
W_{3, i, j} = \sum_{k=1}^{m_2} W_{1, i, k} W_{2, k, j}
$$

So we need to do $m_1 m_2 m_3$ multiplications and $m_1 m_2 m_3$ additions in order to complete $W_3 = W_1 \times W_2$ using dense matrix multiplications.

Calculating the number of multiplications and additions for $W_3 = W_1 \times W_2$ using sparse matrix multiplications is more complicated and it is determined by the matrix content during runtime. There could be many different implementations for tensor multiplications using the sparse tensor data structure. However, the convention and the principle are that if the matrices are truly sparse, the number of multiplications and additions required is much smaller and the additional overhead should not be significant, and thus the computation is faster.

Neural Network Pruning

Neural Network Representation

We have $\mathbf{x}_1 \in \mathbb{R}^{m_1}$, $\mathbf{W}_1 \in \mathbb{R}^{m_2 \times m_1}$, $\mathbf{b}_1 \in \mathbb{R}^{m_2}$, $\mathbf{x}_2 \in \mathbb{R}^{m_2}$, $\mathbf{W}_2 \in \mathbb{R}^{m_3 \times m_2}$, $\mathbf{b}_2 \in \mathbb{R}^{m_3}$, $\cdots$ All the tensors were represented using the dense tensor data structure.

$$
\begin{gather}
\mathbf{x}_2 = \mathbf{W}_1 \mathbf{x}_1 + \mathbf{b}_1 \\
\mathbf{x}_3 = \mathbf{W}_2 \mathbf{x}_2 + \mathbf{b}_2 \\
\vdots
\end{gather}
$$

For the sake of simplicity, we ignored the activation functions used between the two adjacent layers.

We train the neural network using training data, and weight matrices $\mathbf{W}_1$, $\mathbf{b}_1$, $\mathbf{W}_2$, $\mathbf{b}_2$, $\cdots$ were determined.

Pruning Weights

The most critical step for neural network pruning is to find out the unimportant synapse connections, i.e., weights, and set the weights to exactly zero. This step is also called pruning weights.

We used binary masks tensors $\mathbf{M}_{1} \in \{0,1\}^{m_2 \times m_1}$, $\mathbf{M}_{2} \in \{0,1\}^{m_3 \times m_2}$, $\mathbf{m}_{1} \in \{0,1\}^{m_2}$, $\mathbf{m}_{2} \in \{0,1\}^{m_3}$ to indicate which synapse connections, i.e., some weights $w_{i,j} \in \mathbf{W}_1, \mathbf{W}_2$ or biases $b_{i} \in \mathbf{b}_1, \mathbf{b}_2$, are not important and should be pruned. These binary masks are constant during neural network training or fine-tuning.

There are many ways to prune weights. Some straightforward methods use the magnitude of the weights to determine which weights are not important. For example,

$$
\begin{align}
\mathbf{M}_{i, j} =
\begin{cases}
0 & \text{if $|\mathbf{W}_{i, j}| < \lambda$} \\
1 & \text{else} \\
\end{cases}
\end{align}
$$

Other methods could be more sophisticated, such as using the second-order derivatives, Hessian, to determine which weights are not important.

The pruning methods could also be categorized as structured and unstructured pruning methods. For structured pruning methods, the locations of pruned weights in the weight tensor are structured, thus the remaining weight tensor after pruning is still well-suited to be represented using the dense tensor data structure or some certain data structures. Typical structured pruning methods uses regularization techniques such as group lasso. For unstructured pruning methods, since the locations of pruned weights do not follow certain pattern or rules, the remaining weight tensor after pruning could usually only be represented using the sparse tensor data structure.

Once the mask tensors were determined, the unimportant weights were “pruned” by multiplying the weight tensor with the mask tensor. Mathematically, the neural networks are actually doing

$$
\begin{gather}
\mathbf{x}_2 = (\mathbf{W}_1 \odot \mathbf{M}_{1}) \mathbf{x}_1 + (\mathbf{b}_1 \odot \mathbf{m}_{1}) \\
\mathbf{x}_3 = (\mathbf{W}_2 \odot \mathbf{M}_{2}) \mathbf{x}_2 + (\mathbf{b}_2 \odot \mathbf{m}_{2}) \\
\vdots
\end{gather}
$$

We could have finished our pruning for the neural network. However, empirically the one-time pruned model usually does not perform as good as the original model regarding the accuracy. So a good amendment is to fine-tune the pruned neural network.

Fine-Tuning

Fine-tuning the pruned neural network is almost the same as fine-tuning an ordinary neural network. The only difference is that this time we have constant mask tensors in our neural network. Let’s see how the mask tensors affect the fine-tuning.

Forward Propagation

The forward propagation in neural networks has always been straightforward. For example, to compute the $\mathbf{x}_{2, i}$ from $\mathbf{x}_{2}$, we have

$$
\begin{align}
\mathbf{x}_{2, i} &= \bigg( \sum_{j = 1}^{m_1} (\mathbf{W}_{1, i, j} \mathbf{M}_{1, i, j}) \mathbf{x}_{1, j} \bigg) + (\mathbf{b}_{1, i} \odot \mathbf{m}_{1, i}) \\
\end{align}
$$

Back Propagation

The back propagation is slightly different from the conventional one given we have a mask, to compute the derivative with respect to $\mathbf{W}_{1, i, j}$, we have

$$
\begin{align}
\frac{\partial \mathbf{x}_{2, i} }{\partial \mathbf{W}_{1, i, j}}
&= \mathbf{M}_{1, i, j} \mathbf{x}_{1, j} \\
&=
\begin{cases}
0 & \text{if $\mathbf{M}_{1, i, j} = 0$} \\
\mathbf{x}_{1, j} & \text{if $\mathbf{M}_{1, i, j} = 1$} \\
\end{cases}
\end{align}
$$

$$
\begin{align}
\frac{\partial L }{\partial \mathbf{W}_{1, i, j}}
&= \frac{\partial L }{\partial \mathbf{x}_{2, i}} \frac{\partial \mathbf{x}_{2, i} }{\partial \mathbf{W}_{1, i, j}} \\
&=
\begin{cases}
0 & \text{if $\mathbf{M}_{1, i, j} = 0$} \\
\mathbf{x}_{1, j} \frac{\partial L }{\partial \mathbf{x}_{2, i}} & \text{if $\mathbf{M}_{1, i, j} = 1$} \\
\end{cases}
\end{align}
$$

where $L$ is the loss.

So if $\mathbf{M}_{1, i, j} = 0$ indicating the synapse connection between $\mathbf{x}_{2, i}$ and $\mathbf{x}_{1, j}$ is dead, during back propagation, the value of $\mathbf{W}_{1, i, j}$ remains unchanged.

Note that having a batch does not change above fact. We would used the expected value of the gradient to update weights.

$$
\begin{align}
\mathbb{E} \bigg( \frac{\partial L }{\partial \mathbf{W}_{1, i, j}} \bigg)
&= \mathbb{E} \bigg( \frac{\partial L }{\partial \mathbf{x}_{2, i}} \frac{\partial \mathbf{x}_{2, i} }{\partial \mathbf{W}_{1, i, j}} \bigg) \\
&= \frac{1}{n} \sum_{k=1}^{n} \bigg( \frac{\partial L }{\partial \mathbf{x}_{2, i}^{(k)}} \frac{\partial \mathbf{x}_{2, i}^{(k)} }{\partial \mathbf{W}_{1, i, j}} \bigg) \\
&=
\begin{cases}
0 & \text{if $\mathbf{M}_{1, i, j} = 0$} \\
\frac{1}{n} \sum_{k=1}^{n} \mathbf{x}_{1, j}^{(k)} \frac{\partial L }{\partial \mathbf{x}_{2, i}^{(k)}} & \text{if $\mathbf{M}_{1, i, j} = 1$} \\
\end{cases}
\end{align}
$$

where $\mathbf{x}_{j}^{(k)}$ is the value of neuron $\mathbf{x}_{j}$ for sample $k$ in the batch.

We will iterate the pruning weights step and the fine-tuning step many times to obtain the desired model. This is called “iterative pruning”.

By some orchestration, the number of unimportant weights found could be incremental. For example, starting with $\mathbf{M}_1 = {1}^{m_2 \times m_1}$, $\mathbf{m}_1 = {1}^{m_2}$, $\mathbf{M}_2 = {1}^{m_3 \times m_2}$, $\mathbf{m}_2 = {1}^{m_3}$, $\cdots$, the algorithm for pruning weights makes sure that the mask entry values could only be set from $1$ to $0$ but not from $0$ to $1$. This will allow us to achieve certain percentage of the pruning eventually, which is very useful for running neural networks on resource-constrained devices.

Finalizing Weights

Once the fine-tuning is finished, the model weights becomes finalized. We would combine the mask and weight together.

$$
\begin{gather}
\mathbf{W}_1 \leftarrow \mathbf{W}_1 \odot \mathbf{M}_{1} \\
\mathbf{W}_2 \leftarrow \mathbf{W}_2 \odot \mathbf{M}_{2} \\
\vdots
\end{gather}
$$

$$
\begin{gather}
\mathbf{b}_1 \leftarrow \mathbf{b}_1 \odot \mathbf{m}_{1} \\
\mathbf{b}_2 \leftarrow \mathbf{b}_2 \odot \mathbf{m}_{2} \\
\vdots
\end{gather}
$$

Now $\mathbf{W}_1$, $\mathbf{W}_2$, $\cdots$, $\mathbf{b}_1$, $\mathbf{b}_2$, $\cdots$, have become (very) sparse tensors, and the mask tensors $\mathbf{M}_{1}$, $\mathbf{M}_{2}$, $\cdots$, $\mathbf{m}_{1}$, $\mathbf{m}_{2}$, $\cdots$, are no longer useful.

Note that so far we have been using the dense tensor data structure to run all the computation for the neural network.

Pruning Neural Networks

Once all the pruning weights are done, the weights are still represented using dense matrix data structure. This means that even though the weight tensors are sparse, the computation cost for the sparse model is exactly the same as the computation cost for the dense model. To reduce the computation cost, especially for the unstructured pruning algorithms, we would have to use the sparse tensor data structure to represent the weights and the intermediate neurons.

The weight sparsity is introduced by various kind of mechanisms. However, the neuron sparsity is usually directly caused by the sparsity of the weights. For example,

$$
\begin{align}
\mathbf{x}_{2, i} &= \bigg( \sum_{j = 1}^{m_1} \mathbf{W}_{1, i, j} \mathbf{x}_{1, j} \bigg) + \mathbf{b}_{1, i} \\
\end{align}
$$

So if $\mathbf{W}_{1, i, j} = 0$ for all $j \in [1, m_1]$ and $\mathbf{b}_{1, i} = 0$, we must have $\mathbf{x}_{2, i} \equiv 0$. In this case, the neuron $\mathbf{x}_{2, i}$ becomes a dead neuron, and no longer affects the next layer. So we could set $\mathbf{M}_{2, k, i} = 0$ for all $k \in [1, m_3]$, indicating $\mathbf{x}_{2, i}$ is a dead neuron.

We will iterate through all the layers in the neural network, finding out all the dead synapse connections and dead neurons. With these information, we could use the sparse tensor data structure to represent the sparse weights tensors and sparse neuron tensors. In this way, the memory consumption and the computation cost of the model becomes much smaller.

Note that for unstructured pruning methods this pruning step is rarely done in practice especially because the sparse tensor multiplication has been rarely supported in deep learning frameworks and hardware so far. Most pruning research projects only concerns about how much weights and neuron they could prune the model in theory but not measuring how much the pruned model is faster than the intact model. To the best of my knowledge, I have not seen unstructured-pruned neural networks running on certain hardware.

Others

There are pruning algorithms which directly prune neurons, which is more aggressive. For example, MorphNet put regularization on the scale parameters in the batch normalization layer. The sparse scale parameters have to be element-wise multiplied with the neurons.

References

Author

Lei Mao

Posted on

03-01-2021

Updated on

03-01-2021

Licensed under


Comments