Transformer Autoregressive Inference Optimization
Introduction
The transformer architecture based neural networks have been extremely successful in different artificial intelligence tasks, since its invention in 2017. The transformer architectures usually can be categorized into three groups: encoder-decoder architecture, encoder-only architecture, and decoder-only architecture. For example, the transformer developed in the original transformer paper “Attention Is All You Need” for solving language translation problems uses an encoder-decoder architecture, the once very popular transformer architecture based model BERT which is capable of working on a variety of different natural language tasks uses an encoder-only architecture, the famous OpenAI GPT series models, such as ChatGPT which has been shown to have artificial general intelligence to a good extent, use decoder-only architecture.
The transformer architecture based neural networks, especially those large language models, usually have billions of parameters. Therefore, even running inference from a pretrained transformer model on a massive scale is a challenge. In this article, I would like to discuss how to optimize transformer autoregressive inference, specifically for the encoder-decoder architecture and decoder-only architecture transformers, from the perspective of mathematics.
Transformer Autoregressive Inference Computational Complexity
Prior to the birth of transformer, there had been some other types of neural networks which runs inference in an autoregressive fashion. One of the most famous ones was the recurrent neural network. Recurrent neural networks can “memorize” the prior context during autoregressive inference to some extent using a latent state feature vector. Its autoregressive inference is not complicated. At each time step, the recurrent neural network would consume only one token, and the inference is an asymptotically $O(1)$ operation that runs in constant time. Therefore, to run autoregressive inference $n$ times for a recurrent neural network, the inference asymptotic computation complexity is $O(n)$.
The transformer model autoregressive inference, however, can be asymptotically much slower if it is not well optimized. The transformer models do not “memorize” the prior context like recurrent neural networks. Instead, they directly looks at the prior context and pay different attentions to the prior tokens. In the transformer decoder autoregressive inference, using the “vanilla” recipe, at the time step $n$, the transformer decoder would have to consume $n$ tokens and the inference for this time step is an $O(n^2)$ operation due to the self-attention mechanism. Therefore, to run autoregressive inference $n$ times for a transformer decoder, the inference asymptotic computation complexity is $O(n^3)$, which is something that usually we cannot afford.
However, close examination of the transformer decoder mathematics reveals that it’s possible to reduce the asymptotic complexity of inference at time step $n$ from $O(n^2)$ to $O(n)$. Therefore, to run autoregressive inference $n$ times for a transformer decoder, the inference asymptotic computation complexity can be reduced to $O(n^2)$ from the vanilla $O(n^3)$.
Transformer Autoregressive Inference Optimization
We would examine and optimize the transformer autoregressive decoding in each type of the layers in the transformer architecture, such as the masked self-attention head, the cross-attention head used in the encoder-decoder architecture, and the multi-head attention layer. It turns out that with appropriate caching the inference for each type of the layers is incremental, i.e., adding a new input token to each layer does not affect the outputs from the previous input tokens. Therefore, we would need to feed only one token instead of $n$ tokens at the time step $n$ and only compute the attention with respect to that one token, which significantly reduces the autoregressive inference cost.
Masked Self-Attention Head
In the autoregression decoding, at the time step $n$, the number of the input tokens to the decoder is $n$. The input tensor to the masked self-attention head is an attention tensor $X_{n} \in \mathbb{R}^{n \times d_{\text{model}}}$ where $n$ is the number of input tokens to the decoder and $d_{\text{model}}$ is the number of attention features per token. The parameters of the masked self-attention are $W^{Q} \in \mathbb{R}^{d_{\text{model}} \times d_{k}}$ for the query feature transformation, $W^{K} \in \mathbb{R}^{d_{\text{model}} \times d_{k}}$ for the key feature transformation, and $W^{V} \in \mathbb{R}^{d_{\text{model}} \times d_{v}}$ for the value feature transformation, where $d_{k}$ is the number of query or key features per token in the attention head and $d_{v}$ is the number of value features per token in the attention head. The output tensor from the masked self-attention head is an attention tensor $Y_{n} \in \mathbb{R}^{n \times d_{v}}$.
The query tensor $Q_{n} \in \mathbb{R}^{n \times d_{k}}$, the key tensor $K_{n} \in \mathbb{R}^{n \times d_{k}}$, the value tensor $V_{n} \in \mathbb{R}^{n \times d_{v}}$ could be computed using a linear transformation. Concretely,
$$
Q_{n} = X_{n} W^{Q}
$$
$$
K_{n} = X_{n} W^{K}
$$
$$
V_{n} = X_{n} W^{V}
$$
The output attention tensor $Y_{n}$ from the masked self-attention head is computed as follows.
$$
Y_{n} = \text{softmax} \left( \text{Mask} \left( \frac{ Q_{n} K_{n}^{\top}}{\sqrt{d_k}} \right) \right) V_{n}
$$
In the next time step $n + 1$, we have a new attention token $x_{n+1} \in \mathbb{R}^{1 \times d_{\text{model}}}$, which is usually the predicted token from the time step $n$, going into the transformer decoder. The input attention tensor $X_{n+1} \in \mathbb{R}^{(n+1) \times d_{\text{model}}}$ is a concatenation of $X_{n}$ and $x_{n+1}$.
$$
\begin{align}
X_{n+1}
&=
\left [
\begin{array}{c|c}
X_{n} \\
x_{n+1} \\
\end{array}
\right ]
\end{align}
$$
The new query tensor $Q_{n+1} \in \mathbb{R}^{(n+1) \times d_{k}}$ at the time step $n+1$ could be computed using the same feature transformation matrices.
$$
\begin{align}
Q_{n+1}
&= X_{n+1} W^{Q} \\
&=
\left [
\begin{array}{c|c}
X_{n} \\
x_{n+1} \\
\end{array}
\right ] W^{Q} \\
&=
\left [
\begin{array}{c|c}
X_{n} W^{Q} \\
x_{n+1} W^{Q} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Q_{n} \\
q_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the query tensor $q_{n+1} \in \mathbb{R}^{1 \times d_{k}}$ for the new attention token $x_{n+1}$ can be computed using
$$
q_{n+1} = x_{n+1} W^{Q}
$$
Here, computing the new query tensor $q_{n+1}$ for the new attention token $x_{n+1}$ is a $O(1)$ operation, as $d_{\text{model}}$ and $d_{k}$ are constants.
This means, to compute the query tensor $Q_{n+1}$ at the time step $n+1$, we actually don’t need to use the entire input $X_{n+1}$. We would just need to compute the query tensor for the new attention token and concatenate it with the the query tensor $Q_{n}$ from the previous time step $n$.
Similarly, the new key tensor $K_{n+1} \in \mathbb{R}^{(n+1) \times d_{k}}$ at the time step $n+1$ can be computed using the concatenation of the key tensor for the new attention token $k_{n+1} \in \mathbb{R}^{1 \times d_{k}}$ with the key tensor $K_{n}$ from the previous time step $n$.
$$
\begin{align}
K_{n+1}
&= X_{n+1} W^{K} \\
&=
\left [
\begin{array}{c|c}
X_{n} \\
x_{n+1} \\
\end{array}
\right ] W^{K} \\
&=
\left [
\begin{array}{c|c}
X_{n} W^{K} \\
x_{n+1} W^{K} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
K_{n} \\
k_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the new key tensor $k_{n+1}$ for the new attention token $x_{n+1}$ can be computed using
$$
k_{n+1} = x_{n+1} W^{K}
$$
Here, computing the new key tensor $k_{n+1}$ for the new attention token $x_{n+1}$ is a $O(1)$ operation, as $d_{\text{model}}$ and $d_{k}$ are constants.
Similarly, the new value tensor $V_{n+1} \in \mathbb{R}^{(n+1) \times d_{v}}$ at the time step $n+1$ can be computed using the concatenation of the value tensor for the new attention token $v_{n+1} \in \mathbb{R}^{1 \times d_{v}}$ with the value tensor $V_{n}$ from the previous time step $n$.
$$
\begin{align}
V_{n+1}
&= X_{n+1} W^{V} \\
&=
\left [
\begin{array}{c|c}
X_{n} \\
x_{n+1} \\
\end{array}
\right ] W^{V} \\
&=
\left [
\begin{array}{c|c}
X_{n} W^{V} \\
x_{n+1} W^{V} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the new value tensor $v_{n+1}$ for the new attention token $x_{n+1}$ can be computed using
$$
v_{n+1} = x_{n+1} W^{V}
$$
Here, computing the new query tensor $v_{n+1}$ for the new attention token $x_{n+1}$ is a $O(1)$ operation, as $d_{\text{model}}$ and $d_{v}$ are constants.
Let’s further check how the masked self-attention head is computed at the time step $n + 1$. The output attention tensor $Y_{n+1} \in \mathbb{R}^{(n+1) \times d_{v}}$ can be computed as follows.
$$
\begin{align}
Y_{n+1}
&= \text{softmax} \left( \text{Mask} \left( \frac{ Q_{n+1} K_{n+1}^{\top}}{\sqrt{d_k}} \right) \right) V_{n+1} \\
&= \text{softmax} \left( \text{Mask} \left(
\frac{1}{\sqrt{d_k}}
\left [
\begin{array}{c|c}
Q_{n} \\
q_{n+1} \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
K_{n} \\
k_{n+1} \\
\end{array}
\right ]^{\top} \right) \right)
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&= \text{softmax} \left( \text{Mask} \left(
\frac{1}{\sqrt{d_k}}
\left [
\begin{array}{c|c}
Q_{n} \\
q_{n+1} \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
K_{n}^{\top} & k_{n+1}^{\top} \\
\end{array}
\right ] \right) \right)
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&= \text{softmax} \left( \text{Mask} \left(
\frac{1}{\sqrt{d_k}}
\left [
\begin{array}{c|c}
Q_{n}K_{n}^{\top} & Q_{n}k_{n+1}^{\top} \\
\hline
q_{n+1}K_{n}^{\top} & q_{n+1}k_{n+1}^{\top} \\
\end{array}
\right ]
\right) \right)
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&= \text{softmax} \left(
\left [
\begin{array}{c|c}
\text{Mask} \left( \frac{1}{\sqrt{d_k}} Q_{n}K_{n}^{\top}\right) & -\infty \\
\hline
\frac{1}{\sqrt{d_k}} q_{n+1}K_{n}^{\top} & \frac{1}{\sqrt{d_k}} q_{n+1}k_{n+1}^{\top} \\
\end{array}
\right ]
\right)
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
\left [
\begin{array}{c|c}
\text{softmax} \left(\text{Mask} \left( \frac{1}{\sqrt{d_k}} Q_{n}K_{n}^{\top}\right) \right) & 0 \\
\end{array}
\right ] \\
\hline
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1}
\left [
\begin{array}{c|c}
K_{n}^{\top} & k_{n+1}^{\top} \\
\end{array}
\right ]
\right) \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
\text{softmax} \left(\text{Mask} \left( \frac{1}{\sqrt{d_k}} Q_{n}K_{n}^{\top}\right) \right) V_{n} \\
\hline
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K_{n+1}^{\top}
\right) V_{n + 1} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Y_{n} \\
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K_{n+1}^{\top}
\right) V_{n + 1} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Y_{n} \\
y_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the new attention tensor $y_{n+1}$ for the new attention token $x_{n+1}$ can be computed using
$$
\begin{align}
y_{n+1}
&= \text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K_{n+1}^{\top}
\right) V_{n + 1} \\
&=
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1}
\left [
\begin{array}{c|c}
K_{n}^{\top} & k_{n+1}^{\top} \\
\end{array}
\right ]
\right)
\left [
\begin{array}{c|c}
V_{n} \\
v_{n+1} \\
\end{array}
\right ] \\
&=
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}} x_{n+1} W^{Q}
\left [
\begin{array}{c|c}
K_{n}^{\top} & W^{K\top} x_{n+1}^{\top} \\
\end{array}
\right ]
\right)
\left [
\begin{array}{c|c}
V_{n} \\
x_{n+1} W^{V} \\
\end{array}
\right ] \\
\end{align}
$$
Computing the new attention tensor $y_{n+1}$ for the new attention token $x_{n+1}$ is a $O(n)$ operation. The computation time increases as the number of input tokens increases for the new attention tensor. But fortunately, with the recipe above, at least we don’t have to recompute $Y_{n}$ in order to compute $Y_{n+1}$ at each time step, which makes it a $O(n^2)$ operation.
These mathematics suggest that we could cache the intermediate tensors of the inference of the current time step from the masked self-attention head, specifically $K_{n}$ and $V_{n}$ and nothing else, to accelerate the inference of the next time step.
Cross Attention Head
For the encoder-decoder transformer, there are usually cross-attention heads that allow the decoding process to pay attentions to the encoder inputs. Specifically, unlike the masked self-attention head mentioned above, the key tensor $K \in \mathbb{R}^{m \times d_{k}}$ and the value tensor $V \in \mathbb{R}^{m \times d_{v}}$ used in the attention equation, where $m$ is the number of input tokens for the encoder, are constants from the encoder given some encoder inputs throughout the entire autoregressive decoding. The query tensor $Q_{n} \in \mathbb{R}^{n \times d_{k}}$ in the attention equation is from the decoder.
In the autoregression decoding, at the time step $n$, the number of the input tokens to the decoder is $n$.
$$
Q_{n} = X_{n} W^{Q}
$$
The output attention tensor $Y_{n} \in \mathbb{R}^{n \times d_{v}}$ from the cross-attention head is computed as follows. No masks are used this time.
$$
Y_{n} = \text{softmax} \left( \frac{ Q_{n} K^{\top}}{\sqrt{d_k}} \right) V
$$
At the time step $n+1$, we could analyze the output attention tensor $Y_{n+1} \in \mathbb{R}^{(n + 1) \times d_{v}}$ using the similar method that we used for analyzing the masked self-attention head.
$$
\begin{align}
Q_{n+1}
&= X_{n+1} W^{Q} \\
&=
\left [
\begin{array}{c|c}
X_{n} \\
x_{n+1} \\
\end{array}
\right ] W^{Q} \\
&=
\left [
\begin{array}{c|c}
X_{n} W^{Q} \\
x_{n+1} W^{Q} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Q_{n} \\
q_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the query tensor $q_{n+1} \in \mathbb{R}^{1 \times d_{k}}$ for the new attention token $x_{n+1}$ can be computed using
$$
q_{n+1} = x_{n+1} W^{Q}
$$
Here, computing the new query tensor $q_{n+1}$ for the new attention token $x_{n+1}$ is a $O(1)$ operation, as $d_{\text{model}}$ and $d_{k}$ are constants.
$$
\begin{align}
Y_{n+1}
&= \text{softmax} \left( \frac{ Q_{n+1} K^{\top}}{\sqrt{d_k}} \right) V \\
&= \text{softmax} \left(
\frac{1}{\sqrt{d_k}}
\left [
\begin{array}{c|c}
Q_{n} \\
q_{n+1} \\
\end{array}
\right ]
K^{\top} \right)
V \\
&= \text{softmax} \left(
\frac{1}{\sqrt{d_k}}
\left [
\begin{array}{c|c}
Q_{n}K^{\top} \\
q_{n+1}K^{\top} \\
\end{array}
\right ]
\right)
V \\
&= \text{softmax} \left(
\left [
\begin{array}{c|c}
\frac{1}{\sqrt{d_k}} Q_{n}K^{\top} \\
\frac{1}{\sqrt{d_k}} q_{n+1}K^{\top} \\
\end{array}
\right ]
\right)
V \\
&=
\left [
\begin{array}{c|c}
\text{softmax} \left(\frac{1}{\sqrt{d_k}} Q_{n}K^{\top} \right) V \\
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K^{\top}
\right) V \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Y_{n} \\
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K^{\top}
\right) V \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Y_{n} \\
y_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the new attention tensor $y_{n+1}$ for the new attention token $x_{n+1}$ can be computed using
$$
\begin{align}
y_{n+1}
&=
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}}q_{n+1} K^{\top}
\right) V \\
&=
\text{softmax}
\left(
\frac{1}{\sqrt{d_k}} x_{n+1} W^{Q} K^{\top}
\right) V \\
\end{align}
$$
Computing the new attention tensor $y_{n+1}$ for the new attention token $x_{n+1}$ is an $O(m)$ operation. However, because usually the number of encoder input tokens $m$ is a constant during inference, this computation becomes an $O(1)$ operation.
Multi-Head Attention
The multi-head self-attention or cross attention essentially performs self-attention or cross attention on an input attention tensor $X_{n} \in \mathbb{R}^{n \times d_{\text{model}}}$ in $h$ independent self-attention or cross attention heads, producing $h$ output attention tensors $Y_{n,1} \in \mathbb{R}^{n \times d_{v}}$, $Y_{n,2} \in \mathbb{R}^{n \times d_{v}}$, $\cdots$, $Y_{n,h} \in \mathbb{R}^{n \times d_{v}}$, concatenates them into an attention tensor $Y_{n} \in \mathbb{R}^{n \times hd_{v}}$, and performs a linear feature transformation using a matrix $W^{O} \in \mathbb{R}^{hd_{v} \times d_{\text{model}}}$ resulting an output attention tensor $Z_{n} \in \mathbb{R}^{n \times d_{\text{model}}}$.
Notice that in this context, the $Y_{n} \in \mathbb{R}^{n \times hd_{v}}$ is different from the $Y_{n} \in \mathbb{R}^{n \times d_{v}}$ we have discussed so far. Instead, the $Y_{n,1} \in \mathbb{R}^{n \times d_{v}}$, $Y_{n,2} \in \mathbb{R}^{n \times d_{v}}$, $\cdots$, $Y_{n,h} \in \mathbb{R}^{n \times d_{v}}$ are the ones corresponding to the $Y_{n} \in \mathbb{R}^{n \times d_{v}}$ we discussed in the masked self-attention head and the cross-attention head sections.
Although many transformer models used the following settings,
$$
d_{k} = d_{v} = \frac{d_{\text{model}}}{h}
$$
those are not the essential constrains from the transformer attention mathematics. Even if we set the transformer to use the following settings,
$$
d_{k} \neq d_{v} \neq \frac{d_{\text{model}}}{h}
$$
mathematically, it would still work.
In the autoregression decoding, at the time step $n$, the number of the input tokens to the decoder is $n$. We also split the $W^{O} \in \mathbb{R}^{hd_{v} \times d_{\text{model}}}$ into $h$ smaller matrices $W_{1}^{O} \in \mathbb{R}^{d_{v} \times d_{\text{model}}}$, $W_{2}^{O} \in \mathbb{R}^{d_{v} \times d_{\text{model}}}$, $\cdots$, $W_{h}^{O} \in \mathbb{R}^{d_{v} \times d_{\text{model}}}$.
$$
\begin{align}
Z_{n}
&=
Y_{n} W^{O} \\
&=
\left [
\begin{array}{c|c}
Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, h} \\
\end{array}
\right ] W^{O} \\
&=
\left [
\begin{array}{c|c}
Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, h} \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
W_{1}^{O} \\
W_{2}^{O} \\
\vdots \\
W_{h}^{O} \\
\end{array}
\right ] \\
&=
Y_{n, 1} W_{1}^{O} + Y_{n, 1} W_{1}^{O} + \cdots + Y_{n, h} W_{h}^{O}
\\
&=
\sum_{i = 1}^{h} Y_{n, i} W_{i}^{O}
\\
\end{align}
$$
At the time step $n+1$, $Z_{n+1} \in \mathbb{R}^{(n + 1) \times d_{\text{model}}}$ can be computed as follows.
$$
\begin{align}
Z_{n+1}
&=
Y_{n+1} W^{O} \\
&=
\left [
\begin{array}{c|c}
Y_{n+1, 1} & Y_{n+1, 2} & \cdots & Y_{n+1, h} \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
W_{1}^{O} \\
W_{2}^{O} \\
\vdots \\
W_{h}^{O} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, h} \\
\hline
y_{n+1, 1} & y_{n+1, 2} & \cdots & y_{n+1, h} \\
\end{array}
\right ]
\left [
\begin{array}{c|c}
W_{1}^{O} \\
W_{2}^{O} \\
\vdots \\
W_{h}^{O} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
\sum_{i = 1}^{h} Y_{n, i} W_{i}^{O} \\
\sum_{i = 1}^{h} y_{n+1, i} W_{i}^{O} \\
\end{array}
\right ] \\
&=
\left [
\begin{array}{c|c}
Z_{n} \\
z_{n+1} \\
\end{array}
\right ] \\
\end{align}
$$
where the new attention tensor $z_{n+1}$ for the new attention token $x_{n+1}$ can be computed using
$$
\begin{align}
z_{n+1}
&= \sum_{i = 1}^{h} y_{n+1, i} W_{i}^{O}
\end{align}
$$
Computing the new attention tensor $z_{n+1}$ for the new attention token $x_{n+1}$ is an $O(n)$ operation if the multi-head attention is a masked multi-head self-attention or an $O(1)$ operation if the multi-head attention is a multi-head cross attention. Because the number of heads $h$ is a constant, it does not change the asymptotic complexity of the algorithm.
Other Layers
Layer normalization layer, feed forward layer, the final softmax layer, etc., are executed on the individual token features. It’s obvious that we would just have to compute the output tensor for the the new attention token only at each time step without having to recompute for the attention tokens from the previous time steps.
Conclusions
In the transformer autoregressive inference, using appropriate caching, at each time step, all the previous token attentions would not need to be computed. This reduces the inference asymptotic computation complexity for running the autoregressive inference $n$ times to $O(n^2)$.
However, running the autoregressive inference $n$ times in $O(n^2)$ is still very expensive. That’s partially why the allowed number of tokens is usually limited in commercial applications. For example, ChatGPT only allows the number of tokens to be a few thousand. This also means, if the user talks to ChatGPT too much, ChatGPT would not be able to look back to the context all the way from the beginning of the conversation.
Finally, the entire transformer autoregressive inference optimization is valid only if the inference of each type of the layer in the transformer decoder can be incremental. If some layers got changed so that its inference can no longer be incremental, such as removing the mask from the self-attention heads, the inference of the previous tokens will be disrupted upon the addition of a new token. Therefore, in this case, the inference behavior will become “unexpected” or “undefined” during the autoregressive inference, even if there is no optimization applied.
References
Transformer Autoregressive Inference Optimization
https://leimao.github.io/article/Transformer-Autoregressive-Inference-Optimization/