Matrix multiplication and accumulation (MMA) is a key operation for general matrix multiplication (GEMM). In CUTLASS, CuTe provides APIs for configuring MMA from MMA atoms to MMA tiles so that larger MMA problems can be solved.
In this blog post, I would like to discuss the CuTe tiled MMA configurations, layouts, and API usages, using an example.
CuTe Tiled MMA Preview Example
The following CuTe tiled MMA preview example does not actually perform any MMA computation, because it is completely a host program. Instead, it demonstrates how to configure the MMA atom, MMA tile, and MMA layout using CuTe APIs.
intmain(int argc, constchar** argv) { // Tiled MMA requires everything to be static. // Therefore, this preview program does not allow user to configure // dynamically. To preview a new tiled MMA configuration, the user has to // modify this program and recompile.
// Configure data type. using TA = cute::half_t; using TB = cute::half_t; using TC = cute::half_t;
// Configure static "shared memory". // The "shared memory" is actually on host for preview purpose. // For tiled mma, the shared memory layout has to be static. constexprint bM{128 * 2 / sizeof(TA)}; constexprint bN{128 * 2 / sizeof(TB)}; constexprint bK{32}; autoconst blk_M = cute::Int<bM>{}; autoconst blk_N = cute::Int<bN>{}; autoconst blk_K = cute::Int<bK>{};
auto h_A = thrust::host_vector<TA>(size_a); auto h_B = thrust::host_vector<TB>(size_b); auto h_C = thrust::host_vector<TC>(size_c);
// Make tensor for smem_A and smem_B. auto smem_tensor_A{cute::make_tensor(h_A.data(), smem_layout_A)}; auto smem_tensor_B{cute::make_tensor(h_B.data(), smem_layout_B)}; auto smem_tensor_C{cute::make_tensor(h_C.data(), smem_layout_C)};
// Configure tiled MMA. using MmaTraits = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>; using MmaAtomShape = MmaTraits::Shape_MNK; autoconst mma_atom = cute::MMA_Atom<MmaTraits>{}; autoconst mma_atom_shape = MmaAtomShape{}; // Repeating the mma atom along the M, N, and K dimensions. // This increases the number of threads to process the tiled MMA. constexprint MMA_LAYOUT_M{2}; constexprint MMA_LAYOUT_N{2}; constexprint MMA_LAYOUT_K{1}; auto mma_layout{cute::make_layout( cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}))}; // Repeating the mma processing along the M, N, and K dimensions. // This does not increase the number of threads to process the tiled MMA. // But the number of registers required for processing the tiled MMA // increases. constexprint NUM_MMA_TILE_M{1}; constexprint NUM_MMA_TILE_N{2}; constexprint NUM_MMA_TILE_K{1}; constexprint MMA_TILE_M{cute::get<0>(mma_atom_shape) * MMA_LAYOUT_M * NUM_MMA_TILE_M}; constexprint MMA_TILE_N{cute::get<1>(mma_atom_shape) * MMA_LAYOUT_N * NUM_MMA_TILE_N}; constexprint MMA_TILE_K{cute::get<2>(mma_atom_shape) * MMA_LAYOUT_K * NUM_MMA_TILE_K}; auto mma_tile{cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{})}; auto tiled_mma{cute::make_tiled_mma(mma_atom, mma_layout, mma_tile)};
// Partition via MMA. // set an arbitrary thread index. constexprint THREAD_IDX{0}; CUTE_STATIC_ASSERT(THREAD_IDX < NUM_THREADS); CUTE_STATIC_ASSERT(THREAD_IDX >= 0);
auto thread_mma{tiled_mma.get_slice(THREAD_IDX)}; // Register tensors used for MMA. auto thread_layout_C_register_tensor_A{ thread_mma.partition_fragment_A(smem_tensor_A)}; // (MMA, MMA_M, MMA_K) auto thread_layout_C_register_tensor_B{ thread_mma.partition_fragment_B(smem_tensor_B)}; // (MMA, MMA_N, MMA_K) auto thread_layout_C_register_tensor_C{ thread_mma.partition_fragment_C(smem_tensor_C)}; // (MMA, MMA_M, MMA_N)
// Use no tiled copy from shared memory to register. auto thread_layout_C_smem_tensor_A_no_tiled_copy{ thread_mma.partition_A(smem_tensor_A)}; // (MMA, MMA_M, MMA_K) auto thread_layout_C_smem_tensor_B_no_tiled_copy{ thread_mma.partition_B(smem_tensor_B)}; // (MMA, MMA_N, MMA_K) auto thread_layout_C_smem_tensor_C_no_tiled_copy{ thread_mma.partition_C(smem_tensor_C)}; // (MMA, MMA_M, MMA_N)
// thread_layout_C_smem_tensor_A_no_tiled_copy and // thread_layout_C_register_tensor_A shall have the same shape. CUTE_STATIC_ASSERT_V( cute::shape(thread_layout_C_smem_tensor_A_no_tiled_copy) == cute::shape(thread_layout_C_register_tensor_A));
// Use tiled copy from shared memory to register. auto copy_atom_A = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, TA>{}; auto copy_atom_B = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, TB>{};
auto smem_tiled_copy_A{cute::make_tiled_copy_A(copy_atom_A, tiled_mma)}; auto smem_tiled_copy_B{cute::make_tiled_copy_B(copy_atom_B, tiled_mma)};
auto smem_thread_copy_A{smem_tiled_copy_A.get_slice(THREAD_IDX)}; auto smem_thread_copy_B{smem_tiled_copy_B.get_slice(THREAD_IDX)};
auto thread_layout_C_smem_tensor_A_tiled_copy{ smem_thread_copy_A.partition_S(smem_tensor_A)}; auto thread_layout_C_smem_tensor_B_tiled_copy{ smem_thread_copy_B.partition_S(smem_tensor_B)};
auto thread_layout_C_register_tensor_A_copy_view{ smem_thread_copy_A.retile_D(thread_layout_C_register_tensor_A)}; auto thread_layout_C_register_tensor_B_copy_view{ smem_thread_copy_B.retile_D(thread_layout_C_register_tensor_B)};
A high-performance example that uses almost the same tiled MMA configurations to perform the GEMM computation can be found on my GitHub.
CuTe Tiled MMA Configurations and Layouts
MMA Problem Size and Shared Memory Configuration
In one thread block, per one main loop iteration, the MMA problem size is $M \times N \times K = 128 \times 128 \times 32$. The static shared memory is used to store the $M \times K = 128 \times 32$ sub-matrix of matrix $A$ in a column-major layout and the $K \times N = 32 \times 128$ sub-matrix of matrix $B$ in a row-major layout. Using the convention of MMA, we typically describe the sub-matrix of matrix $B$ as $N \times K = 128 \times 32$ column-major.
1 2 3 4
smem_tensor_A ptr[16b](0x57b7b93248c0) o (_128,_32):(_1,_128) smem_tensor_B ptr[16b](0x57b7b93268d0) o (_128,_32):(_1,_128)
The shared memory configuration has to be compatible with the tiled MMA we configured later. Otherwise, the tiled MMA will not be able to process the MMA problem correctly because the cute::gemm API takes no predicates (for performance reasons).
MMA Atom Configuration
The MMA atom processes an MMA problem of size $M^{\prime} \times N^{\prime} \times K^{\prime}$ using a certain number of threads.
In our case, the MMA atom cute::SM80_16x8x16_F16F16F16F16_TN processes an MMA problem of size $M^{\prime} \times N^{\prime} \times K^{\prime} = 16 \times 8 \times 16$, and this MMA atom consists of a warp of 32 threads. The MMA atom is responsible for processing the $M^{\prime} \times K^{\prime} = 16 \times 16$ sub-matrix of matrix $A$ and the $K^{\prime} \times N^{\prime} = 16 \times 8$ sub-matrix of matrix $B$.
To process the MMA problem of size $M \times N \times K = 128 \times 128 \times 32$, theoretically we could tile the MMA atoms in one of the following configurations:
One single MMA atom processes $\frac{M}{M^{\prime}} \times \frac{N}{N^{\prime}} \times \frac{K}{K^{\prime}} = \frac{128}{16} \times \frac{128}{8} \times \frac{32}{16} = 8 \times 16 \times 2 = 256$ times.
$\frac{M}{M^{\prime}} \times \frac{N}{N^{\prime}} = \frac{128}{16} \times \frac{128}{8} = 8 \times 16 = 128$ MMA atoms to process in parallel, each MMA atom processes $\frac{K}{K^{\prime}} = \frac{32}{16} = 2$ times.
Something between the above two configurations.
Note that we did not configure parallelism in the $K$ dimension in the above configurations. But theoretically it’s also possible to do it, especially when $\frac{K}{K^{\prime}}$ is large.
The MMA atom also defines the layouts of the MMA matrices it works on. The thread-value layouts of the MMA matrices are usually very complicated. But fortunately, we could usually visualize them using the CuTe cute::print_latex or cute::print_svg functions.
In our case, the MMA atom cute::SM80_16x8x16_F16F16F16F16_TN defines the thread-value layouts of the MMA matrices as follows:
Apparently, each thread in the MMA atom in one process will access $2 \times 2 \times 2 = 8$ elements in the sub-matrix of matrix $A$ and $2 \times 2 = 4$ elements in the sub-matrix of matrix $B$, and produce $2 \times 2 = 4$ elements in the sub-matrix of matrix $C$.
Tiled MMA Configuration
The MMA tile configuration is a trade-off between resource and performance. The more MMA atoms we use, the higher parallelism we can achieve, at a cost of more threads to use and higher pressure for memory access. The fewer MMA atoms we use, the lower parallelism we can achieve, but we can save more threads and reduce the pressure for memory access. To achieve the best performance, we need to find the sweet spot between the two extremes.
// Configure tiled MMA. using MmaTraits = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>; using MmaAtomShape = MmaTraits::Shape_MNK; autoconst mma_atom = cute::MMA_Atom<MmaTraits>{}; autoconst mma_atom_shape = MmaAtomShape{}; // Repeating the mma atom along the M, N, and K dimensions. // This increases the number of threads to process the tiled MMA. constexprint MMA_LAYOUT_M{2}; constexprint MMA_LAYOUT_N{2}; constexprint MMA_LAYOUT_K{1}; auto mma_layout{cute::make_layout( cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}))}; // Repeating the mma processing along the M, N, and K dimensions. // This does not increase the number of threads to process the tiled MMA. // But the number of registers required for processing the tiled MMA // increases. constexprint NUM_MMA_TILE_M{1}; constexprint NUM_MMA_TILE_N{2}; constexprint NUM_MMA_TILE_K{1}; constexprint MMA_TILE_M{cute::get<0>(mma_atom_shape) * MMA_LAYOUT_M * NUM_MMA_TILE_M}; constexprint MMA_TILE_N{cute::get<1>(mma_atom_shape) * MMA_LAYOUT_N * NUM_MMA_TILE_N}; constexprint MMA_TILE_K{cute::get<2>(mma_atom_shape) * MMA_LAYOUT_K * NUM_MMA_TILE_K}; auto mma_tile{cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{})}; auto tiled_mma{cute::make_tiled_mma(mma_atom, mma_layout, mma_tile)};
In our tiled MMA configuration, we have 2 MMA atoms along the $M$ dimension, 2 MMA atoms along the $N$ dimension, and $1$ MMA along the $K$ dimension, i.e., no parallelism in the $K$ dimension. Therefore, we have $2 \times 2 \times 1 = 4$ MMA atoms in total. Because each MMA atom consists of a warp of 32 threads, we have the tiled MMA ThrLayoutVMNK = (_32,_2,_2,_1):(_1,_32,_64,_0). The number of threads involved in this tiled MMA is $32 \times 2 \times 2 \times 1 = 128$. This tiled MMA in process can solve an MMA problem of size $(2 \times M^{\prime}) \times (2 \times N^{\prime}) \times (1 \times K^{\prime}) = 32 \times 16 \times 16$. The same tiled MMA can process multiple times along different dimensions to solve larger MMA problems. In our case, we configured the tiled MMA to process along the $N$ dimension 2 times. As a result, with such permutation, we have the tiled MMA PermutationMNK: (_32,_32,_16) that solves an MMA problem of size $32 \times 32 \times 16$.
The tiled MMA layouts for the MMA matrices can also be visualized using the CuTe cute::print_latex or cute::print_svg functions.
Tiled MMA Memory Copy Partition
The tiled MMA could then be used as the building block for solving MMA problems of even larger size. Given large MMA matrix tensors, the tiled MMA can be decomposed into thread MMAs and each thread MMA has the methods partition_A, partition_B, and partition_C to partition the MMA matrix tensors into sub-matrices needed for the thread that can be processed by the tiled MMA. The partitioned MMA matrix tensors are then used as the input to the tiled MMA and larger MMA problems are solved by repeating the tiled MMA along different dimensions using the cute::gemm API.
1 2 3 4 5 6 7 8 9
auto thread_mma{tiled_mma.get_slice(THREAD_IDX)};
// Use no tiled copy from shared memory to register. auto thread_layout_C_smem_tensor_A_no_tiled_copy{ thread_mma.partition_A(smem_tensor_A)}; // (MMA, MMA_M, MMA_K) auto thread_layout_C_smem_tensor_B_no_tiled_copy{ thread_mma.partition_B(smem_tensor_B)}; // (MMA, MMA_N, MMA_K) auto thread_layout_C_smem_tensor_C_no_tiled_copy{ thread_mma.partition_C(smem_tensor_C)}; // (MMA, MMA_M, MMA_N)
In our case, the tiled MMA partitions the matrix A, matrix B, and matrix C as follows:
1 2 3 4 5 6
thread_layout_C_smem_tensor_A_no_tiled_copy ptr[16b](0x5c6c073e78c0) o ((_2,_2,_2),_4,_2):((_128,_8,_1024),_32,_2048) thread_layout_C_smem_tensor_B_no_tiled_copy ptr[16b](0x5c6c073e98d0) o ((_2,_2),_8,_2):((_128,_1024),_16,_2048) thread_layout_C_smem_tensor_C_no_tiled_copy ptr[16b](0x5c6c073eb8e0) o ((_2,_2),_4,_8):((_128,_8),_32,_2048)
Note that, again, each thread in the MMA atom in one process will access $2 \times 2 \times 2 = 8$ elements in the sub-matrix of matrix $A$ and $2 \times 2 = 4$ elements in the sub-matrix of matrix $B$, and produce $2 \times 2 = 4$ elements in the sub-matrix of matrix $C$. Such patterns are repeated $4$ times along the $M$ dimension and $2$ times along the $K$ dimension for matrix $A$, $8$ times along the $N$ dimension and $2$ times along the $K$ dimension for matrix $B$, and $4$ times along the $M$ dimension and $8$ times along the $N$ dimension for matrix $C$. The cute::gemm API will be responsible for accessing the desired data using the correct indices from MMA iterations.
Instead of accessing data from global or shared memory, sometimes we would like to access data from register for data reuse and better performance. The thread MMA also provides the methods partition_fragment_A, partition_fragment_B, and partition_fragment_C that configure the minimum amount of registers needed for the tiled MMA operations in each thread.
1 2 3 4 5 6 7
// Register tensors used for MMA. auto thread_layout_C_register_tensor_A{ thread_mma.partition_fragment_A(smem_tensor_A)}; // (MMA, MMA_M, MMA_K) auto thread_layout_C_register_tensor_B{ thread_mma.partition_fragment_B(smem_tensor_B)}; // (MMA, MMA_N, MMA_K) auto thread_layout_C_register_tensor_C{ thread_mma.partition_fragment_C(smem_tensor_C)}; // (MMA, MMA_M, MMA_N)
We could see that the shape of the register MMA tensors are exactly the same as the shape of the corresponding ones on shared memory or global memory. However, because of its compact striding comparing to the ones on shared memory or global memory, no registers are wasted. This is important because the number of registers is limited for each thread and the performance can be compromised if the number of registers configured is too large or there are wasted registers (the compiler might not be able to identify the wasted registers and optimize them out).
1 2 3 4 5 6
thread_layout_C_register_tensor_A ptr[16b](0x7ffc34e465f0) o ((_2,_2,_2),_4,_2):((_1,_2,_4),_8,_32) thread_layout_C_register_tensor_B ptr[16b](0x7ffc34e46670) o ((_2,_2),_8,_2):((_1,_2),_4,_32) thread_layout_C_register_tensor_C ptr[16b](0x7ffc34e466f0) o ((_2,_2),_4,_8):((_1,_2),_4,_16)
Tiled MMA Memory Tiled Copy Partition
In this case, there are some performance issues when the threads tries to access the data from shared memory or global memory. When each thread in the MMA atom tries to access the $2 \times 2 \times 2 = 8$ elements in the sub-matrix of matrix $A$ and the $2 \times 2 = 4$ elements in the sub-matrix of matrix $B$, and the $2 \times 2 = 4$ elements in the sub-matrix of matrix $C$, multiple transactions have to be performed because those data are not contiguous in memory, not even any two elements.
CUDA has special warp-level matrix load instruction ldmatrix that specifically addresses this problem and it is wrapped into the CuTe copy atoms. In our case, for cute::half_t data type, the copy atoms we used are cute::SM75_U16x8_LDSM_T for both matrix $A$ and matrix $B$. This copy atom consists of a warp of 32 threads from the ValLayoutSrc: (_32,_8):(_8,_1) we learned that each thread will copy $8$ contiguous elements from the source memory to the destination memory by abstraction. It is by abstraction because under the hood ldmatrix does not work in this way. The copy atom will copy $32 \times 8 = 128$ elements in total.
The copy atoms could be tiled with the tiled MMA for matrix $A$ sub-matrix and matrix $B$ sub-matrix tiled copy using cute::make_tiled_copy_A and cute::make_tiled_copy_B functions. In our case, because the tiled MMA (that has permutations) will solve an MMA problem of size $32 \times 32 \times 16$ (not $32 \times 16 \times 16$), the tiled copy will be responsible for a tile of size $32 \times 16$ for matrix $A$ and $32 \times 16$ for matrix $B$. In fact, without the permutation in the tiled MMA, the tile size for matrix $B$ would become $16 \times 16$ and $16 \times 16$ is smaller than $128$, the number of elements in total that the copy atom will copy, and such tiled copy will be prohibited by CuTe.
1 2 3 4 5 6
// Use tiled copy from shared memory to register. auto copy_atom_A = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, TA>{}; auto copy_atom_B = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, TB>{};
auto smem_tiled_copy_A{cute::make_tiled_copy_A(copy_atom_A, tiled_mma)}; auto smem_tiled_copy_B{cute::make_tiled_copy_B(copy_atom_B, tiled_mma)};
The tiled copy layouts could be printed using the cute::print_latex function as well. In our case, the left layout is the source layout and actually it’s misleading in our particular case. The value mapping between the left source layout and the right destination layout is incorrect. For example, the $T0V1$ value from the left layout will not be copied to the $T0V1$ value in the right layout. The left source layout shows that a thread copies $8$ elements from the same column because we need the threads to be located at the beginning of the column so that the column address can be correctly passed for ldmatrix.
The tiled copy can be decomposed into (abstracted) thread copies and each thread copy has the method partition_S to partition the source layout. Similarly, the thread copy source layout is misleading as one thread will not actually load $8$ contiguous elements from the source memory using the cute::SM75_U16x8_LDSM_T copy atom. But the CuTe tiled copy abstraction will at least ensure the consequent copy behavior is correct.
1 2 3 4
thread_layout_C_smem_tensor_A_tiled_copy ptr[16b](0x57b7b93248c0) o ((_8,_1),_4,_2):((_1,_0),_32,_2048) thread_layout_C_smem_tensor_B_tiled_copy ptr[16b](0x57b7b93268d0) o ((_8,_1),_4,_2):((_1,_0),_32,_2048)
To perform the tiled copy, there is still one problem. The layouts for the destination register tensors for tiled MMA are not immediately compatible with the source tensor layouts.
1 2 3 4
thread_layout_C_register_tensor_A ptr[16b](0x7ffc34e465f0) o ((_2,_2,_2),_4,_2):((_1,_2,_4),_8,_32) thread_layout_C_register_tensor_B ptr[16b](0x7ffc34e46670) o ((_2,_2),_8,_2):((_1,_2),_4,_32)
However, we realized that the sub-layout of the destination register tensor A (_2,_2,_2):(_1,_2,_4) is equivalent as the sub-layout of the source shared memory tensor A (_8,_1):(_1,_0). The sub-layout of the destination register tensor B ((_2,_2),_8:(_1,_2),_4) is equivalent as the sub-layout of the source shared memory tensor B ((_8,_1),_4):((_1,_0),_8). So before running tiled copy using the cute::copy API, the destination register tensors should be retiled using the retile_D method from thread copy.
After retiling the destination register tensors, the layouts become compatible with the source shared memory tensors for tiled copy.
1 2 3 4
thread_layout_C_register_tensor_A_copy_view ptr[16b](0x7ffd7a129350) o ((_8,_1),_4,_2):((_1,_0),_8,_32) thread_layout_C_register_tensor_B_copy_view ptr[16b](0x7ffd7a1293d0) o ((_8,_1),_4,_2):((_1,_0),_8,_32)
Without using tiled copy, to load $8$ elements from matrix $A$ for tiled MMA in each thread, 8 memory access instructions have to be performed. But with tiled copy, only 1 memory access instruction is needed. Thus, tiled copy from shared memory to register could usually improve the performance.
Of course, when the tiled copy is performed, there could be shared memory bank conflicts. We could try minimizing the shared memory bank conflicts using CuTe swizzle.