Autoregressive Model and Autoregressive Decoding for Sequence to Sequence Tasks
Introduction
Sequence to sequence tasks, such as machine translation, have been modeled using autoregressive models, such as recurrent neural network and Transformers. Given an input sequence, the output sequence is also generated in an autoregressive fashion.
In this blog post, I would like to discuss mathematically why autoregressive models and autoregressive decoding have been applied for sequence to sequence tasks, and their drawbacks in terms of latency-constrained inference in practice.
Discriminative Model for Sequence to Sequence Tasks
The sequence to sequence discriminative model usually learns a probability distribution $P(Y | X)$ where $X$ is a sequence of input variables $X = \{X_1, X_2, \cdots, X_{T^{\prime}}\}$ and $Y$ is a sequence of output variables $Y = \{Y_1, Y_2, \cdots, Y_{T}\}$. To make the problem simple, we assume $X$ and $Y$ are sequences of discrete variables. This assumption is valid in many problems, such as language translation. Note that although it seems that the model is generating output sequence, the model is still a discriminative model, rather than a generative model. If the reader does not know the difference between discriminative model and generative model, please check my blog post “Discriminative Model VS Generative Model”.
The goal of inference is to find or sample the most likely output sequence $y$ given the input sequence $x$. This step is also called “decoding”. Mathematically, it could be expressed as
$$
\DeclareMathOperator*{\argmin}{argmin}
\DeclareMathOperator*{\argmax}{argmax}
y = \argmax_{Y} P(Y | X = x)
$$
One of the problems for the sequence to sequence discriminative model is that the model is actually an ensemble model or an adaptive model that models the distributions $P(Y_1, Y_2, \cdots, Y_{T} | X_1, X_2, \cdots, X_{T^{\prime}})$ for all $T^{\prime} = {1, 2, \cdots}$ and $T = {1, 2, \cdots}$. This means that $P(Y = \{y_1\} | X = x)$ and $P(Y = \{y_1, y_2\} | X = x)$ are from two different conditional distributions and they are not directly comparable. We cannot determine the output sequence length by comparing the conditional probabilities of the the output sequences of different lengths during inference.
In addition, even if we know the output sequence length $T = t$ given the input sequence, given the model $P(Y_1, Y_2, \cdots, Y_{t} | X)$ but without having any of the mathematical properties about it and the input sequence $x$, to find the most likely $y$, we have to apply the brute force decoding by iterating through all the possible sequence combinations for $Y$, and find the $y$ that has the maximum probability.
Note that here we have no independence assumptions for the variables in the output sequence, i.e.,
$$
P(Y_1, Y_2, \cdots, Y_{t} | X) \neq \prod_{i=1}^{t} P(Y_i | X)
$$
Otherwise the finding the optimal output sequence will be much easier.
Suppose the $Y_i$ is a binary variable for $i = 1, 2, \cdots, t$. it will take $O(2^t)$ time complexity to find the optimal $y$, which is intractable. The brute force decoding will work if the output sequence length is very small. However, in many practical problems, the sequence length $t$ could be very large. For example, even for a output sequence of length $32$, $2^{32} = 4294967296$, which is an extremely large number around $4$ billion.
Therefore, decoding, i.e., searching the most likely output sequence, is the most critical and the hardest problem to solve for the sequence to sequence discriminative model. Concretely, we have to solve two problems, how to determine the length of output sequence, and develop an efficient searching algorithm that finds the maximum conditional probability.
Autoregressive Modeling and Autoregressive Decoding
Given a piece of training data, the input sequence $x = \{x_1, x_2, \cdots, x_{t^{\prime}}\}$ and its ground truth output sequence $y = \{y_1, y_2, \cdots, y_{t}\}$, the autoregressive model basically applies the probability chain rule and creates a temporal model for the problem.
$$
\begin{align}
P(Y | X; \theta) = P(Y_0 | X_{1:t^{\prime}}; \theta) \prod_{i=1}^{t} P(Y_i | Y_{0:i-1}, X_{1:t^{\prime}}; \theta)
\end{align}
$$
It is very often that $Y_0$ is not a variable. Usually, $Y_0 \equiv \langle \text{BOS} \rangle$ (the beginning of the sequence) and $P(Y_0 = \langle \text{BOS} \rangle | X_{1:t^{\prime}}; \theta) = 1$.
During the autoregressive model training, the model is actually maximizing the likelihood for training data $(x, y)$ that have different sequence lengths by changing the model parameter $\theta$.
$$
\begin{align}
\argmax_{\theta} P(Y = y | X = x; \theta) &= \argmax_{\theta} \prod_{i=1}^{t} P(Y_i = y_i | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)
\end{align}
$$
By optimizing $P(Y | X ; \theta)$, the model also optimized the $P(Y_i | Y_{0:i-1}, X_{1:t^{\prime}}; \theta)$ for all $i = 1, 2, \cdots$. This means that during inference, instead of computing $P(Y | X)$ for all the possible combinations for $Y = \{Y_1, Y_2, \cdots \}$ to find out the optimal $y$, we could sort of find out the optimal variable in the output sequence greedily. Suppose the $Y_i$ is a binary variable for $i = 1, 2, \cdots$. If we know the output sequence length is $t$, the brute force decoding will take $O(2^t)$ time complexity to find the optimal $y$, which is intractable, whereas the greedy autoregressive decoding will only take $O(t)$.
Note that the theoretically, we cannot guarantee the following equation during autoregressive decoding and generally it is false.
$$
\begin{align}
\max_{Y_1, Y_2, \cdots, Y_t} \prod_{i=1}^{t} P({Y_1, Y_2, \cdots, Y_t} | X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta) = \prod_{i=1}^{t} \max_{Y_i} P(Y_i | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)
\end{align}
$$
where
$$
\begin{align}
y_i = \argmax_{Y_i} P(Y_i | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)
\end{align}
$$
This means that unlike the brute force decoding which finds the global optimum, the autoregressive decoding does not necessarily finds the global optimum. However, $\prod_{i=1}^{t} \max_{} P(Y_i | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)$ usually is a very large probability. This lays the foundation for why autoregressive decoding is a valid approach.
Sometimes, a modified greedy autoregressive decoding method, sometimes referred as beam search, achieves slightly better decoding results. It produces better results compared to the greedy autoregressive decoding method because it has larger search space. But the time complexity for the beam search decoding is usually $O(k^2 t)$. where $k$ is a constant representing the beam size. Since it is a constant, the time complexity remains $O(t)$.
Please also note that we can apply the autoregressive decoding only if we trained our model in an autoregressive fashion.
The autoregressive decoding algorithm is an efficient searching algorithm that finds the maximum conditional probability from the output sequences that are of a fixed length $T = t$, compared with the greedy decoding algorithm, although it does not guarantee the global maximum. However, the sequence length problem remains. We don’t know the sequence length, then we still have infinite number of output sequence candidates.
One approach is to create a model $P(T | X_{1:T^{\prime}}; \theta)$ that predict the output sequence length directly from the input sequence. However, this sometimes does not work well in practice. Because $T$ is a discrete variable that randomly choose value from an infinite set $\{1, 2, \cdots \}$. For many sequence to sequence tasks, the error tolerance for $T$ is very poor. For example, given an input sequence $x$, the autoregressive decoding generates sequences of $\{y_1\}$, $\{y_1, y_2\}$, $\{y_1, y_2, y_3\}$, $\{y_1, y_2, y_3, y_4\}$, $\cdots$. If the $\{y_1, y_2, y_3\}$ is a very good output sequence but the length prediction is $T=2$, the output sequence selected by the algorithm will be $\{y_1, y_2\}$, which will be absurd in many sequence to sequence tasks. Trying thinking of predicting “How are” because the model did not predict the output sequence length correctly, whereas actually “How are you” makes more sense.
So a more commonly used approach is to somehow implicitly encode the sequence length information to the output sequence, rather than directly predicting the value for output sequence length. Concretely, given a piece of training data, the input sequence $x = \{x_1, x_2, \cdots, x_{t^{\prime}}\}$ and its ground truth output sequence $y = \{y_1, y_2, \cdots, y_{t}\}$, during training, the ground truth output sequence will actually be $y = \{y_1, y_2, \cdots, y_{t}, y_{t+1}\}$ where $y_{t+1} = \langle \text{EOS} \rangle$ (the end of the sequence). Therefore, during autoregressive decoding, whenever the algorithm see an new output token is $\langle \text{EOS} \rangle$, it knows it is time to stop decoding. The only concern is whether $\langle \text{EOS} \rangle$ will come up too early (truncated sequence prediction) or whether $\langle \text{EOS} \rangle$ will probably never show up (infinite sequence prediction). This is usually not a problem in practice if the discriminative model has learned the sequence to sequence task very well. $P(Y_i = \langle \text{EOS} \rangle | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)$ is usually the largest compared to the conditional probabilities for $Y_i \neq \langle \text{EOS} \rangle$ if $y_{0:i-1}$ is already a sequence that matches $x_{1:t^{\prime}}$ quite well. On the contrary, $P(Y_i = \langle \text{EOS} \rangle | Y_{0:i-1} = y_{0:i-1}, X_{1:t^{\prime}} = x_{1:t^{\prime}}; \theta)$ is usually almost zero compared to the conditional probabilities for $Y_i \neq \langle \text{EOS} \rangle$ if $y_{0:i-1}$ is not a good match for $x_{1:t^{\prime}}$.
Autoregressive Summary
Given a sequence of input variables $X = \{X_1, X_2, \cdots, X_{T^{\prime}}\}$ and a sequence of output variables $Y = \{Y_1, Y_2, \cdots, Y_{T}, Y_{T+1}\}$, according to the chain rule,
$$
\begin{align}
P(Y | X; \theta) = \prod_{t=1}^{T + 1} P(Y_t | Y_{0:t-1}, X_{1:T^{\prime}}; \theta)
\end{align}
$$
where $Y_0 \equiv \langle \text{BOS} \rangle$ and $Y_{T+1} \equiv \langle \text{EOS} \rangle$. Note that the input sequence length and the output sequence length are also random variables.
The principle of autoregressive model optimization is as follows.
$$
\begin{align}
\argmax_{\theta} \log P(Y | X; \theta) &= \argmax_{\theta} \log \prod_{t=1}^{T+1} P(Y_t | Y_{0:t-1}, X_{1:T^{\prime}}; \theta) \\
&= \argmax_{\theta} \sum_{t=1}^{T+1} \log P(Y_t | Y_{0:t-1}, X_{1:T^{\prime}}; \theta) \\
\end{align}
$$
The principle of greedy autoregressive model decoding is as follows.
$$
\begin{align}
\max_{Y_1, Y_2, \cdots, Y_T, Y_{T+1}} \prod_{t=1}^{T + 1} P({Y_1, Y_2, \cdots, Y_t} | X_{1:T^{\prime}} = x_{1:T^{\prime}}; \theta) \approx \prod_{t=1}^{T + 1} \max_{Y_t} P(Y_t | Y_{0:t-1} = y_{0:t-1}, X_{1:T^{\prime}} = x_{1:T^{\prime}}; \theta)
\end{align}
$$
where
$$
\begin{align}
y_t = \argmax_{Y_t} P(Y_t | Y_{0:t-1} = y_{0:t-1}, X_{1:T^{\prime}} = x_{1:T^{\prime}}; \theta)
\end{align}
$$
Autoregressive Drawbacks
It sounds like the autoregressive model and autoregressive decoding are very useful for the sequence to sequence discriminative model. However, it still has some drawback in certain use cases.
In an inference latency constrained system, autoregressive model and autoregressive decoding are not favored because the autoregressive decoding computation process could not be parallelized, as the tokens have to be generated one by one by inferencing multiple times. For long output sequences, the inference time could easily go over the latency budget can cause various problems. This means even if the autoregressive decoding time complexity is $O(t)$, it is not satisfying if the generation of the token in the output sequence cannot be parallelized.
In some rare use cases or problems, we could have the independence assumptions for the variables in the output sequence for both training and inference, and we know the length of output sequence, the temporal autoregressive model degenerates to the non-autoregressive model with independence assumptions.
$$
\begin{align}
P(Y | X; \theta) = P(T | X_{1:t^{\prime}}; \theta) \prod_{i=1}^{T} P(Y_i | X_{1:t^{\prime}}; \theta)
\end{align}
$$
Then we could easily parallelize the generation of the tokens. Semantic segmentation is an example for such use cases, where we treat each pixel as a variable and each output pixel is conditionally independent from each other. However, it is no longer a temporal model, even if it is “sequence to sequence”, we hardly talked about autoregressive decoding for generating the labels for segmentation.
Notes
Conventional Transformer models have a mask in the decoder that prevents the output tokens from seeing the future token in the output sequence. This is enforcing the Transformer model to learn in an autoregressive fashion. As a result, during inference time, the Transformer decoding could (has to) be autoregressive.
Conclusions
We have learned how the autoregressive model learns a task and how the autoregressive decoding reduces the time complexity to find the (local or sub) optimal output sequences. We have also learned the drawbacks of the autoregressive models and the autoregressive decoding.
Autoregressive Model and Autoregressive Decoding for Sequence to Sequence Tasks
https://leimao.github.io/blog/Autoregressive-Model-Autoregressive-Decoding/