Deformable Attention

Introduction

Transformer multi-head attention is a mechanism that allows a neural network to focus on a small set of features from a large set of features. However, the vanilla multi-head attention mechanism is usually very expensive to compute and slows down training convergence, especially for computer vision models. Deformable multi-head attention was developed to reduce the computational complexity of the attention mechanism for computer vision models.

In this blog post, I would like to discuss the vanilla multi-head attention, the deformable multi-head attention, and the deformable multi-head attention v2 in detail.

Multi-Head Attention

Multi-Head Attention Formulation

Let qΩq indexes a query element with representation feature zqRC×1, and kΩk indexes a key element with representation feature xkRC×1, where C is the feature dimension, and Ωq and Ωk are the query and key element sets, respectively. The multi-head attention feature is computed as follows:

MultiHeadAttention(zq,X)=m=1MWm[kΩkAm,q,kWmxk]

where m indexes the attention head, M is the number of attention heads, WmRC×Cv and WmRCv×C are the learnable weight matrices for the m-th attention head, Cv=CM by default, and Am,q,kR is the attention weight between the query element q and the key element k for the m-th attention head.

The attention weight Am,q,k, which is normalized, is computed as follows using a softmax function:

Am,q,k=exp(score(zq,xk))kΩkexp(score(zq,xk))=exp((Umzq)VmxkCv)kΩkexp((Umzq)VmxkCv)=exp(zqUmVmxkCv)kΩkexp(zqUmVmxkCv)

where UmRCv×C and VmRCv×C are the learnable weight matrices for the m-th attention head, and score(zq,xk)=zqUmVmxkCv is the attention score between the query element q and the key element k for the m-th attention head.

kΩkAm,q,kWmxk for all the queries can be computed using vectorization as follows:

AmXWm=softmax(ZUm(XVm)Cv)XWm=softmax(ZUmVmXCv)XWm

where ZRNq×C and XRNk×C are the query feature matrix and the key feature matrix, respectively, and AmRNq×Nk is the attention weight matrix for the m-th attention head.

The multi-head attention feature for all the queries is computed as follows:

MultiHeadAttention(Z,X)=m=1MAmXWmWm=m=1Msoftmax(ZUmVmXCv)XWmWm

Multi-Head Attention Computational Complexity

Suppose each multiply-accumulate (MAC) operation takes O(1) time, let’s compute the computational complexity of the multi-head attention for Nq queries and Nk keys.

The computational complexity of the multi-head attention for Nq queries and Nk keys can be derived as follows:

O(M(O(NkCCv)+O(NkCvC)XWmWm+O(NkCCv)+O(NqCCv)+O(NqNkCv)+O(NqNk)softmax(ZUmVmXCv)+O(NqNkCv)))=O(M(O(NkCCv)+O(NqCCv)+O(NqNkCv)))=O(NkC2)+O(NqC2)+O(NqNkC)=O(NkC2+NqC2+NqNkC)

In modern Transformer based large language models, for both encoder and decoder, Nq and Nk can be very large and much larger than C, therefore, the computational complexity of the multi-head attention is dominated by the third term, O(NqNkC).

In Transformer based computer vision models, the values of Nq, Nk, and C can be quite different in the encoder and decoder, therefore affecting the asymptotic computational complexity of the multi-head attention.

For the encoder, because all the attention layers are self-attention, Nq=Nk=HW where H and W are the height and width of the feature map, respectively. Because usually HWC, the computational complexity of the multi-head attention is O(NqNkC)=O(H2W2C).

For the decoder, the number of queries Nq is much smaller than the one in the encoder. In the self-attention layers, Nq=Nk=N, which might be comparable to C, the computational complexity of the multi-head attention is O(NC2+N2C). In the cross-attention layers, Nq=N and Nk=HW. Because usually HWC and HWN, the computational complexity of the multi-head attention is O(HWC2+NHWC).

It’s not hard to see that the computational complexity of the multi-head attention in the encoder self-attention layers and the decoder cross-attention layers are quite expensive. Deformable attention was devised to reduce the computational complexity of the multi-head attention in these layers for computer vision models.

Deformable Multi-Head Attention

Deformable Multi-Head Attention Formulation

Inspired by the deformable convolution, the deformable multi-head attention only attends to a small set of the keys sampled around a reference point in the spatial dimension for each query for computer vision models.

Given an input feature map XRC×H×W, let qΩq indexes a query element with representation feature zqRC, and pq denotes the reference point for q, the deformable multi-head attention feature is computed as follows:

