AWQ: Activation-Aware Weight Quantization
Introduction
To accelerate large language models (LLMs) inference, low bit-width group-wise quantization is a common technique used to reduce the memory footprint and increase the computational efficiency. To maintain the accuracy of the quantized model, usually only the weights are quantized, and the activations are kept in the half precision. Even so, sometimes the accuracy of the quantized model is still not satisfactory.
Activation-aware weight quantization (AWQ) is a novel quantization technique that can reduce the quantization error of group-wise weight-only quantization by scaling the weights and inversely scaling the activations. The analysis shows that the quantization error of group-wise weight-only quantization is activation-aware, meaning that the quantization error is proportional to the input activation magnitude. By reducing the input activation magnitude, the quantization error can be reduced. Additional kernel fusions are used to keep the computational efficiency of the AWQ model the same as the group-wise weight-only quantization model.
In this blog post, we will discuss the theory behind AWQ and how it can be used to quantize large language models (LLMs) to achieve the same performance as group-wise weight-only quantization but with better accuracy.
Group-Wise Quantization
Group-Wise Quantization, Per-Tensor Quantization and Per-Channel Quantization
To quantize a tensor, different elements in the tensor can be grouped together and quantized using the same quantization scale and quantization zero point. This is called group-wise quantization. The group size can be adjusted to balance the trade-off between quantization accuracy and quantization metadata size. In general, a larger group size leads to smaller quantization metadata size but lower quantization accuracy, and a smaller group size leads to larger quantization metadata size but higher quantization accuracy.
The commonly seen per-tensor quantization and per-channel quantization are special cases of group-wise quantization. In per-tensor quantization, the entire tensor is quantized as a single group, and in per-channel quantization, each channel is quantized as a single group.
Quantization Compression Ratio
Suppose we have a tensor of size $N$ consisting of elements of size $a$ bits, and we want to perform group-wise quantization with a group size of $G$ for the tensor so that the element size is reduced to $b$ bits. Assuming the (average) size of quantization metadata per group is $s$ bytes, the compression ratio of quantization can be defined as the ratio of the original tensor size to the quantized tensor size and the quantization metadata size.
$$
\begin{align}
R &= \frac{N a}{\frac{N}{G} s + N b} \\
&= \frac{a}{\frac{s}{G} + b}
\end{align}
$$
where $\frac{s}{G}$ is the extra bits per element introduced by quantization metadata.
In per-tensor quantization, $G = N$ and usually $N \gg s$, so $R = \frac{a}{b}$, which is the maximum compression ratio that can be achieved. However, per-tensor quantization sometimes leads to a large quantization error, and group-wise quantization with a smaller group size can be used to reduce the quantization error. The goal is to keep the compression ratio, $R$, and the ratio of extra bits per element over the quantized element size, i.e., $\frac{s}{Gb}$, as small as possible.
In some neural networks such as large language models (LLMs), because the autoregressive decoding process is usually memory-bound (even if batching is used), the higher the compression ratio, the faster the decoding process. In order to achieve a high compression ratio, very low bitwidth quantization, e.g., 4-bit quantization, are used, i.e., $b = 4$ bits. Usually, the (average) size of quantization metadata per group $s$ is dominated by the quantization scale size which can just be a half precision floating point number, i.e., $s = 16$ bits, for symmetric quantization. Therefore, the key factor that affects the compression ratio is the group size $G$. By increasing the group size, the compression ratio can be increased, but the quantization error may also be increased. Because $b = 4$ is already very small, $G$ cannot be too small, otherwise the compression ratio will be significantly affected. For example, $s = 16$, $b = 4$, $G = 32$, we have $\frac{s}{G} = \frac{16}{32} = 0.5$, which is the extra bits per element introduced by quantization metadata. The performance of the quantized model can be $\frac{s}{Gb} = \frac{16}{32 \times 4} = 0.125 = 12.5\%$ worse than a quantized model using the same bitwidth but with per-tensor quantization.
Group-Wise Quantization In Practice
While group-wise quantization can have finer granularity or coarser granularity than per-channel quantization theoretically, in practice and in the context of AWQ or many papers from Song Han’s lab, group-wise quantization refers to grouping multiple channels together. Accordingly, the group size, $G$, refers to the number of channels in a group, and the number of elements $N$ is the number of total channels in the tensor. That is to say, a neuron is counted as one single element.
AWQ: Activation-Aware Weight Quantization
Group-Wise Weight-Only-Quantization
Assuming the group in group-wise quantization consists of multiple channels and all the channels in the group share the same quantization meta data. Consider a weight matrix $\mathbf{w}$ consisting a group of $G$ channels, the corresponding input activation tensor $\mathbf{x}$, and the output activation tensor $\mathbf{y}$, $\mathbf{y} = \mathbf{w} \mathbf{x}$. The weight-only-quantization counterpart is $\mathbf{y} \approx Q(\mathbf{w}) \mathbf{x}$, where $Q(\mathbf{w})$ is the quantized and dequantized weight matrix. More specifically, the symmetric quantization and dequantization function $Q$ is defined as
$$
\begin{align}
Q(\mathbf{w}) &= \Delta \cdot \text{Round}\left(\frac{\mathbf{w}}{\Delta}\right) \\
\end{align}
$$
where
$$
\begin{align}
\Delta &= \frac{\max( \lvert \mathbf{w} \rvert )}{2^{b-1}-1}
\end{align}
$$
is commonly denoted as the quantization scale, and $b$ is the bitwidth of the quantized value.
Group-Wise Weight-Only-Quantization Error
Unlike activation quantization, the quantization process of weight-only-quantization only involves rounding errors and there is no clipping involved, because the range of the weight matrix is always known and fixed. The quantization absolute error of weight-only-quantization multiplication $y = Q(w) x$, note it is not in a matrix multiplication form, can be expressed as
$$
\begin{align}
\text{Err}(Q(w) x) = \Delta \cdot \text{RoundErr}\left(\frac{w}{\Delta}\right) \cdot x \\
\end{align}
$$
where
$$
\begin{align}
\text{RoundErr}\left(\frac{w}{\Delta}\right) = \left\lvert \frac{w}{\Delta} - \text{Round}\left(\frac{w}{\Delta}\right) \right\rvert
\end{align}
$$
Assuming each variable in $\frac{w}{\Delta}$ is uniformly distributed, the absolute error of rounding will follow a uniform distribution in the range of $[0, 0.5]$. Therefore, the expected value of the absolute error of rounding is, $\mathbb{E}\left[\text{RoundErr}\left(\frac{w}{\Delta}\right)\right] = 0.25$, which is a constant. In this case, the quantization absolute error of weight-only-quantization multiplication is proportional to the quantization scale $\Delta$ and the input activation $x$. In one group, because $\Delta$ is shared by all the weights in the group, the quantization error of the group is proportional only to the input activation $x$. This means, the quantization error of weight-only-quantization is activation-aware. If the input activation magnitude is small, the quantization error is small, and if the input activation magnitude is large, the quantization error is large. Therefore, if we could somehow reduce the input activation magnitude in weight-only-quantization, the quantization error can be reduced.
Scaling Weights and Activations
An intuitive idea to reduce the input activation magnitude is to scale the weights and inversely scale the activations in weight-only-quantization. That is to say,
$$
\begin{align}
\mathbf{y} &\approx Q(\mathbf{w}) \mathbf{x} \\
&\approx Q(\mathbf{w} \cdot \text{diag}(\mathbf{s})) (\text{diag}(\mathbf{s})^{-1} \cdot \mathbf{x})
\end{align}
$$
where $\mathbf{s}$ is a vector of scaling factors for each channel in the group, indicating the scaling factor can be different for each channel.
In weight-only-quantization where the weights are scaled by $s$, we have
$$
\begin{align}
Q(\mathbf{w} \cdot \text{diag}(\mathbf{s})) &= \Delta^{\prime} \cdot \text{Round}\left(\frac{\mathbf{w} \cdot \text{diag}(\mathbf{s})}{\Delta^{\prime}}\right) \\
\end{align}
$$
where
$$
\begin{align}
\Delta^{\prime} &= \frac{\max( \lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert )}{2^{b-1}-1}
\end{align}
$$
The quantization absolute error of weight-only-quantization multiplication where the weights are scaled by $s$ becomes
$$
\begin{align}
\text{Err}(Q(w s) \frac{x}{s}) = \Delta^{\prime} \cdot \text{RoundErr}\left(\frac{w s}{\Delta^{\prime}}\right) \cdot \frac{x}{s} \\
\end{align}
$$
Assuming $\frac{w s}{\Delta^{\prime}}$ also follows a uniform distribution, the expected value of the absolute error of rounding remains $\mathbb{E}\left[\text{RoundErr}\left(\frac{w s}{\Delta^{\prime}}\right)\right] = 0.25$, which is the same as the one without scaling.
The the quantization error reduction factor can be defined $\frac{\Delta^{\prime}}{\Delta} \cdot \frac{1}{s}$ as the following because the absolute error of rounding is the same with or without scaling.
$$
\begin{align}
\frac{\text{Err}(Q(w s) \frac{x}{s}) }{\text{Err}(Q(w) x)} &= \frac{\Delta^{\prime}}{\Delta} \cdot \frac{1}{s}
\end{align}
$$
If $s > 1$ and the quantization scale $\Delta^{\prime}$ remains the same or does not increase too much such that the inverse scaling of activation is not completely canceled out, the quantization error of weight-only-quantization can be reduced.
Suppose $\max({\mathbf{w}})$ is the maximum value in the weight matrix $\mathbf{w}$, $w$ can be scaled by a factor $s = \frac{\max(\lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert)}{\lvert w \rvert}$ to maximize the quantization error reduction and the error can be guaranteed to be no worse than the error without scaling, because
$$
\begin{align}
\frac{\Delta^{\prime}}{\Delta} \cdot \frac{1}{s} &= \frac{\max( \lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert )}{\max( \lvert \mathbf{w} \rvert )} \cdot \frac{1}{s} \\
&= \frac{\max( \lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert )}{\max( \lvert \mathbf{w} \rvert )} \cdot \frac{\lvert w \rvert}{\max(\lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert)} \\
&= \frac{\lvert w \rvert}{\max(\lvert \mathbf{w} \rvert)} \\
&\leq 1 \\
\end{align}
$$
Even if the scaling is brainlessly the same for each channel, i.e., $\mathbf{s} = \{s, s, …, s\}$, the quantization error will remain the same as the one without scaling, because
$$
\begin{align}
\frac{\Delta^{\prime}}{\Delta} \cdot \frac{1}{s} &= \frac{\max( \lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert )}{\max( \lvert \mathbf{w} \rvert )} \cdot \frac{1}{s} \\
&\leq \frac{\max( \lvert \mathbf{w} \rvert s )}{\max( \lvert \mathbf{w} \rvert )} \cdot \frac{1}{s} \\
&= s \cdot \frac{1}{s} \\
&= 1 \\
\end{align}
$$
The only scenario where the quantization error of weight-only-quantization can be increased by scaling is when the large weights are scaled by large factors and the small weights are scaled by small factors. In this case, for some channels whose weights are small, $\frac{\Delta^{\prime}}{\Delta} \geq s$.
$$
\begin{align}
\frac{\Delta^{\prime}}{\Delta} \cdot \frac{1}{s} &= \frac{\max( \lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert )}{\max( \lvert \mathbf{w} \rvert )} \cdot \frac{1}{s} \\
&\geq s \cdot \frac{1}{s} \\
&= 1 \\
\end{align}
$$
However, depending on the distribution of the weights across different channels, even if the quantization error of some channels is increased, the quantization error of the entire group can still be reduced.
Optimizing AWQ Scaling Factors
Even though we have shown that scaling the weights and inversely scaling the activations can usually reduce the quantization error of weight-only-quantization, we still don’t know how to optimize those scaling factors. The formula $s = \frac{\max(\lvert \mathbf{w} \cdot \text{diag}(\mathbf{s}) \rvert)}{\lvert w \rvert}$ mentioned previously was only used for illustration and could not be used for optimization.
The optimization goal is as simple as follows:
$$
\begin{align}
\mathbf{s}^{\ast} &= \arg \min_{\mathbf{s}} \mathcal{L}(\mathbf{s})
\end{align}
$$
where
$$
\begin{align}
\mathcal{L}(\mathbf{s}) &= \mathbb{E}_{\mathbf{x} \sim \mathbf{X}} \left[ \left \lVert Q(\mathbf{w} \cdot \text{diag}(\mathbf{s})) (\text{diag}(\mathbf{s})^{-1} \cdot \mathbf{x}) - \mathbf{w} \mathbf{x} \right \rVert^{2} \right]
\end{align}
$$
However, because the non-differentiable quantization and dequantization function $Q$, this problem cannot be directly optimized by gradient descent. Approximate gradient-based optimization, such as the straight-through estimator (STE), can be used but seems to suffer from unstable convergence based on the findings from the paper.
Therefore, some heuristics are used for the optimization instead based on our previous analysis: large scaling factors should be used for weight channels for large input activations to reduce the quantization error. There are also rooms for weight channels of small input activations to reduce the quantization error by scaling as well, after the large scaling factors are applied to the weight channels for large input activations, but not as much as the large ones.
The optimization surrogate becomes as follows:
$$
\begin{align}
\mathbf{s}^{\ast} &= \mathbf{s}_{\mathbf{x}} ^{\alpha^{\ast}}
\end{align}
$$
where
$$
\begin{align}
\alpha^{\ast} &= \arg \min_{\alpha} \mathcal{L}\left( \mathbf{s}_{\mathbf{x}} ^{\alpha} \right)
\end{align}
$$
and $\mathbf{s}_{\mathbf{x}} = \mathbb{E}_{\mathbf{x} \sim \mathbf{X}} \left[ \mathbf{x} \right]$ is the average magnitude of the input activations, usually computed from a calibration dataset that represents the input activations during inference.
The optimum scaling factor $\alpha^{\ast}$ can be found by grid search in a range of $\alpha \in [0, 1]$.
Operator Fusion and CUDA Acceleration
The existing high-performance CUDA kernels developed for group-wise weight-only-quantization might be used for AWQ. Similar to $Q(\mathbf{w})$, $Q(\mathbf{w} \cdot \text{diag}(\mathbf{s}))$ can be computed offline and there is still only one quantization metadata per group, the CUDA kernel will see no difference between them. The only issue is the input activation $\mathbf{x}$ has to scaled by $\text{diag}(\mathbf{s})^{-1}$ before the existing kernel can be called. If this activation scaling operation is performed separately, additional computation overhead will be introduced, compromising the performance of neural network inference, which is undesirable. Fortunately, this activation scaling operation can usually be fused into the layer right before the layer to be weight-only-quantized.
In LLMs, the layer before the layer to be weight-only-quantized is usually the layer normalization layer and due to the elementwise affine operation in layer normalization, the number of weights and biases of layer norm, i.e., the embedding size, is the same as the number of input channels to the layer to be weight-only-quantized, i.e., the size of $\mathbf{s}$. Therefore, the scaling operation can be fused into the layer normalization layer, and the neural network inference performance will not be compromised. A reference implementation of the fused layer normalization and the activation scaling operation can be found in Torch Init.
FAQ
Performance and Accuracy of AWQ VS Per-Channel Quantization
Assuming group-wise weight-only-quantization is coarser than per-channel quantization, the performance of group-wise weight-only quantization is usually faster than per-channel quantization, because the quantization metadata size is smaller. The quantization error of group-wise weight-only-quantization, however, is usually larger than per-channel quantization. AWQ reuses the group-wise weight-only-quantization CUDA kernels and acceleration techniques, so its performance will usually be the same as the ordinary group-wise weight-only-quantization. The quantization error of AWQ is usually lower than the ordinary group-wise weight-only-quantization, but it will not be better than the per-channel quantization. Therefore, AWQ resides somewhere between the ordinary group-wise weight-only-quantization and the per-channel quantization, and can usually just replace the ordinary group-wise weight-only-quantization if performance is more critical than quantization error.
Conclusions
AWQ is a post-training group-wise weight-only-quantization technique that results in lower quantization errors than the vanilla post-training group-wise weight-only-quantization.
References
AWQ: Activation-Aware Weight Quantization
https://leimao.github.io/blog/AWQ-Activation-Aware-Weight-Quantization/