Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence. On the Move.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS

Introduction

Transformer is one of the most influential sequence transduction models as it is based solely on attention mechanisms dispensing recurrence and convolutions. Previously deep learning sequence transduction models usually consists of recurrence architecture which could hardly take the advantage of modern parallel computing. With recurrence removed, transformer showed not only training performance gain but also better results in machine translation compared to state of the art in 2017.


After going through the architecture of transformer, I found that it is actually simple mathematically, probably due to the lack of recurrence and convolutions. In this blog post, I will try to illustrate the transformer model mathematically in one single page.

Overview

Model Architecture

I have to admit that for the first time I read the transformer paper “Attention is All You Need”, I did not understand the model thoroughly. This was probably because that I have not been studying natural language processing for many years. There is a very good blog post from Jay Alammar illustrating how transformer works with animated gifs. It helped me understand the model quite well.

Transformer Architecture

Later I found actually explaining transformer using mathematics is much easier and concise, and causes less confusions about the model for the beginners.

Machine Translation

I will use the machine translation example from Jay Alammar in my blog post. We have an input in French “je suis étudiant”, and the translation program translates it to “i am a student” in English. I am using this French-to-English translation as an example in particular because the number of input words is different to the number of output words, and I want to make use of that to help the readers understand with less confusions.

Mathematical Formulations

Variables

These are all the variables we would use in the transformer model.


$n_\text{input}$: number of input words (embeddings).

$n_\text{output}$: number of output words (embeddings).

$N$: number of stacks in encoder, the number of stacks in decoder.

$d_\text{model}$: length of word embedding vector.

$d_k$: length of query vector, length of key vector.

$d_v$: length of value vector.

$d_{ff}$: length of inner layer of feed forward networks

$h$: number of attention heads in multi-head attention.

$n_\text{vol}$: vocabulary size.

$W_{\text{Embedding}} \in \mathbb{R}^{n_\text{vol} \times d_\text{model}}$: word embedding matrix.

$W_{i,j}^{Q,E} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for queries in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i,j}^{K,E} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for keys in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i,j}^{V,E} \in \mathbb{R}^{d_\text{model} \times d_v}$: linear transformation matrix for values in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i}^{O,E} \in \mathbb{R}^{h d_v \times d_\text{model}}$: linear transformation matrix for concatenated attention heads of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i,j}^{Q,D} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for queries in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,j}^{K,D} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for keys in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,j}^{V,D} \in \mathbb{R}^{d_\text{model} \times d_v}$: linear transformation matrix for values in attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i}^{O,D} \in \mathbb{R}^{h d_v \times d_\text{model}}$: linear transformation matrix for concatenated attention heads of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,j}^{Q,DM} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for queries in masked attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,j}^{K,DM} \in \mathbb{R}^{d_\text{model} \times d_k}$: linear transformation matrix for keys in masked attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,j}^{V,DM} \in \mathbb{R}^{d_\text{model} \times d_v}$: linear transformation matrix for values in masked attention head $j \in \{ 1, 2, \cdots, h \}$ of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i}^{O,DM} \in \mathbb{R}^{h d_v \times d_\text{model}}$: linear transformation matrix for concatenated masked attention heads of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,1}^{FF,E} \in \mathbb{R}^{d_\text{model} \times d_{ff}}$: linear transformation matrix to the inner layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$b_{i,1}^{FF,E} \in \mathbb{R}^{d_{ff}}$: bias term to the inner layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i,2}^{FF,E} \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$: linear transformation matrix to the output layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$b_{i,2}^{FF,E} \in \mathbb{R}^{d_{\text{model}}}$: bias term to the output layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the encoder.

$W_{i,1}^{FF,D} \in \mathbb{R}^{d_\text{model} \times d_{ff}}$: linear transformation matrix to the inner layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$b_{i,1}^{FF,D} \in \mathbb{R}^{d_{ff}}$: bias term to the inner layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W_{i,2}^{FF,D} \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$: linear transformation matrix to the output layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$b_{i,2}^{FF,D} \in \mathbb{R}^{d_{\text{model}}}$: bias term to the output layer in the feed forward network of stack $i \in \{ 1, 2, \cdots, N \}$ in the decoder.