DeformableMultiHeadAttention(zq,pq,X)=m=1MWm[k=1KAm,q,kWmx(pq+Δpm,q,k)]

where m indexes the attention head, M is the number of attention heads, WmRC×Cv and WmRCv×C are the learnable weight matrices for the m-th attention head, Cv=CM by default, k indexes the sampled key element, K is the number of sampled key elements, Δpm,q,k is the offset for the sampled key element k for the m-th attention head, and Am,q,kR is the attention weight between the query element q and the sampled key element k for the m-th attention head.

Deformable Multi-Head Attention

Deformable Multi-Head Attention

To effectively reduce the computational complexity of the multi-head attention, the number of sampled key elements K should be much smaller than the number of all the key elements Nk, i.e., KNk. Similar to the deformable convolution, x(pq+Δpm,q,k) can be interpolated from the neighboring pixels on the input feature map.

Although the offset Δpm,q,k is also learned, just like the one in the deformable convolution, it is usually learned by a linear layer instead of a convolution. This might seem to be inferior because the offset prediction now only depends on the query element q for each attention head m, and its neighbor pixels will not be considered.

Δpm,q,k=Wm,kzq

where Wm,kR2×C is the learnable weight matrix for the k-th sampled key element for the m-th attention head.

The attention weight Am,q,k is computed quite differently from the one in the multi-head attention. Instead of computing the attention score between the query element q and the key element k via dot product, the attention weight only depends on the query element q for each attention head m. The attention weight Am,q,k is computed as follows using a softmax function:

Am,q,k=exp(Wm,kzq)k=1Kexp(Wm,kzq)

where Wm,kR1×C is the learnable weight vector for the k-th sampled key element for the m-th attention head.

The deformable multi-head attention feature for all the queries is computed as follows:

DeformableMultiHeadAttention(Z,P,X)=m=1MAmX(P+ΔPm)WmWm=m=1Msoftmax(ZWm)X(P+ZWm)WmWm

where ZRNq×C and PRNq×2 are the query feature matrix and the reference point matrix, respectively, ΔPmRNq×K×2 is the offset matrix for the m-th attention head, and AmRNq×K is the attention weight matrix for the m-th attention head, and WmRK×2×C and WmRK×C are the learnable weight matrices for the m-th attention head. ZWmRNq×K is the attention score matrix for the m-th attention head, X(P+ZWm)RNq×K×C is the sampled key feature matrix for the m-th attention head, softmax(ZWm)X(P+ZWm)RNq×C.

Deformable Multi-Head Attention Computational Complexity

The computational complexity of the multi-head attention for Nq queries and K sampled keys per query can be derived as follows:

O(M(O(min(NqKC+NqKC+NqKCvC,NqKC+NkCvC+NqKCv))X(P+ZWm)Wm+O(NqKC+NqK)softmax(ZWm)+O(NqKCv)+O(NqCCv)))=O(NqKCM+NqKC+NqC2+min(NqKC2,NkC2))=O(NqC2+min(NqKC2,NkC2))

where Nk is the number of all the key elements and Nk=HW. We have a min operation in the first term because we could either do feature sampling followed by feature transformation or feature transformation followed by feature sampling, whichever is cheaper depending on the values of Nq, Nk=HW, and K.

For the Transformer based computer vision model encoder, because all the attention layers are self-attention, Nq=Nk=HW where H and W are the height and width of the feature map, respectively. The computational complexity of the deformable multi-head attention is O(NqC2)=O(HWC2). Comparing to the O(H2W2C) computational complexity of the multi-head attention, the computational complexity of the deformable multi-head attention is much cheaper in most cases.

For the decoder, the computational complexity of the deformable multi-head attention is O(NKC2) for both the self-attention layers and the cross-attention layers. Comparing to the O(NC2+N2C) computational complexity of the multi-head attention in the self-attention layers, the computational complexity of the deformable multi-head attention is comparable. Comparing to the O(HWC2+NHWC) computational complexity of the multi-head attention in the cross-attention layers, the computational complexity of the deformable multi-head attention is much cheaper in almost all cases.

Deformable Multi-Head Attention VS Deformable Convolution

The deformable multi-head attention degenerates to R×S deformable convolution, where R is the convolution kernel height and S is the convolution kernel width, when M=RS, K=1, and Wm=I which is an identity matrix.

Multi-Scale Deformable Multi-Head Attention

Multi-Scale Deformable Multi-Head Attention Formulation

The deformable multi-head attention can be naturally extended to multi-scale deformable multi-head attention to support the multi-scale feature maps from the neural network modules such as feature pyramid network (FPN). It allows the attention to be computed on multiple feature maps with different spatial resolutions.

