LoRA and LoRAPrune
Introduction
In the era of large language models, because the number of parameters in a large language model is very large, even fine-tuning the model on a small dataset can be extremely computationally expensive.
Similar to fine-tuning a large language model, pruning a large language model can also be computationally expensive because the importance of all the parameters in a large language model has to be evaluated.
In this blog post, I would like to discuss how to accelerate fine-tuning large language models using low-rank adaptation LoRA and pruning large language models using LoRAPrune.
LoRA
LoRA assumes the parameter update matrix during fine-tuning is low-rank and decomposes the full-rank parameter update matrix into two low-rank matrices. During fine-tuning, instead of updating the full-rank parameter matrix, LoRA freezes the full-rank parameter matrix and only updates the two low-rank matrices that have much fewer parameters, which can significantly reduce the computational cost of fine-tuning a large language model.
Concretely, given a pre-trained weight matrix $W_{0} \in \mathbb{R}^{d \times k}$, the fine-tuning specific update matrix $\Delta W$ is decomposed into two low-rank matrices $B \in \mathbb{R}^{d \times r}$ and $A \in \mathbb{R}^{r \times k}$ and $\Delta W = B A$.
For the pre-trained model forward pass $h = W_{0} x$, the modified forward pass during fine-tuning becomes
$$
\begin{align}
h &= \left( W_{0} + \Delta W \right) x \\
&= \left( W_{0} + B A \right) x \\
&= W_{0} x + B A x \\
&= W_{0} x + B \left(A x\right) \\
\end{align}
$$
Note that to save the memory and computation cost, we usually first compute $A x$ and then compute $B \left(A x\right)$.
In the modified backward pass during fine-tuning, we only need to compute the gradients with respect to $A$ and $B$ and update $A$ and $B$ while the original full-rank weight matrix $W_{0}$ is frozen.
$$
\begin{align}
\frac{\partial \mathcal{L}}{\partial B}
&= \frac{\partial \mathcal{L}}{\partial h} \frac{\partial h}{\partial B} \\
&= \frac{\partial \mathcal{L}}{\partial h} \left(A x\right)^{\top} \\
\end{align}
$$
$$
\begin{align}
\frac{\partial \mathcal{L}}{\partial A}
&= \frac{\partial \mathcal{L}}{\partial h} \frac{\partial h}{\partial \left(A x\right)} \frac{\partial \left(A x\right)}{\partial A} \\
&= B^{\top} \frac{\partial \mathcal{L}}{\partial h} x^{\top} \\
\end{align}
$$
Note that the matrix $BA$ never exists in the actual computation in both the forward pass and the backward pass.
Just like RepVGG, the fine-tuning specific update matrix can be fused into the pre-trained weight matrix, resulting in no additional inference overhead in production. In other words, after fine-tuning, the fine-tuned weight matrix for production becomes
$$
\begin{align}
W &= W_{0} + \Delta W \\
&= W_{0} + B A \\
\end{align}
$$
This is the only time we compute the matrix $BA$.
LoRAPrune
LoRAPrune was proposed to accelerate the parameter importance evaluation process in which LoRA was adopted for neural network pruning and fine-tuning.
In the context of LoRA, the parameter importance evaluation of the group of parameters $\mathbf{W}_{\mathcal{S}}$ can be formulated as follows.
$$
\begin{align}
\mathcal{I}(\left(\mathbf{BA}\right)_{\mathcal{S}})
&= \left( \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right) \right) - \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\mathbf{\left(BA\right)}_{\mathcal{S}} = -W_{\mathcal{S}}} \right) \right)^{2}
\end{align}
$$
We define
$$
\begin{align}
f(\left(\mathbf{BA}\right)_{\mathcal{S}})
&= \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right) \right) - \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\mathbf{\left(BA\right)}_{\mathcal{S}}} \right)
\end{align}
$$
Note that
$$
\begin{align}
f(-W_{\mathcal{S}})
&= \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right) \right) - \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\mathbf{\left(BA\right)}_{\mathcal{S}} = -W_{\mathcal{S}}} \right)
\end{align}
$$
We can approximate the function $f(\left(\mathbf{BA}\right)_{\mathcal{S}})$ using the Taylor expansion in which $\mathbf{\left(BA\right)}_{\mathcal{S}}$ evaluated at the original parameter values in $\mathbf{BA}$ denoted as $\left(BA\right)_{\mathcal{S}}$.
$$
\begin{align}
f(\mathbf{\left(BA\right)}_{\mathcal{S}})
&= f(\left(BA\right)_{\mathcal{S}}) + \nabla f(\left(BA\right)_{\mathcal{S}}) \cdot (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \frac{1}{2} (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}})^{\top} \nabla^{2} f(\left(BA\right)_{\mathcal{S}}) (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \cdots
\end{align}
$$
We notice that $f(\left(BA\right)_{\mathcal{S}}) = 0$. Therefore, we have
$$
\begin{align}
f(\mathbf{\left(BA\right)}_{\mathcal{S}})
&= \nabla f(\left(BA\right)_{\mathcal{S}}) \cdot (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \frac{1}{2} (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}})^{\top} \nabla^{2} f(\left(BA\right)_{\mathcal{S}}) (\mathbf{\left(BA\right)}_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \cdots
\end{align}
$$
This function evaluated at $\left(\mathbf{BA}\right)_{\mathcal{S}} = -W_{\mathcal{S}}$ is
$$
\begin{align}
f(-W_{\mathcal{S}})
&= \nabla f(\left(BA\right)_{\mathcal{S}}) \cdot (-W_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \frac{1}{2} (-W_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}})^{\top} \nabla^{2} f(\left(BA\right)_{\mathcal{S}}) (-W_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) + \cdots
\end{align}
$$
The first order derivative $\nabla f(\left(\mathbf{BA}\right)_{\mathcal{S}})$ is the gradient of the loss function $\mathcal{L}$ with respect to the parameters $\left(\mathbf{BA}\right)_{\mathcal{S}}$.
$$
\begin{align}
\nabla f(\left(\mathbf{BA}\right)_{\mathcal{S}})
&= - \nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}}} \right)
\end{align}
$$
The first order derivative $\nabla f(\left(\mathbf{BA}\right)_{\mathcal{S}})$ evaluated at $\left(BA\right)_{\mathcal{S}}$, $\nabla f(\left(BA\right)_{\mathcal{S}})$, happens to have already been computed in the backpropagation during the neural network training and we can get it for free.
Thus, we have
$$
\begin{align}
f(-W_{\mathcal{S}})
&\approx \nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right) \cdot (-W_{\mathcal{S}} - \left(BA\right)_{\mathcal{S}}) \\
&= - \nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right) \cdot (W_{\mathcal{S}} + \left(BA\right)_{\mathcal{S}}) \\
\end{align}
$$
and the importance of the group of parameters $\mathbf{W}_{\mathcal{S}}$ can be approximated as
$$
\begin{align}
\mathcal{I}(\left(\mathbf{BA}\right)_{\mathcal{S}})
&\approx \left( \nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right) \cdot (W_{\mathcal{S}} + \left(BA\right)_{\mathcal{S}}) \right)^{2}
\end{align}
$$
This approximation is similar to the one we have derived for conventional neural networks without using LoRA. However, the problem is, as mentioned in the previous section, the matrix $BA$ never exists in the actual computation in both the forward pass and the backward pass. Recomputing $\left(BA\right)_{\mathcal{S}}$ is probably no big deal, but we still don’t know how to compute $\nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right)$ exactly.
In the original LoRAPrune paper, the authors proposed that $\nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right)$ is proportional to the change of $\left(\mathbf{BA}\right)_{\mathcal{S}}$ between the current time step $t$ and the previous time step $t-1$. To compute the $\left(\mathbf{BA}\right)_{\mathcal{S}}$ at the time step $t-1$ from the time step $t$, they used the gradient information with respect to $\mathbf{B}$ and $\mathbf{A}$.
In my opinion, there is no need to compute this term in such a complicated and approximated way. As what we have derived for the gradients with respect to the matrix $A$ and $B$ in the previous section, assuming $BA$ was actually computed,
$$
\begin{align}
h &= \left( W_{0} + \Delta W \right) x \\
&= \left( W_{0} + B A \right) x \\
&= W_{0} x + B A x \\
&= W_{0} x + \left(B A\right) x \\
\end{align}
$$
Then the exact $\nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right)$ is just
$$
\begin{align}
\nabla \mathcal{L}\left( \mathcal{D}, \left(\mathbf{BA}\right)|{\left(\mathbf{BA}\right)_{\mathcal{S}} = \left(BA\right)_{\mathcal{S}}} \right)
&= \left( \frac{\partial \mathcal{L}}{\partial h} x^{\top} \right)_{\mathcal{S}} \\
&= \left( \frac{\partial \mathcal{L}}{\partial h} \right)_{\mathcal{S}} \left( x^{\top} \right)_{\mathcal{S}} \\
\end{align}
$$
where $\left( \frac{\partial \mathcal{L}}{\partial h} \right)_{\mathcal{S}}$ and $\left( x^{\top} \right)_{\mathcal{S}}$ are the submatrices of the gradients with respect to $h$ and $x^{\top}$ whose product results in the gradients with respect to $\left(\mathbf{BA}\right)_{\mathcal{S}}$, respectively.
In the worse case, we could always cache the matrix the matrix $A$ and $B$ from the last time step and recompute the matrix $BA$ from the last time step. Storing the matrix $A$ and $B$ on memory is not a big deal because they are low-rank matrices.
References
LoRA and LoRAPrune