$W^{S} \in \mathbb{R}^{d_{\text{model}} \times n_\text{vol}}$: linear transformation matrix to the final softmax layer after decoder.

$b^{S} \in \mathbb{R}^{n_\text{vol}}$: bias term to to the final softmax layer after decoder.


We also have the following constrains:

Functions

These are all the functions we would use in the transformer model.


To calculate the embeddings after attention,

For encoder multi-head attention for stack $i$ in the encoder,

where

Similarly, for decoder multi-head attention for stack $i$ in the decoder,

where

Masked multi-head attention only exists in the decoder. For the masked multi-head for stack $i$ in the decoder during training.

where

$\text{Mask}(x)$ is a function which takes a square matrix $x$ as input. We have $x’ = \text{Mask}(x)$ where

For the masked multi-head for stack $i$ in the decoder during inference, the mask function was removed. That is to say $\text{MaskedMultiHead}$ layer is the same to $\text{MultiHead}$ layer during inference.

where

$\text{LayerNorm}$ is a normalization layer which does not change the shape of tensors. Please find the details of layer normalization in one of my blog posts.


For feed forward neural network for stack $i$ in the encoder,

Similarly, for feed forward neural network for stack $i$ in the decoder,

For positional encoding, $\text{PE}$ takes a two dimensional $x \in \mathbb{R}^{n \times d_\text{model}}$ matrix as input, we have $x’ \in \mathbb{R}^{n \times d_\text{model}}$ and $x’ = \text{PE}(x)$ where

Training Phase

Let us first look at the encoder.


We have an input consisting of $n_\text{input}$ words, the input matrix $X_{\text{input}} \in \mathbb{R}^{n_\text{input} \times n_\text{vol}}$ has $X_{\text{input}} = \{x_1, x_2, \cdots, x_{n_\text{input}}\}$. Each $x_k$ is a one-hot encoded row vector of length $n_\text{vol}$ for the corresponding word.


Then we have the input embedding matrix $X \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$:

We further add positional encoding to the embedding matrix resulting in matrix $X’ \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$.

Here we started to apply self attention! $Z_{1,1}^{E} \in \mathbb{R}^{n_{\text{input}} \times h d_v}$

We use residual connection and apply layer normalization, $Z_{1,2}^{E} \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$

Then we do feed forward neural network, $Z_{1,3}^{E} \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$

We use residual connection and apply layer normalization, $Z_{1,4}^{E} \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$

$Z_{1,4}^{E}$ will be used as the input to the next stack (stack 2) in the encoder, $Z_{2,1}^{E} \in \mathbb{R}^{n_\text{input} \times h d_v}$

We iterate this process $N$ times, and get $Z_{N,4}^{E} \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$ from the stack $N$ in the encoder. This encoded embeddings $Z_{N,4}^{E}$ will be used for each of the stacks in the decoder.


We then look at the decoder.


We have the translation output, which is the translation of input, consisting of $n_{\text{output}}$ words. the output matrix $Y_{\text{output}} \in \mathbb{R}^{n_{\text{output}} \times n_{\text{vol}}}$ has $Y_{\text{output}} = \{ y_1, y_2, \cdots, {y_{n_{\text{output}}}} \}$. Each $y_k$ is a one-hot encoded row vector of length $n_\text{vol}$ for the corresponding word.


The input to the decoder the right shifted of ${Y_{\text{output}}}$, ${Y_{\text{output}}^{\prime}}$. Concretely ${Y_{\text{output}}^{\prime}} = \{ y_0, y_1, y_2, \cdots, {y_{n_{\text{output}}-1}} \}$, and $y_0$ could just simply be a zero vector.


Given ${Y_{\text{output}}^{\prime}}$, we wish to the decoder could generate ${Y_{\text{output}}}$.


Similar to the encoder, we have the output embedding matrix $Y \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$:

We further add positional encoding to the embedding matrix resulting in matrix $Y’ \in \mathbb{R}^{n_\text{input} \times d_\text{model}}$.

