Transformer Vanilla Attention Performance Theoretical Analysis
Introduction
The transformer is a neural network architecture that has been widely used in different artificial intelligence tasks, including the most popular large language models. To serve large language models, it is critical to understand the performance bottleneck of the transformer architecture so that corresponding optimizations can be made to improve the performance of the transformer.
In this blog post, because the attention layers are the performance-dominated layers in the transformer, we will analyze the performance of the vanilla attention layer in the transformer theoretically.
Transformer Vanilla Attention Performance Theoretical Analysis
The inputs to the attention layer are the query tensor $Q \in \mathbb{R}^{b \times l_{q} \times d_{qk}}$, key tensor $K \in \mathbb{R}^{b \times l_{kv} \times d_{qk}}$, and value tensor $V \in \mathbb{R}^{b \times l_{kv} \times d_{v}}$, where $b$ is the batch size, $l_{q}$ is the query sequence length, $l_{kv}$ is the key and value sequence length, $d_{qk}$ is the query and key dimension, and $d_{v}$ is the value dimension. They are usually linearly transformed from the input tensor or input tensors of the transformer block for self-attention or cross-attention, respectively.
The attention layer for each single head in transformer is defined as follows.
$$
Y = \text{softmax} \left( \text{Mask} \left( \frac{ Q K^{\top}}{\sqrt{d_{qk}}} \right) \right) V
$$
The resulting tensor $Y \in \mathbb{R}^{b \times l_{q} \times d_{v}}$ is the output of the masked attention layer. The resulting tensors from multiple attention heads will be concatenated and linearly transformed to the output tensor of the transformer block.
To compute the arithmetic intensity of the masked attention, we need to compute the number of math operations and the number of memory accesses in bytes.
The number of memory accesses in bytes is straightforward to compute. The query tensor $Q$, key tensor $K$, and value tensor $V$ are read once, and the resulting tensor $Y$ is written once. Therefore, the number of memory accesses in bytes asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{v} + b \cdot l_{q} \cdot d_{v})$.
The number of math operations is also straightforward to compute. The number of math operations involved in the $Q K^{\top}$ matrix multiplication is $b \cdot l_{q} \cdot l_{kv} \cdot d_{qk}$ multiplications and $b \cdot l_{q} \cdot l_{kv} \cdot d_{qk}$ additions. The number of math operations involved in the mask, scale, and softmax operation are negligible compared to the previous matrix multiplication. The number of math operations involved in the subsequent matrix multiplication with $V$ is $b \cdot l_{q} \cdot l_{kv} \cdot d_{v}$ multiplications and $b \cdot l_{q} \cdot l_{kv} \cdot d_{v}$ additions. Therefore, the number of math operations asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b \cdot l_{q} \cdot l_{kv} \cdot d_{v})$.
Thus the arithmetic intensity of the single-head masked attention can be computed as
$$
\begin{align}
\frac{\mathcal{\Theta}(b \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b \cdot l_{q} \cdot l_{kv} \cdot d_{v})}{\mathcal{\Theta}(b \cdot l_{q} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{v} + b \cdot l_{q} \cdot d_{v})}
&= \mathcal{\Theta}\left( \frac{ b \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b \cdot l_{q} \cdot l_{kv} \cdot d_{v} }{b \cdot l_{q} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{v} + b \cdot l_{q} \cdot d_{v}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + l_{kv} \cdot d_{qk} + l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right) \\
\end{align}
$$
Usually $d_{qk}$ and $d_{v}$ are of the same value or at least of the same order of magnitude. Assuming $d_{qk} = d_{v} = d_{m}$, we have the following simplification.
$$
\begin{align}
\mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + l_{kv} \cdot d_{qk} + l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right)
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} + l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + l_{kv} \cdot d_{m} + l_{kv} \cdot d_{m} + l_{q} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 2 \cdot l_{q} \cdot l_{kv} \cdot d_{m} }{2 \cdot l_{q} \cdot d_{m} + 2 \cdot l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} }{l_{q} + l_{kv}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{l_{q}} + \frac{1}{l_{kv}}} \right) \\
\end{align}
$$
When there are multiple heads used in the attention layer, i.e., multi-head attention, the arithmetic intensity of the masked attention remains the same asymptotically, because the number of heads $h$ is a constant and it can always to be folded into the batch size $b$ so that the multi-head attention can be treated as a single-head attention, where $Q \in \mathbb{R}^{bh \times l_{q} \times d_{qk}}$, $K \in \mathbb{R}^{bh \times l_{kv} \times d_{qk}}$, and $V \in \mathbb{R}^{bh \times l_{kv} \times d_{v}}$.
For large language models, from this performance analysis, we could draw a few conclusions.
- During the prefill stage because $l_{q}$ and $l_{kv}$ are usually large, the arithmetic intensity of the attention is very high, which means the attention is math-bound.
- During the decode stage which employs the KV cache optimization, because $l_{q} = 1$ and $l_{kv}$ is usually large, the arithmetic intensity of the attention is $\mathcal{\Theta}(1)$, which is very low. It means the attention is memory-bound.
- Unlike other common operations, such as convolution, in neural networks, having larger batch size will not affect the arithmetic intensity of the attention. This means autoregressive decoding with KV caching optimization will always remain memory-bound no matter how many user requests were processed in parallel.
Conclusions
Because the attention layers are the performance-dominated layers in the transformer, to optimize the performance of the transformer, we need to optimize the attention layers. During the decode stage, because the attention layers are memory-bound, we need to reduce the memory accesses in the attention layers to improve the performance of the transformer. Because of these, efforts have been made to optimize the memory accesses of attention layers specific for serving large language models, such as the KV cache quantization and pruning that reduces the size of key and value tensors, and grouped query attention that reuses the key and value tensors for multiple queries.
References
Transformer Vanilla Attention Performance Theoretical Analysis
https://leimao.github.io/blog/Transformer-Vanilla-Attention-Performance-Theoretical-Analysis/