Grouped Query Attention Performance Theoretical Analysis

Introduction

The performance of transformer models are bottlenecked by the attention layers. Based on our previous theoretical analysis, the vanilla attention layer is memory-bound during decoding. Therefore, it makes sense to propose alternative attention mechanisms to reduce the memory IO pressure of the attention layer. The grouped query attention is one of the proposed mechanisms to reduce the memory IO pressure of the attention layer.

In this blog post, we will quickly analyze the performance of the grouped query attention theoretically.

Grouped Query Attention

The arithmetic intensity of the grouped query attention for a single head is the same as the vanilla attention. In our previous theoretical analysis for the vanilla attention, the number of memory accesses in bytes asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{qk} + b \cdot l_{kv} \cdot d_{v} + b \cdot l_{q} \cdot d_{v})$ and the number of math operations asymptotically is $\mathcal{\Theta}(b \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b \cdot l_{q} \cdot l_{kv} \cdot d_{v})$.

In grouped query attention, the queries from a group of heads share the same key and value tensors. Suppose the query group size is $g$, for the number of heads is $h$, and there are $\frac{h}{g}$ groups for grouped query attention. There are $h$ query tensors, $\frac{h}{g}$ key tensors, and $\frac{h}{g}$ value tensors to access per grouped query attention. The number of memory accesses in bytes asymptotically becomes $\mathcal{\Theta}(b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v})$. The number of math operations asymptotically becomes $\mathcal{\Theta}(b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v})$.

The arithmetic intensity of the grouped query attention can be computed as

$$
\begin{align}
\frac{\mathcal{\Theta}(b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v})}{\mathcal{\Theta}(b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v})}
&= \mathcal{\Theta}\left( \frac{ b h \cdot l_{q} \cdot l_{kv} \cdot d_{qk} + b h \cdot l_{q} \cdot l_{kv} \cdot d_{v} }{b h \cdot l_{q} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{qk} + b \frac{h}{g} \cdot l_{kv} \cdot d_{v} + b h \cdot l_{q} \cdot d_{v}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right) \\
\end{align}
$$

Assuming $d_{qk} = d_{v} = d_{m}$, we have the following simplification.

$$
\begin{align}
\mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{qk} + l_{q} \cdot l_{kv} \cdot d_{v} }{l_{q} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{qk} + \frac{1}{g} \cdot l_{kv} \cdot d_{v} + l_{q} \cdot d_{v}} \right)
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} + l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m} + l_{q} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 2 \cdot l_{q} \cdot l_{kv} \cdot d_{m} }{2 \cdot l_{q} \cdot d_{m} + 2 \cdot \frac{1}{g} \cdot l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} \cdot d_{m} }{l_{q} \cdot d_{m} + \frac{1}{g} \cdot l_{kv} \cdot d_{m}} \right) \\
&= \mathcal{\Theta}\left( \frac{ l_{q} \cdot l_{kv} }{l_{q} + \frac{1}{g} \cdot l_{kv}} \right) \\
&= \mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}} \right) \\
\end{align}
$$

In our previous theoretical analysis, we have shown the arithmetic intensity of the vanilla attention is $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{l_{q}} + \frac{1}{l_{kv}}}\right)$. The arithmetic intensity of the grouped query attention is $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}}\right)$ and $\mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{g \cdot l_{q}} + \frac{1}{l_{kv}}}\right) \geq \mathcal{\Theta}\left( \frac{ 1 }{\frac{1}{l_{q}} + \frac{1}{l_{kv}}}\right)$. Therefore, the grouped query attention always has a higher arithmetic intensity than the vanilla attention when $g > 1$.

For large language models, in terms of the decode stage where $l_{q} = 1$ and $l_{kv} \gg g \cdot l_{q}$, the arithmetic intensity of the grouped query attention becomes $\mathcal{\Theta}\left( g \right)$. Compared to the vanilla attention whose arithmetic intensity is $\mathcal{\Theta}\left( 1 \right)$ during decoding, the grouped query attention has a higher arithmetic intensity especially when $g \gg 1$.

Conclusions

The grouped query attention is a proposed mechanism to reduce the memory IO pressure of the attention layer. We could increase the group size to a value so that the model accuracy is not significantly affected and the arithmetic intensity of the attention layer can be balanced. The multi-query attention is a special case of the grouped query attention where $g = h$.

References

Author

Lei Mao

Posted on

02-03-2025

Updated on

02-03-2025

Licensed under


Comments