Multi-Scale Deformable Multi-Head Attention in Deformable DETR

Multi-Scale Deformable Multi-Head Attention in Deformable DETR

Given L feature maps {X}l=1L with different spatial resolutions. The l-th feature map Xl has spatial dimension Hl×Wl. let qΩq indexes a query element with representation feature zqRC, and p^q[0,1]2 denotes the normalized reference point for q, the multi-scale deformable multi-head attention feature is computed as follows:

MultiScaleDeformableMultiHeadAttention(zq,p^q,{X}l=1L)=m=1MWm[l=1Lk=1KAm,q,l,kWmxl(p^q+Δpm,q,l,k)]

where m indexes the attention head, M is the number of attention heads, WmRC×Cv and WmRCv×C are the learnable weight matrices for the m-th attention head, Cv=CM by default, l indexes the feature map, L is the number of feature maps, k indexes the sampled key element, K is the number of sampled key elements, Δpm,q,l,k is the offset for the sampled key element k for the m-th attention head, and Am,q,l,kR is the attention weight between the query element q and the sampled key element k for the m-th attention head and l=1Lk=1KAm,q,l,k=1.

Deformable Multi-Head Attention V2

Deformable Multi-Head Attention Drawbacks

Taking a closer look at the deformable multi-head attention mentioned previously, we can see that the attention weight Am,q,k is computed quite differently from the one in the vanilla multi-head attention. Instead of computing the attention score between the query element q and the key element k via dot product, the attention weight only depends on the query element q for each attention head m, which is something a little bit strange in the context of Transformer attention. In fact, the deformable multi-head attention is more closer to convolution than attention.

The deformable multi-head attention also has a drawback of higher memory consumption than the vanilla multi-head attention, especially for the self-attention in the Transformer encoders. In the vanilla multi-head attention, the largest memory consumption comes from the feature map tensor XRH×W×C, which is usually very large. In the deformable multi-head attention, because of the new sampled key elements, the memory consumption would be the intermediate sampled key tensor XRNq×K×C. In the Transformer encoders for computer vision models, Nq=Nk=HW is already very large, having NqKC=HWKC is even more expensive than the vanilla multi-head attention.

When the deformable multi-head attention was first proposed for Transformer based computer vision models, the Vision Transformer (ViT) has not been developed yet. Therefore, the deformable multi-head attention was only used in the task head and the feature extraction backbone was still a convolutional neural network (CNN) for Deformable-DETR, producing a feature map of reasonable sized H and W, so that the deformable multi-head attention would not be too expensive. In addition, because the deformable multi-head attention was only used in the task head, using small K would not affect the performance too much while reducing the memory consumptions.

The ViT uses the multi-head attention in the Transformer encoders for feature extraction. In this case, because H and W become much larger, the deformable multi-head attention would be too expensive to use. In addition, using small K would affect feature exaction from the backbone. Therefore, using the deformable multi-head attention mentioned above in the Transformer encoders for feature extraction is not a good idea. A new deformable multi-head attention v2 was developed to address these issues more specifically for the ViT.

Deformable Multi-Head Attention V2 Formulation

Developed based on the deformable multi-head attention, the key idea of the deformable multi-head attention v2 is to use a set of global shifted keys shared among all the queries for each attention head. This design is more naturally extended from the vanilla multi-head attention, and it also reduces the memory consumption.

Deformable Multi-Head Attention V2

Deformable Multi-Head Attention V2

In the deformable multi-head attention v2, given an input feature map XRC×H×W, let qΩq indexes a query element with representation feature zqRC, and pq denotes the reference point for q, the deformable multi-head attention v2 feature is computed as follows:

DeformableMultiHeadAttentionV2(zq,pq,X)=m=1MWm[k=1KAm,q,kWmx(pq+Δpm,k)]

where m indexes the attention head, M is the number of attention heads, WmRC×Cv and WmRCv×C are the learnable weight matrices for the m-th attention head, Cv=CM by default, k indexes the sampled key element, K is the number of sampled key elements, Δpm,k is the offset for the sampled key element k for the m-th attention head, and Am,q,kR is the attention weight between the query element q and the sampled key element k for the m-th attention head. Notice that unlike the offset Δpm,q,k used in the deformable multi-head attention that is learned for each query element q, the offset Δpm,k used in the deformable multi-head attention v2 is shared among all the query elements.

The attention weight Am,q,k, which is normalized, is computed almost the same as the one in the vanilla deformable multi-head attention:

