CUDA Matrix Multiplication Optimization
Introduction
General matrix multiplication (GEMM) is a fundamental operation in linear algebra. It is also a very important operation in many scientific computing applications, such as machine learning and deep learning.
In this article, we will discuss how to optimize the performance of FP32 GEMM on NVIDIA GPUs using CUDA and how to extend the FP32 GEMM optimizations to FP16 GEMM using NVIDIA Tensor Cores.
General Matrix Multiplication
GEMM operation computes $D = AB + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$. In computer programs, usually $A$ and $B$ are constant input matrices and $C$ will be overwritten by the output matrix $D$.
In our implementations, we assume all the matrices, $A$, $B$, $C$ and $D$, are stored in the row-major order on memory with the leading dimension padded to 64 bytes for FP32 matrices and 32 bytes for FP16 matrices.
Naive Implementation with Non-Coalesced Memory Access
The naive implementation is to use 2D blocks, where each thread is responsible for computing one element of the output matrix. Concretely, for each thread with global thread index $(t_m, t_n)$, where $t_m \in [1, m]$ and $t_n \in [1, n]$, it computes $D_{t_m, t_n} = \sum_{t_k=1}^{k} A_{t_m, t_k} B_{t_k, t_n} + C_{t_m, t_n}$.
The following code snippet shows the naive implementation.
1 | template <typename T> |
In addition to other drawbacks from the naive algorithm, however, there is a major problem with this implementation, which is the non-coalesced memory access for both reading and writing the global memory. In our implementation specifically, because of the oversight that the fast thread index is used for indexing the row of $A$ and $C$, the threads in the same warp read the elements from the same column of $A$ that is stored in row-major order on memory, resulting in a non-coalesced memory access as the reads are completely non-consecutive. The same problem also happens when the warp overwrites the elements of $C$. The threads in the same warp read the same element of $B$, resulting in a broadcast memory access which is not affected by the oversight.
The performance of this FP32 GEMM implementation is only 0.27 TFLOPS on an NVIDIA GeForce RTX 3090 GPU, which is very poor.
Naive Implementation with Coalesced Memory Access
The fix to the non-coalesced memory access is to use the fast thread index for indexing the row of matrices that are stored in row-major order on memory instead so that the threads in the same warp read or overwrite the elements from the same row of the matrices are coalesced. In our implementation, we just need to swap the fast thread index and the slow thread index in the kernel function.
The following code snippet shows the naive implementation with coalesced memory access.
1 | template <typename T> |
Now, because of the fix, the threads in the same warp read the elements from the same row of $B$ that is stored in row-major order on memory, resulting in a coalesced memory access. The same thing also happens when the warp overwrites the elements of $C$. The threads in the same warp read the same element of $A$, resulting in a broadcast memory access. Therefore, this implementation should perform much better than the one with non-coalesced memory access.
The performance of this FP32 GEMM implementation becomes 1.72 TFLOPS on an NVIDIA GeForce RTX 3090 GPU, which is much better than the previous implementation. However, considering the theoretical peak performance of the GPU is 35.58 TFLOPS, the performance of this implementation is still very poor.
Implementation with 2D Block Tiling
Because the previous implementation accesses the global memory frequently, the GEMM implementation becomes memory-bound. Because accessing the shared memory is much faster than accessing the global memory, to improve the performance, we can use the shared memory to cache the input matrices $A$ and $B$ for data reuse.
However, because the shared memory size is limited, we cannot cache the entire input matrices $A$ and $B$ in the shared memory. Instead, we can cache a 2D tile of $A$ and $B$ in the shared memory and use the 2D tile to compute a 2D tile of the output matrix $D$. Then, we can load the next 2D tile of $A$ and $B$ to the shared memory and compute the next 2D tile of $D$.
Mathematically, given a GEMM operation $D = AB + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$, the matrices could be divided into smaller matrices.
$$
A =
\begin{bmatrix}
A_{1,1}^{d_{bm} \times d_{bk}} & A_{1,2}^{d_{bm} \times d_{bk}} & \cdots & A_{1,k/d_{bk}}^{d_{bm} \times d_{bk}} \\
A_{2,1}^{d_{bm} \times d_{bk}} & A_{2,2}^{d_{bm} \times d_{bk}} & \cdots & A_{2,k/d_{bk}}^{d_{bm} \times d_{bk}} \\
\vdots & \vdots & \ddots & \vdots \\
A_{m/d_{bm},1}^{d_{bm} \times d_{bk}} & A_{m/d_{bm},2}^{d_{bm} \times d_{bk}} & \cdots & A_{m/d_{bm},k/d_{bk}}^{d_{bm} \times d_{bk}} \\
\end{bmatrix}
$$
$$
B =
\begin{bmatrix}
B_{1,1}^{d_{bk} \times d_{bn}} & B_{1,2}^{d_{bk} \times d_{bn}} & \cdots & B_{1,n/d_{bn}}^{d_{bk} \times d_{bn}} \\
B_{2,1}^{d_{bk} \times d_{bn}} & B_{2,2}^{d_{bk} \times d_{bn}} & \cdots & B_{2,n/d_{bn}}^{d_{bk} \times d_{bn}} \\
\vdots & \vdots & \ddots & \vdots \\
B_{k/d_{bk},1}^{d_{bk} \times d_{bn}} & B_{k/d_{bk},2}^{d_{bk} \times d_{bn}} & \cdots & B_{k/d_{bk},n/d_{bn}}^{d_{bk} \times d_{bn}} \\
\end{bmatrix}
$$
$$
C =
\begin{bmatrix}
C_{1,1}^{d_{bm} \times d_{bn}} & C_{1,2}^{d_{bm} \times d_{bn}} & \cdots & C_{1,n/d_{bn}}^{d_{bm} \times d_{bn}} \\
C_{2,1}^{d_{bm} \times d_{bn}} & C_{2,2}^{d_{bm} \times d_{bn}} & \cdots & C_{2,n/d_{bn}}^{d_{bm} \times d_{bn}} \\
\vdots & \vdots & \ddots & \vdots \\
C_{m/d_{bm},1}^{d_{bm} \times d_{bn}} & C_{m/d_{bm},2}^{d_{bm} \times d_{bn}} & \cdots & C_{m/d_{bm},n/d_{bn}}^{d_{bm} \times d_{bn}} \\
\end{bmatrix}
$$
$$
D =
\begin{bmatrix}
D_{1,1}^{d_{bm} \times d_{bn}} & D_{1,2}^{d_{bm} \times d_{bn}} & \cdots & D_{1,n/d_{bn}}^{d_{bm} \times d_{bn}} \\
D_{2,1}^{d_{bm} \times d_{bn}} & D_{2,2}^{d_{bm} \times d_{bn}} & \cdots & D_{2,n/d_{bn}}^{d_{bm} \times d_{bn}} \\
\vdots & \vdots & \ddots & \vdots \\
D_{m/d_{bm},1}^{d_{bm} \times d_{bn}} & D_{m/d_{bm},2}^{d_{bm} \times d_{bn}} & \cdots & D_{m/d_{bm},n/d_{bn}}^{d_{bm} \times d_{bn}} \\
\end{bmatrix}
$$
Each small matrix in $D$ is computed as multiple small matrix multiplications and accumulations.
$$
D_{b_m,b_n}^{d_{bm} \times d_{bn}} = \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}}
$$
In this implementation, each 2D block with block index $(b_m, b_n)$, where $b_m \in [1, m/d_{bm}]$ and $b_n \in [1, n/d_{bn}]$, is responsible for computing one small matrix $D_{b_m,b_n}^{d_{bm} \times d_{bn}}$. The shared memory is used to cache a 2D tile of $A$ and $B$ with size $d_{bm} \times d_{bk}$ and $d_{bk} \times d_{bn}$, respectively. The 2D tile of $A$ is indexed by $(b_m, b_k)$, where $b_m \in [1, m/d_{bm}]$ and $b_k \in [1, k/d_{bk}]$. The 2D tile of $B$ is indexed by $(b_k, b_n)$, where $b_k \in [1, k/d_{bk}]$ and $b_n \in [1, n/d_{bn}]$. The cache and small matrix multiplication compute process is repeated for $k/d_{bk}$ times until the entire small matrix $D_{b_m,b_n}^{d_{bm} \times d_{bn}}$ is accumulated.
Similar to the previous implementations, each block requires $d_{bm} \times d_{bn}$ threads to compute the small matrix $D_{b_m, b_n}^{d_{bm} \times d_{bn}}$ and each thread with block thread index $(t_m, t_n)$, where $t_m \in [1, d_{bm}]$ and $t_n \in [1, d_{bn}]$, is responsible for computing one element of the small matrix.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{t_m,t_n}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_m,t_n} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{t_k=1}^{d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m,t_k} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n} \\
\end{aligned}
$$
The following code snippet shows the implementation with 2D block tiling.
1 | template <typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y, |
The performance of this FP32 GEMM implementation becomes 2.66 TFLOPS on an NVIDIA GeForce RTX 3090 GPU, which is much better than the previous implementation. However, it is still far from the theoretical peak performance of the GPU.
The problem of this implementation is that the shared memory is accessed very frequently. Even if accessing the shared memory is much faster than accessing the global memory, the shared memory instruction runs very intensively just to compute a small number of values using simple arithmetic instructions, which is not efficient. Therefore, the performance of this implementation is still limited by the memory bandwidth, this time from the shared memory.
Implementation with 2D Block Tiling and 1D Thread Tiling
To further improve the performance, we can alleviate the shared memory bandwidth problem by further caching some even smaller tiles of the input matrices $A$ and $B$ from the shared memory to the registers of the threads. This time, each thread is responsible for computing a small tile of the output matrix $D$ instead of one single element. Because the registers are the fastest to access, the performance of this implementation should be much better than the previous one.
We start with only caching the data of matrix $B$ from the shared memory to the registers. Each thread with block thread index $(t_m, t_n)$, where $t_m \in [1, d_{bm} / d_{tm}]$ and $t_n \in [1, d_{bn}]$, is now responsible for computing $d_{tm}$ elements of the small matrix, where $d_{tm}$ is the thread tile size.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{t_m : t_m + d_{tm},t_n}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{t_k=1}^{d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m : t_m + d_{tm},t_k} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n} \\
\end{aligned}
$$
In our previous implementation without thread tiling, to compute one element of the small matrix, we need to read $d_{bk}$ values from the cached matrix $A$ in the shared memory and $d_{bk}$ values from the cached matrix $B$ in the shared memory. In total, we need to read $2d_{k}$ values from the shared memory.
Now, with 1D thread tiling, to compute $d_{tm}$ elements of the small matrix, we only need to read $d_{bk} \times d_{tm}$ values from the cached matrix $A$ in the shared memory and $d_{bk}$ values from the cached matrix $B$ in the shared memory. Specifically, in each inner loop, $\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n}$ is cached in the register to be reused for $d_{tm}$ times. In total, we need to read $d_{bk} \times d_{tm} + d_{bk}$ values from the shared memory. On average, to compute one element of the small matrix, we need to read $d_{bk} + d_{bk} / d_{tm}$ values from the shared memory.
Because $d_{bk} + d_{bk} / d_{tm} < 2d_{k}$, the shared memory is accessed less frequently and the shared memory bandwidth problem is alleviated.
The following code snippet shows the implementation with 2D block tiling and 1D thread tiling.
1 | template <typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y, |
The performance of this FP32 GEMM implementation becomes 8.91 TFLOPS on an NVIDIA GeForce RTX 3090 GPU. It seems that we have been making good progress.
Implementation with 2D Block Tiling and 2D Thread Tiling
If the number of registers is not a bottleneck for the performance, we can further improve the performance by caching the data of both matrix $A$ and $B$ from the shared memory to the registers. Each thread with block thread index $(t_m, t_n)$, where $t_m \in [1, d_{bm} / d_{tm}]$ and $t_n \in [1, d_{bn} / d_{tn}]$, is now responsible for computing $d_{tm} \times d_{tn}$ elements of the small matrix, where $d_{tm}$ and $d_{tn}$ are the thread tile sizes for the row and column, respectively.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{t_m : t_m + d_{tm},t_n : t_n + d_{tn}}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n : t_n + d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n : t_n + d_{tn}} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n : t_n + d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{t_k=1}^{d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m : t_m + d_{tm},t_k} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n : t_n + d_{tn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m : t_m + d_{tm},t_n : t_n + d_{tn}} \\
\end{aligned}
$$
In our previous implementation with 1D thread tiling, to compute one element of the small matrix, we need to read $d_{bk} + d_{bk} / d_{tm}$ values from the shared memory on average.
Now, with 2D thread tiling, to compute $d_{tm} \times d_{tn}$ elements of the small matrix, we only need to read $d_{bk} \times (d_{tm} + d_{tn})$ values from the shared memory. Specifically, in each inner loop, $\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m : t_m + d_{tm},t_k}$ and $\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n : t_n + d_{tn}}$ are cached in the register to be reused for computing the matrix multiplication $\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m : t_m + d_{tm},t_k} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n : t_n + d_{tn}}$. In total, we need to read $d_{bk} \times (d_{tm} + d_{tn})$ values from the shared memory. On average, to compute one element of the small matrix, we need to read $d_{bk} / d_{tm} + d_{bk} / d_{tn}$ values from the shared memory.
Because $d_{bk} / d_{tm} + d_{bk} / d_{tn} < d_{bk} + d_{bk} / d_{tm}$, the shared memory is accessed even less frequently and the shared memory bandwidth problem is further alleviated.
There is an alternative way to describe the 2D thread tiling implementation.
Mathematically, given a matrix multiplication and accumulation operation $D_{b_m,b_n}^{d_{bm} \times d_{bn}} = \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}}$, where $D_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, $A_{b_m,b_k} \in \mathbb{R}^{d_{bm} \times d_{bk}}$, $B_{b_k,b_n} \in \mathbb{R}^{d_{bk} \times d_{bn}}$, $C_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, the matrices could be divided into smaller matrices.
$$
A_{b_m,b_k}^{d_{bm} \times d_{bk}} =
\begin{bmatrix}
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,1}^{d_{tm} \times d_{tk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,2}^{d_{tm} \times d_{tk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,d_{bk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,1}^{d_{tm} \times d_{tk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,2}^{d_{tm} \times d_{tk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,d_{bk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{tm},1}^{d_{tm} \times d_{tk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{tm},2}^{d_{tm} \times d_{tk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{tm},d_{bk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\end{bmatrix}
$$
$$
B_{b_k,b_n}^{d_{bk} \times d_{bn}} =
\begin{bmatrix}
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,1}^{d_{tk} \times d_{tn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,2}^{d_{tk} \times d_{tn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,d_{bn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,1}^{d_{tk} \times d_{tn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,2}^{d_{tk} \times d_{tn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,d_{bn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{tk},1}^{d_{tk} \times d_{tn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{tk},2}^{d_{tk} \times d_{tn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{tk},d_{bn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\end{bmatrix}
$$
$$
C_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{tm} \times d_{tn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{tm} \times d_{tn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{tm} \times d_{tn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{tm} \times d_{tn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},1}^{d_{tm} \times d_{tn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},2}^{d_{tm} \times d_{tn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\end{bmatrix}
$$
$$
D_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{tm} \times d_{tn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{tm} \times d_{tn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{tm} \times d_{tn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{tm} \times d_{tn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},1}^{d_{tm} \times d_{tn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},2}^{d_{tm} \times d_{tn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{tm},d_{bn}/d_{tn}}^{d_{tm} \times d_{tn}} \\
\end{bmatrix}
$$
Each small matrix in $D_{b_m,b_n}^{d_{bm} \times d_{bn}}$ is computed as multiple small matrix multiplications and accumulations.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{t_m,t_n}^{d_{tm} \times d_{tn}}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n}^{d_{tm} \times d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_m,t_n}^{d_{tm} \times d_{tn}} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n}^{d_{tm} \times d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{t_k=1}^{d_{bk} / d_{tk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{t_m,t_k}^{d_{tm} \times d_{tk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{t_k,t_n}^{d_{tk} \times d_{tn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{t_m,t_n}^{d_{tm} \times d_{tn}} \\
\end{aligned}
$$
Each thread with block thread index $(t_m, t_n)$, where $t_m \in [1, d_{bm} / d_{tm}]$ and $t_n \in [1, d_{bn} / d_{tn}]$, in the block with block index $(b_m, b_n)$, where $b_m \in [1, m/d_{bm}]$ and $b_n \in [1, n/d_{bn}]$, is responsible for computing one small matrix multiplication and accumulation $\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{t_m,t_n}^{d_{tm} \times d_{tn}}$.
In the case of 1D thread tiling described in this article, we have $d_{tm} > 1$, $d_{tk} = 1$ and $d_{tn} = 1$. In the case of 2D thread tiling in this article, we have $d_{tm} > 1$, $d_{tk} = 1$ and $d_{tn} > 1$. It is also technically feasible to do thread tiling with $d_{tk} > 1$. In the case of no thread tiling, which is actually a special case of the thread tiling, we have $d_{tm} = 1$, $d_{tk} = 1$ and $d_{tn} = 1$.
The following code snippet shows the implementation with 2D block tiling and 2D thread tiling.
1 | // GEMM kernel v04. |
The performance of this FP32 GEMM implementation becomes 13.02 TFLOPS on an NVIDIA GeForce RTX 3090 GPU.
Implementation with 2D Block Tiling and 2D Thread Tiling and Vectorized Memory Access
In my previous article “CUDA Vectorized Memory Access”, I showed how to use vectorized memory access to improve the performance of a trivial memory copy kernel. Vectorized memory access reduces the number of memory transactions and therefore improves the memory bandwidth utilization. The same trick can be applied to this GEMM kernel to accelerate the data loading from global memory to the shared memory and the data loading from the shared memory to the registers.
In the previous implementation, to compute matrix multiplication, each thread would have to read a column of matrix $A$ and a row of matrix $B$ from the shared memory and cache them in the registers. Because reading the data from a column of matrix $A$ would prevent vectorized memory access, we would like to transpose the matrix $A$ when loading the data from global memory to the shared memory, so that each thread can access a row of transposed matrix $A$ and a row of matrix $B$ from the shared memory in a vectorized fashion and cache them in the registers.
The following code snippet shows the implementation with 2D block tiling and 2D thread tiling and vectorized memory access.
1 | template <typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y, |
Except the data loading using vectorized memory access, the rest of the kernel is the same as the previous implementation with 2D block tiling and 2D thread tiling. There is, however, a caveat for vectorized memory access in our use case which does not exist in the previous implementation. When we load the data from global memory to the shared memory and load the data from the shared memory to the registers, considering the matrices are 2D, we need to make sure the data alignment is correct for the vectorized memory access data type. Otherwise, undefined behavior will happen. For example, if we use int4
as the vectorized memory access data type, we need to make sure the data alignment is a multiple of 16 bytes. This is why we will have to pad the leading dimension of the matrix $A$ and matrix $B$ in the global memory and the shared memory dimensions have to be carefully chosen.
The performance of this FP32 GEMM implementation becomes 19.66 TFLOPS on an NVIDIA GeForce RTX 3090 GPU.
Implementation with 2D Block Tiling and 2D Warp Tiling and 2D Thread Tiling and Vectorized Memory Access
In the CUDA programming model, a warp, which consists of 32 threads, is the smallest unit of scheduling and execution. Shared memory bank conflicts can happen when the threads in a warp access the same bank of the shared memory. In our previous implementation, because the GEMM CUDA kernel was not organized in a warp-centric way, it is less obvious how to avoid shared memory bank conflicts.
In this implementation, we will organize the GEMM CUDA kernel in a warp-centric way and use 2D warp tiling and 2D thread tiling so that the shared memory bank conflicts can be anticipated and optimized much easier.
Understanding warp tiling is almost exactly the same as understanding thread tiling.
Mathematically, given a matrix multiplication and accumulation operation $D_{b_m,b_n}^{d_{bm} \times d_{bn}} = \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}}$, where $D_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, $A_{b_m,b_k} \in \mathbb{R}^{d_{bm} \times d_{bk}}$, $B_{b_k,b_n} \in \mathbb{R}^{d_{bk} \times d_{bn}}$, $C_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, the matrices could be divided into smaller matrices.
$$
A_{b_m,b_k}^{d_{bm} \times d_{bk}} =
\begin{bmatrix}
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\end{bmatrix}
$$
$$
B_{b_k,b_n}^{d_{bk} \times d_{bn}} =
\begin{bmatrix}
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\end{bmatrix}
$$
$$
C_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\end{bmatrix}
$$
$$
D_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\end{bmatrix}
$$
Each small matrix in $D_{b_m,b_n}^{d_{bm} \times d_{bn}}$ is computed as multiple small matrix multiplications and accumulations.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
\end{aligned}
$$
Each warp with block warp index $(w_m, w_n)$, where $w_m \in [1, d_{bm} / d_{wm}]$ and $w_n \in [1, d_{bn} / d_{wn}]$, in the block with block index $(b_m, b_n)$, where $b_m \in [1, m/d_{bm}]$ and $b_n \in [1, n/d_{bn}]$, is responsible for computing one small matrix multiplication and accumulation $\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}$.
So far, everything looks the same as the mathematical descriptions for the 2D thread tiling, except that the thread indices and thread tile sizes are replaced by the warp indices and warp tile sizes.
The remaining question is how to use all the 32 threads in the warp with block warp index $(w_m, w_n)$ to compute $\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)^{d_{wm} \times d_{wn}}$. There is not a unique way to do that. The way we chose is to use 2D thread tiling. Suppose the number of threads per warp in the row is $m_{t}$ and the number of threads per warp in the column is $n_{t}$, we must have $m_{t} \times n_{t} = 32$. Each thread in the warp should be responsible for computing $\left(d_{wm} / m_{t}) \times (d_{wn} / n_{t}\right)$ values of the output matrix $\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)^{d_{wm} \times d_{wn}}$. We then set the thread tile sizes to be $d_{tm}$ for row and $d_{tn}$ for column, such that $\left(d_{wm} / m_{t} \right) \mod d_{tm} = 0$ and $\left(d_{wn} / n_{t} \right) \mod d_{tn} = 0$. Each thread in the warp will have to compute $\left(\left(d_{wm} / m_{t} \right) / d_{tm} \right) \times \left(\left(d_{wn} / n_{t} \right) / d_{tn} \right)$ tiles of size $d_{tm} \times d_{tn}$ of the output matrix $\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)^{d_{wm} \times d_{wn}}$.
Suppose the thread tile index is $(t_m, t_n)$, where $t_m \in [1, d_{wm} / d_{tm}]$ and $t_n \in [1, d_{wn} / d_{tn}]$. The thread responsible for computing the tile has the warp thread index $(t_m \mod m_t, t_n \mod n_t)$. Because the matrix $\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}}$ can be divided along the row dimension to $d_{wm} / d_{tm}$ fragments and the matrix $\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}$ can be divided along the column dimension to $d_{wn} / d_{tn}$ fragments. We have
$$
\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} =
\begin{bmatrix}
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\vdots & \vdots & \ddots & \vdots \\
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\end{bmatrix}
$$
$$
\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} =
\begin{bmatrix}
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\end{bmatrix}
$$
Each thread with warp thread index $(t_m \mod m_t, t_n \mod n_t)$ is responsible for computing one small matrix multiplication and accumulation $\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}$.
The thread tile $\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}$ can be computed as follows.
$$
\begin{aligned}
\left(\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}
&= \sum_{t_k=1}^{d_{wk} / d_{tk}} \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{t_m,t_k}^{d_{tm} \times d_{tk}} \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{t_k,t_n}^{d_{tk} \times d_{tn}} \\
\end{aligned}
$$
Taken together, the thread tile $\left(\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}$ can be computed as follows.
$$
\begin{aligned}
\left(\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}
&= \left(\sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( \sum_{t_k=1}^{d_{wk} / d_{tk}} \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{t_m,t_k}^{d_{tm} \times d_{tk}} \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{t_k,t_n}^{d_{tk} \times d_{tn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}} \\
\end{aligned}
$$
In this implementation, we set $d_{wk} = d_{tk}$ to make the thread tiling algorithm simpler.
The following code snippet shows the implementation with 2D block tiling and 2D warp tiling and 2D thread tiling and vectorized memory access.
1 | // GEMM kernel v06. |
The performance of this FP32 GEMM implementation becomes 20.16 TFLOPS on an NVIDIA GeForce RTX 3090 GPU. Comparing to the cuBLAS FP32 GEMM performance, which is 24.59 TFLOPS, this implementation has been optimized reasonably well.
Implementation with 2D Block Tiling and 2D Warp Tiling and Tensor Core and Vectorized Memory Access
Because we have already organized the GEMM CUDA kernel in a warp-centric way, and NVIDIA Tensor Core instructions are interfaced at the warp level, it is then very straightforward to utilize NVIDIA Tensor Core WMMA APIs to further accelerate the GEMM computation. Because the NVIDIA Tensor Core does not support IEEE FP32 computation, we will make this CUDA kernel to run FP16 GEMM instead.
Comparing to the implementation with 2D block tiling and 2D warp tiling and 2D thread tiling and vectorized memory access, the implementation with 2D block tiling and 2D warp tiling and Tensor Core and vectorized memory access is simpler because the thread tiling process is abstracted away by the NVIDIA Tensor Core warp-level WMMA APIs.
Mathematically, given a matrix multiplication and accumulation operation $D_{b_m,b_n}^{d_{bm} \times d_{bn}} = \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}}$, where $D_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, $A_{b_m,b_k} \in \mathbb{R}^{d_{bm} \times d_{bk}}$, $B_{b_k,b_n} \in \mathbb{R}^{d_{bk} \times d_{bn}}$, $C_{b_m,b_n} \in \mathbb{R}^{d_{bm} \times d_{bn}}$, the matrices could be divided into smaller matrices.
$$
A_{b_m,b_k}^{d_{bm} \times d_{bk}} =
\begin{bmatrix}
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{1,d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{2,d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wk}} & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wk}} & \cdots & \left(A_{b_m,b_k}^{d_{bm} \times d_{bk}}\right)_{d_{bm}/d_{wm},d_{bk}/d_{wk}}^{d_{wm} \times d_{wk}} \\
\end{bmatrix}
$$
$$
B_{b_k,b_n}^{d_{bk} \times d_{bn}} =
\begin{bmatrix}
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},1}^{d_{wk} \times d_{wn}} & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},2}^{d_{wk} \times d_{wn}} & \cdots & \left(B_{b_k,b_n}^{d_{bk} \times d_{bn}}\right)_{d_{bk}/d_{wk},d_{bn}/d_{wn}}^{d_{wk} \times d_{wn}} \\
\end{bmatrix}
$$
$$
C_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wn}} & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wn}} & \cdots & \left(C_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\end{bmatrix}
$$
$$
D_{b_m,b_n}^{d_{bm} \times d_{bn}} =
\begin{bmatrix}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{1,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{2,d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},1}^{d_{wm} \times d_{wn}} & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},2}^{d_{wm} \times d_{wn}} & \cdots & \left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{d_{bm}/d_{wm},d_{bn}/d_{wn}}^{d_{wm} \times d_{wn}} \\
\end{bmatrix}
$$
Each small matrix in $D_{b_m,b_n}^{d_{bm} \times d_{bn}}$ is computed as multiple small matrix multiplications and accumulations.
$$
\begin{aligned}
\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}
&= \left( \sum_{b_k=1}^{k/d_{bk}} A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} + C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}} \\
\end{aligned}
$$
Each warp with block warp index $(w_m, w_n)$, where $w_m \in [1, d_{bm} / d_{wm}]$ and $w_n \in [1, d_{bn} / d_{wn}]$, in the block with block index $(b_m, b_n)$, where $b_m \in [1, m/d_{bm}]$ and $b_n \in [1, n/d_{bn}]$, is responsible for computing one small matrix multiplication and accumulation $\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}$.
Suppose the Tensor Core WMMA GEMM size is $d_{tm} \times d_{tn} \times d_{tk}$. Because the matrix $\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}}$ can be divided along the row dimension to $d_{wm} / d_{tm}$ fragments and the matrix $\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}}$ can be divided along the column dimension to $d_{wn} / d_{tn}$ fragments. We have
$$
\left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} =
\begin{bmatrix}
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{1,d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{2,d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\vdots & \vdots & \ddots & \vdots \\
\left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},1}^{d_{tm} \times d_{tk}} & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},2}^{d_{tm} \times d_{tk}} & \cdots & \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{d_{wm}/d_{tm},d_{wk}/d_{tk}}^{d_{tm} \times d_{tk}} \\
\end{bmatrix}
$$
$$
\left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} =
\begin{bmatrix}
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{1,d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{2,d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\vdots & \vdots & \ddots & \vdots \\
\left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},1}^{d_{tk} \times d_{tn}} & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},2}^{d_{tk} \times d_{tn}} & \cdots & \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{d_{wk}/d_{tk},d_{wn}/d_{tn}}^{d_{tk} \times d_{tn}} \\
\end{bmatrix}
$$
Instead of calling thread-level instructions, each warp will call WMMA warp-level Tensor Core to compute all the $\left(\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}$ for $\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}$ iteratively.
$$
\begin{aligned}
\left(\left(D_{b_m,b_n}^{d_{bm} \times d_{bn}}\right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}}
&= \left(\sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}} \\
&= \sum_{b_k=1}^{k/d_{bk}} \left( \sum_{w_k=1}^{d_{bk} / d_{wk}} \left( \sum_{t_k=1}^{d_{wk} / d_{tk}} \left( \left( A_{b_m,b_k}^{d_{bm} \times d_{bk}} \right)_{w_m,w_k}^{d_{wm} \times d_{wk}} \right)_{t_m,t_k}^{d_{tm} \times d_{tk}} \left( \left( B_{b_k,b_n}^{d_{bk} \times d_{bn}} \right)_{w_k,w_n}^{d_{wk} \times d_{wn}} \right)_{t_k,t_n}^{d_{tk} \times d_{tn}} \right) + \left( C_{b_m,b_n}^{d_{bm} \times d_{bn}} \right)_{w_m,w_n}^{d_{wm} \times d_{wn}}\right)_{t_m, t_n}^{d_{tm} \times d_{tn}} \\
\end{aligned}
$$
In this implementation, because of the WMMA Tensor Core API restrictions, $d_{tm} = 16$, $d_{tn} = 16$, $d_{tk} = 16$.
The following code snippet shows the implementation with 2D block tiling and 2D warp tiling and Tensor Core and vectorized memory access.
1 | // GEMM kernel v07. |
Because the fundamental WMMA size is $16 \times 16 \times 16$, all the 32 threads in the same warp has to synergistically access the shared memory where the WMMA fragment is cached. It is then very possible that the shared memory bank conflicts will happen. To avoid the shared memory bank conflicts, we will have to pad the shared memory size to make sure the shared memory bank conflicts will not happen. This is why we have to use the skew size to pad the shared memory size at the leading dimension.
The performance of this FP16 GEMM implementation becomes 46.78 TFLOPS on an NVIDIA GeForce RTX 3090 GPU. Comparing to the cuBLAS FP16 GEMM performance, which is 138.95 TFLOPS, this implementation only achieves 33.7% of the cuBLAS FP16 GEMM performance. We will leave the performance optimization of this implementation as a future work.
Conclusions
The optimizations we performed on the GEMM CUDA kernels mainly follow the diagrams in the article “CUTLASS: Fast Linear Algebra in CUDA C++”.
With the optimization techniques, such as 2D block tiling, 2D warp tiling, 2D thread tiling, and vectorized memory access, we can achieve 20.16 TFLOPS FP32 GEMM performance on an NVIDIA GeForce RTX 3090 GPU, which is 80% - 90% of the cuBLAS FP32 GEMM performance.
Source Code
The source code of the GEMM CUDA kernels can be found in my GitHub repository “CUDA GEMM Optimization”.
References
- CUTLASS: Fast Linear Algebra in CUDA C++
- CUDA GEMM Optimization - GitHub
- CUDA Matrix Multiplication
- CUDA Vectorized Memory Access
- CUDA Data Alignment
- CUDA Shared Memory Bank
- NVIDIA Tensor Core Programming
- CUDA Tensor Core GEMM
- How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog
CUDA Matrix Multiplication Optimization
https://leimao.github.io/article/CUDA-Matrix-Multiplication-Optimization/