Then we started to apply self-attention. However, in order to prevent the current word input to peek the following words, we used masked multi-head attention instead of the ordinary multi-head attention we used in the encoder. Here is the intuition. During training, the full right shifted output ${Y_{\text{output}}^{\prime}}$ was feeded. Without mask, after applying self attention, word $y_i$ will have attention to word $y_j$ for $j > i$. This is unwanted. By applying mask, the attention of word $y_i$ to word $y_j$ for $j > i$ will be 0. We have $Z_{1,1}^{D} \in \mathbb{R}^{n_\text{output} \times h d_v}$

We use residual connection and apply layer normalization, $Z_{1,2}^{D} \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$

Then we use queries from $Z_{1,2}^{D}$, keys and values from $Z_{N,4}^{E}$. We then do multi-head attention trying to find the attention of the words to the embeddings from the encoder. $Z_{1,3}^{D} \in \mathbb{R}^{n_{\text{output}} \times h d_v}$

We use residual connection and apply layer normalization, $Z_{1,4}^{D} \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$

Then we do feed forward neural network, $Z_{1,4}^{D} \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$

We use residual connection and apply layer normalization, $Z_{1,6}^{D} \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$

$Z_{1,6}^{D}$ will be used as the input to the next stack (stack 2) in the decoder, $Z_{2,1}^{D} \in \mathbb{R}^{n_\text{output} \times h d_v}$

We iterate this process $N$ times, and get $Z_{N,6}^{D} \in \mathbb{R}^{n_\text{output} \times d_\text{model}}$ from the stack $N$ in the decoder. This encoded embeddings $Z_{N,6}^{D}$ will be used for linear transformation followed by softmax to generate the output probabilities for the decoded word sequence. Concretely, we have $Z \in \mathbb{R}^{n_\text{output} \times n_\text{vol}}$ where

We then apply softmax to get the probability matrix $P \in \mathbb{R}^{n_\text{output} \times n_\text{vol}}$ where

Finally we could calculate the cross entropy loss using $Y$ and $P$. More specifically, each row of $Y_{\text{output}}$ and $P$ are $y_i$ and $p_i$ respectively.

Inference Phase

To predict the $k$th word in translation, we need to feed the previous $k-1$ predicted word to the decoder and get the new predictions of the the first $k$ words. We iterate this step until a stop sign word is predicted. Please do check the concrete example below.

Caveats

Although Jay Alammar has done a good job animating all the components of transformer, his interpretation about the decoder as well as the inference is incorrect. In his blog post, the right shifted words were fed into the decode once a time sequentially. This is incorrect both in training and inference. In training phase, we feed all the shifted words to the decode together and make the use of mask to eliminate peeking to the future words. In inference phase, all the already output words were fed into the decoder to predict the next word. If feeding the words one by one, the self attention part in the decoder does not make sense at all.

Concrete Example

In training phase, we have the following translation to train, “je suis étudiant” -> “i am a student”. We first feed “je suis étudiant” all together into the encoder and get its keys and values. Then we feed “$\emptyset$ i am a” all together into the decoder and get the final probabilities to calculate the cross entropy loss.


In test phase, we have the following French sentence to translate, say, again “je suis étudiant”. We first feed “je suis étudiant” all together into the encoder and get it keys and values. To get the first translation word, we feed “$\emptyset$” into the decoder and get the first word prediction, say “he” (note that this translation is “incorrect”). To get the second translation word, we feed “$\emptyset$ he” into the decoder and get the new prediction for the first word, say, “i” (note that this time the translation is “correct”), and the prediction for the second word, say, “am”. To get the third translation word, we we feed “$\emptyset$ i am” into the decoder and get the new prediction for the first word, say, “i”, and the new prediction for the second word, say, “am”, and the prediction for the third word, say, “a”. We iterate until at some time point a stop sign word is predicted. A valid translation could be “i am a good student $\times$”.

Final Remarks

Actually with this pure mathematical recipe, you don’t even have to know what transformer is in order to make it work. With all of these, I estimate that it could be put together into one single page of an ordinary paper.

Acknowledgement

Thanks Yang Chen for the active discussion on the inference part of this topic.

References