Am,q,k=exp(score(zq,xk)+bm,k)k=1Kexp(score(zq,xk)+bm,k)=exp((Umzq)VmxkCv+bm,k)k=1Kexp((Umzq)VmxkCv+bm,k)=exp(zqUmVmxkCv+bm,k)k=1Kexp(zqUmVmxkCv+bm,k)

where UmRCv×C and VmRCv×C are the learnable weight matrices for the m-th attention head, and score(zq,xk)=zqUmVmxkCv is the attention score between the query element q and the key element k for the m-th attention head, and bm,kR is the bias for the k-th sampled key element for the m-th attention head originally developed for the Swin Transformer. The bias bm,k is interpolated from a bias table and its computational cost is negligible comparing to the dominating attention score computation.

The offset Δpm,k depends on all the query elements q for each attention head m, and it is computed as follows:

Δpm,k=fm,k(Z)

where fm,k(Z) is a function of the query feature matrix ZRNq×C for the m-th attention head.

The deformable multi-head attention v2 feature for all the queries is computed as follows:

DeformableMultiHeadAttentionV2(Z,P,X)=m=1MAmX(P+ΔPm)WmWm=m=1Msoftmax(ZUmVmX(P+ΔPm)Cv+bm)X(P+ΔPm(Z))WmWm=m=1Msoftmax(ZUmVmX(P+fm(Z))Cv+bm)X(P+fm(Z))WmWm

where ZRNq×C and PRNq×2 are the query feature matrix and the reference point matrix, respectively, ΔPmRNq×K×2 is the offset matrix for the m-th attention head, and AmRNq×K is the attention weight matrix for the m-th attention head, and WmRK×2×C and WmRK×C are the learnable weight matrices for the m-th attention head.

Deformable Multi-Head Attention V2 Computational Complexity

Ignoring the computational complexity of the offset computation fm(Z) and the feature map interpolation X(P+ΔPm) for each head, the computational complexity of the deformable multi-head attention v2 for Nq queries and K sampled shared keys can be derived as follows:

O(M(O(KCvC)X(P+ΔPm)Wm+O(KCCv)+O(NqCCv)+O(NqKCv)+O(NqK)+O(K)softmax(ZUmVmX(P+ΔPm)Cv+bm)+O(NqKCv)+O(NqCCv)))=O(M(O(KCCv)+O(NqCCv)+O(NqKCv)))=O(KC2)+O(NqC2)+O(NqKC)=O(KC2+NqC2+NqKC)

Given the computed offsets ΔPm for each head, the computational complexity of feature map interpolation X(P+ΔPm) is O(MKC). The offset computation fm(Z) is usually a small neural network module. Suppose the computational cost of the offset computation is sufficiently small comparing to the other components in the deformable multi-head attention v2, the computational complexity of the deformable multi-head attention v2 for Nq queries and K sampled shared keys is just O(KC2+NqC2+NqKC+MKC).

For the Transformer based computer vision model encoder, because all the attention layers are self-attention, Nq=Nk=HW where H and W are the height and width of the feature map, respectively. The computational complexity of the deformable multi-head attention v2 is O(KC2+NqC2+NqKC+MKC)=O(KC2+HWC2+HWKC+MKC). If KHW, MC, and MHW, which is usually the case, the computational complexity of the deformable multi-head attention v2 becomes O(HWC2) and it is the same as the one in the deformable multi-head attention and is much cheaper than the one in the vanilla multi-head attention.

Although the deformable multi-head attention v2 was originally developed for the Transformer backbone in the ViT and might have not been used in the Transformer decoder, we could still derive the computational complexity of the deformable multi-head attention v2 in the Transformer decoder. In the self-attention layers, Nq=Nk=N, the computational complexity of the deformable multi-head attention v2 is O(KC2+NC2+N2C+MKC). This computation complexity might be comparable to the one used in the vanilla multi-head attention and the one used in the deformable multi-head attention, which is O(NC2+N2C). In the cross-attention layers, the computational complexity of the deformable multi-head attention v2 is O(KC2+NC2+NKC+MKC). Because K<HW and NHW, this computation complexity is much cheaper than the one used in the vanilla multi-head attention, which is O(HWC2+NHWC), and can be comparable to the one used in the deformable multi-head attention, which is O(NKC2). Note that the magnitude of K used in the deformable multi-head attention and the deformable multi-head attention v2 can be quite different.

References

Author

Lei Mao

Posted on

12-16-2023

Updated on

12-16-2023

Licensed under


Comments