Parameter Importance Approximation Via Taylor Expansion In Neural Network Pruning
Introduction
In neural network pruning, usually we have to evaluate the importance of the parameters in a neural network using some criteria and remove the parameters with the smallest importance. In addition to simply evaluating the parameter importance as the absolute value of the parameter, we can also evaluate the parameter importance as the change in the loss function when the parameter is removed, which makes more sense intuitively.
However, such evaluation is usually computationally expensive. In this blog post, I would like to discuss how to approximate the importance of the parameters in a neural network using Taylor expansion to accelerate the parameter importance evaluation process.
Parameter Importance Evaluation
The importance of a parameter in a neural network can be quantified by the change in the loss function when the parameter is removed.
The importance of a parameter $w_{m}$ in a neural network if a function of $w_{m}$ denoted as $\mathcal{I}(w_{m})$. It can be formulated as follows.
$$
\mathcal{I}(w_{m}) = \left( \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{w_{m} = 0} \right) \right)^{2}
$$
where $\mathcal{L}\left( \mathcal{D}, \mathbf{W} \right)$ is the loss function of the neural network with all parameters $\mathbf{W}$, and $\mathcal{L}\left( \mathcal{D}, \mathbf{W}|{w_{m} = 0} \right)$ is the loss function of the neural network with the parameter $w_{m}$ being set to zero.
More generally, the parameters can be grouped and the importance of a group of parameters can be quantified by the change in the loss function when the group of parameters $\mathbf{W}_{\mathcal{S}}$ is removed.
$$
\mathcal{I}(\mathbf{W}_{\mathcal{S}}) = \left( \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = 0} \right) \right)^{2}
$$
where $\mathcal{L}\left( \mathcal{D}, \mathbf{W} \right)$ is the loss function of the neural network with all parameters $\mathbf{W}$, and $\mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = 0} \right)$ is the loss function of the neural network with the group of parameters $\mathbf{W}_{\mathcal{S}}$ being set to zero.
Quantifying the importance of all the parameters in a neural network can be computationally expensive, because we have to apply the above formula to each parameter or each group of the parameters in the neural network and the number of parameters or the number of groups of the parameters in a neural network can be very large.
We need to find a way to approximate the importance of the parameters in a neural network so that we can reduce the computational cost.
Parameter Importance Approximation Via Taylor Expansion
We define the difference between the loss function of the neural network with all parameters and the loss function of the neural network with the parameter $\mathbf{W}_{\mathcal{S}}$ that are subject to change a function of the parameter $\mathbf{W}_{\mathcal{S}}$ denoted as $f(\mathbf{W}_{\mathcal{S}})$.
$$
\begin{align}
f(\mathbf{W}_{\mathcal{S}}) &= \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right) \
\end{align}
$$
Note that
$$
\begin{align}
f(0) &= \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = 0} \right)
\end{align}
$$
We can approximate the function $f(\mathbf{W}_{\mathcal{S}})$ using the Taylor expansion in which $\mathbf{W}_{\mathcal{S}}$ evaluated at the original parameter values in $\mathbf{W}$ denoted as $W_{\mathcal{S}}$.
$$
\begin{align}
f(\mathbf{W}_{\mathcal{S}})
&= f(W_{\mathcal{S}}) + \nabla f(W_{\mathcal{S}}) \cdot (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \frac{1}{2} (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}})^{\top} \nabla^{2} f(W_{\mathcal{S}}) (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \cdots
\end{align}
$$
We notice that $f(W_{\mathcal{S}}) = 0$. Therefore, we have
$$
\begin{align}
f(\mathbf{W}_{\mathcal{S}})
&= \nabla f(W_{\mathcal{S}}) \cdot (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \frac{1}{2} (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}})^{\top} \nabla^{2} f(W_{\mathcal{S}}) (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \cdots
\end{align}
$$
This function evaluated at $\mathbf{W}_{\mathcal{S}} = 0$ is
$$
\begin{align}
f(0)
&= - \nabla f(W_{\mathcal{S}}) W_{\mathcal{S}} + \frac{1}{2} W_{\mathcal{S}}^{\top} \nabla^{2} f(W_{\mathcal{S}}) W_{\mathcal{S}} + \cdots
\end{align}
$$
The first order derivative $\nabla f(\mathbf{W}_{\mathcal{S}})$ is the gradient of the loss function $\mathcal{L}$ with respect to the parameters $\mathbf{W}_{\mathcal{S}}$.
$$
\begin{align}
\nabla f(\mathbf{W}_{\mathcal{S}}) &= \frac{\partial f(\mathbf{W}_{\mathcal{S}})}{\partial \mathbf{W}_{\mathcal{S}}} \\
&= \frac{\partial \left( \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right) \right)}{\partial \mathbf{W}_{\mathcal{S}}} \\
&= \frac{\partial \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right)}{\partial \mathbf{W}_{\mathcal{S}}} - \frac{\partial \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right)}{\partial \mathbf{W}_{\mathcal{S}}} \\
&= 0 - \frac{\partial \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right)}{\partial \mathbf{W}_{\mathcal{S}}} \\
&= - \frac{\partial \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right)}{\partial \mathbf{W}_{\mathcal{S}}} \\
&= - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right)
\end{align}
$$
The first order derivative $\nabla f(\mathbf{W}_{\mathcal{S}})$ evaluated at $W_{\mathcal{S}}$, $\nabla f(W_{\mathcal{S}})$, happens to have already been computed in the backpropagation during the neural network training and we can get it for free.
Similarly, the second order derivative $\nabla^{2} f(\mathbf{W}_{\mathcal{S}})$ is the Hessian of the loss function $\mathcal{L}$ with respect to the parameters $\mathbf{W}_{\mathcal{S}}$.
$$
\begin{align}
\nabla^{2} f(\mathbf{W}_{\mathcal{S}})
&= - \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}}} \right)
\end{align}
$$
The second order derivative $\nabla^{2} f(\mathbf{W}_{\mathcal{S}})$ evaluated at $W_{\mathcal{S}}$ is not free to compute. So usually we only take the first order derivative term in the Taylor expansion for approximation.
Thus, we have
$$
\begin{align}
f(0) \approx - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) W_{\mathcal{S}}
\end{align}
$$
and the importance of the group of parameters $\mathbf{W}_{\mathcal{S}}$ can be approximated as
$$
\begin{align}
\mathcal{I}(\mathbf{W}_{\mathcal{S}})
&= \left( f(0) \right)^{2} \\
&\approx \left( \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) W_{\mathcal{S}} \right)^{2}
\end{align}
$$
Note that this is just the square of a dot product between the vector of gradient with respect to the group of parameters $\mathbf{W}_{\mathcal{S}}$ and the vector of the values of the group of parameters $\mathbf{W}_{\mathcal{S}}$, which is straightforward to compute.
An alternative way to derive these formulas is just to perform Taylor expansion on the loss function $\mathcal{L}$ with respect to the group of parameters $\mathbf{W}_{\mathcal{S}}$.
$$
\begin{align}
\mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} \right)
&= \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) + \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) \cdot (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \frac{1}{2} (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}})^{\top} \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) (\mathbf{W}_{\mathcal{S}} - W_{\mathcal{S}}) + \cdots
\end{align}
$$
When the group of parameters $\mathbf{W}_{\mathcal{S}}$ is set to zero, we have
$$
\begin{align}
\mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = 0 \right)
&= \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) + \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) \cdot (- W_{\mathcal{S}}) + \frac{1}{2} (- W_{\mathcal{S}})^{\top} \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) (- W_{\mathcal{S}}) + \cdots \\
&= \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) \cdot W_{\mathcal{S}} + \frac{1}{2} W_{\mathcal{S}}^{\top} \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) W_{\mathcal{S}} + \cdots \\
\end{align}
$$
Thus, the change in the loss function $f(0)$ can be approximated as
$$
\begin{align}
f(0) &= \mathcal{L}\left( \mathcal{D}, \mathbf{W} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = 0} \right) \\
&= \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) - \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = 0} \right) \\
&= \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) - \left( \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) \cdot W_{\mathcal{S}} + \frac{1}{2} W_{\mathcal{S}}^{\top} \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W} | \mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}} \right) W_{\mathcal{S}} + \cdots \right) \\
&= - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) W_{\mathcal{S}} + \frac{1}{2} W_{\mathcal{S}}^{\top} \nabla^{2} \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) W_{\mathcal{S}} + \cdots \\
&\approx - \nabla \mathcal{L}\left( \mathcal{D}, \mathbf{W}|{\mathbf{W}_{\mathcal{S}} = W_{\mathcal{S}}} \right) W_{\mathcal{S}}
\end{align}
$$
Therefore, computing the importance of the group of parameters $\mathbf{W}_{\mathcal{S}}$ can be much faster than computing the importance of the group of parameters $\mathbf{W}_{\mathcal{S}}$ using the original formula.
Conclusions
Structured neural network pruning usually requires grouping the parameters in the neural network and pruning the group of parameters with the smallest importance. The importance of the group of parameters can be approximated using the Taylor expansion, which can be much faster than computing the importance of the group of parameters using the original non-approximated formula.
References
Parameter Importance Approximation Via Taylor Expansion In Neural Network Pruning