Grouped Query Attention Performance Theoretical Analysis

Introduction

The performance of transformer models are bottlenecked by the attention layers. Based on our previous theoretical analysis, the vanilla attention layer is memory-bound during decoding. Therefore, it makes sense to propose alternative attention mechanisms to reduce the memory IO pressure of the attention layer. The grouped query attention is one of the proposed mechanisms to reduce the memory IO pressure of the attention layer.

In this blog post, we will quickly analyze the performance of the grouped query attention theoretically.

Grouped Query Attention

The arithmetic intensity of the grouped query attention for a single head is the same as the vanilla attention. In our previous theoretical analysis for the vanilla attention, for a single head, the number of memory accesses in bytes asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{v} + b \cdot l_{q} \cdot d_{v})$ and the number of math operations asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b \cdot l_{q} \cdot l_{kv} \cdot d_{v})$.

In grouped query attention, the queries from a group of heads share the same key and value tensors. Suppose the query group size is $g$, for the number of heads is $h$, and there are $\frac{h}{g}$ groups for grouped query attention. There are $h$ query tensors, $\frac{h}{g}$ key tensors, and $\frac{h}{g}$ value tensors to access per grouped query attention. Often the time, we also equivalently say, there are $h$ query heads, $\frac{h}{g}$ key heads, and $\frac{h}{g}$ value heads. The number of memory accesses in bytes asymptotically becomes $\mathcal{\Theta}(b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v})$. The number of math operations asymptotically becomes $\mathcal{\Theta}(b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v})$.

The arithmetic intensity of the grouped query attention can be computed as

$$
\begin{align}
\frac{\mathcal{\Theta}(b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v})}{\mathcal{\Theta}(b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v})}
&= \mathcal{\Theta}\left( \frac{ b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v} }{b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right) \\
\end{align}
$$

Assuming $d_{qk} = d_{v} = d_{m}$, we have the following simplification.

$$
\begin{align}
\mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right)
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} + l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m} + l_{q} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 2 \cdot l_{q} \cdot l_{kv} \cdot d_{m} }{2 \cdot l_{q} \cdot d_{m} + 2 \cdot \frac{1}{g} \cdot l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} }{l_{q} + \frac{1}{g} \cdot l_{kv}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}} \right) \\
\end{align}
$$

In our previous theoretical analysis, we have shown the arithmetic intensity of the vanilla attention is $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{l_{q}} + \frac{1}{l_{kv}}}\right)$. The arithmetic intensity of the grouped query attention is $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}}\right)$ and $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}}\right) \geq \mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{l_{q}} + \frac{1}{l_{kv}}}\right)$. Therefore, the grouped query attention always has a higher arithmetic intensity than the vanilla attention when $g > 1$.

For large language models, in terms of the decode stage where $l_{q} = 1$ and $l_{kv} \gg g \cdot l_{q}$, the arithmetic intensity of the grouped query attention becomes $\mathcal{\Theta}\left( g \right)$. Compared to the vanilla attention whose arithmetic intensity is $\mathcal{\Theta}\left( 1 \right)$ during decoding, the grouped query attention has a higher arithmetic intensity especially when $g \gg 1$.

FAQs

How Were The Number of Key and Value Heads Reduced?

When a sequence of embeddings $X \in \mathbb{R}^{b \times l \times d_{m}}$ is linearly transformed into query, key and value tensors, we have weight tensors $W_{q} \in \mathbb{R}^{d_{m} \times d_{qk}}$ for query, $W_{k} \in \mathbb{R}^{d_{m} \times d_{qk}}$ for key and $W_{v} \in \mathbb{R}^{d_{m} \times d_{v}}$ for value. The query, key and value tensors are computed as $Q = X W_{q}$, $K = X W_{k}$ and $V = X W_{v}$, respectively. For multi-head attention that has the number of heads $h$, there are $h$ sets of such weight tensors for the linear transformation. The weight tensors can be concatenated into mega tensors $W_{q}^{\prime} \in \mathbb{R}^{d_{m} \times h \cdot d_{qk}}$ for query, $W_{k}^{\prime} \in \mathbb{R}^{d_{m} \times h \cdot d_{qk}}$ for key and $W_{v}^{\prime} \in \mathbb{R}^{d_{m} \times h \cdot d_{v}}$ for value. The mega query, key and value tensors are computed as $Q^{\prime} = X W_{q}^{\prime}$, $K^{\prime} = X W_{k}^{\prime}$ and $V^{\prime} = X W_{v}^{\prime}$, respectively. The resulting mega query, key, and value tensors are of shape $Q^{\prime} \in \mathbb{R}^{b \times l \times h \cdot d_{qk}}$, $K^{\prime} \in \mathbb{R}^{b \times l \times h \cdot d_{qk}}$ and $V^{\prime} \in \mathbb{R}^{b \times l \times h \cdot d_{v}}$, respectively. The mega query, key and value tensors can then be split into individual query, key and value tensors that have shape $Q \in \mathbb{R}^{b \times l \times d_{qk}}$, $K \in \mathbb{R}^{b \times l \times d_{qk}}$ and $V \in \mathbb{R}^{b \times l \times d_{v}}$, respectively, for each head.

In grouped query attention, because we have $h$ query heads, and $\frac{h}{g}$ key and value heads, the mega weight tensors become $W_{q}^{\prime} \in \mathbb{R}^{d_{m} \times h \cdot d_{qk}}$ for query, $W_{k}^{\prime} \in \mathbb{R}^{d_{m} \times \frac{h}{g} \cdot d_{qk}}$ for key and $W_{v}^{\prime} \in \mathbb{R}^{d_{m} \times \frac{h}{g} \cdot d_{v}}$ for value. The mega query, key and value tensors are computed as $Q^{\prime} = X W_{q}^{\prime}$, $K^{\prime} = X W_{k}^{\prime}$ and $V^{\prime} = X W_{v}^{\prime}$, respectively. The resulting mega query, key, and value tensors are of shape $Q^{\prime} \in \mathbb{R}^{b \times l \times h \cdot d_{qk}}$, $K^{\prime} \in \mathbb{R}^{b \times l \times \frac{h}{g} \cdot d_{qk}}$ and $V^{\prime} \in \mathbb{R}^{b \times l \times \frac{h}{g} \cdot d_{v}}$, respectively. The mega query, key and value tensors can then be split into individual query, key and value tensors that have shape $Q \in \mathbb{R}^{b \times l \times d_{qk}}$, $K \in \mathbb{R}^{b \times l \times d_{qk}}$ and $V \in \mathbb{R}^{b \times l \times d_{v}}$, respectively, for each query head, key head and value head. This not only reduces the size of key value tensors for caching and memory IO but also reduce the computation of the linear transformation for key and value tensors.

Conclusions

The grouped query attention is a proposed mechanism to reduce the memory IO pressure of the attention layer. We could increase the group size to a value so that the model accuracy is not significantly affected and the arithmetic intensity of the attention layer can be balanced. The multi-query attention is a special case of the grouped query attention where $g = h$ and the vanilla attention is a special case of the grouped query attention where $g = 1$.

References

Author

Lei Mao

Posted on

02-03-2025

Updated on

03-02-2025

Licensed under


Comments