CuTe is a C++ template library that provides a high-level abstraction for layout and tensor operations in CUDA kernels. CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates, allowing the implementation to be more readable and maintainable.
Previously, I have created an article “CuTe Layout Algebra” on the mathematical foundations of CuTe. In this blog post, we will have some hands-on experience and have a better understanding of CuTe by implementing matrix transpose CUDA kernels.
CuTe Matrix Transpose
Matrix transpose CUDA kernels are probably the most example CUDA kernels that I have ever implemented. In my previous examples, the thread and data index mappings in the CUDA kernels are completely manual. There were also some hard-coded assumptions, such as each CUDA thread will only process one single element in matrix transpose to make the implementation easier and more human readable. To have both the configuration complexity and human readability in the implementation, we can create matrix transpose CUDA kernels using CuTe.
To transpose a matrix in a CUDA kernel, performing strided memory reads or writes in a warp is inevitable and it will lead to uncoalesced memory accesses, resulting in performance degradation. To mitigate the performance degradation, the strided memory reads or writes could be performed on shared memory instead of global memory. When the strided memory reads or writes are performed on shared memory, special optimizations have also to be performed to avoid shared memory bank conflicts.
All the CuTe matrix transpose CUDA kernels implemented in this article and their unit tests could be found from my CUTLASS Examples GitHub repository.
CuTe Naive Matrix Transpose
In the CuTe naive matrix transpose CUDA kernel implementation, we will not use shared memory. Two slightly different CUDA kernel variants have been implemented. One performs coalesced global memory reads and strided global memory writes, and the other performs strided global memory reads and coalesced global memory writes. It turns out that the difference between the implementations of the two variants is just one line of code.
CuTe Naive Matrix Transpose Implementation
The CuTe naive matrix transpose CUDA kernel implementation could also be found from my CUTLASS Examples GitHub repository.
template <classTENSOR_SRC, classTENSOR_DST, classTHREAD_LAYOUT> __global__ voidtranspose_naive(TENSOR_SRC tensor_src, TENSOR_DST tensor_dst_transposed, THREAD_LAYOUT) { using Element = typename TENSOR_SRC::value_type;
auto global_tile_src{tensor_src(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X) auto global_tile_dst_transposed{ tensor_dst_transposed(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X)
auto thread_global_tile_src{cute::local_partition( global_tile_src, THREAD_LAYOUT{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X) auto thread_global_tile_dst_transposed{cute::local_partition( global_tile_dst_transposed, THREAD_LAYOUT{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X)
// A 2D array of tuples that maps (x, y) to (x, y). autoconst identity_tensor{cute::make_identity_tensor(cute::make_shape( cute::size<0>(global_tile_src), cute::size<1>(global_tile_src)))}; autoconst thread_identity_tensor{ cute::local_partition(identity_tensor, THREAD_LAYOUT{}, threadIdx.x)}; auto fragment{cute::make_tensor_like(thread_global_tile_src)}; auto predicator{cute::make_tensor<bool>( cute::make_shape(cute::size<0>(fragment), cute::size<1>(fragment)))};
// Alternatively, we could just do the following instead. // cute::copy_if(predicator, thread_global_tile_src, // thread_global_tile_dst_transposed); }
template <typename T> cudaError_t launch_transpose_naive_base( T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, GlobalMemoryCoalescedAccessMode coalesced_access_mode, cudaStream_t stream) { autoconst tensor_shape{cute::make_shape(M, N)}; autoconst tensor_shape_transposed{cute::make_shape(N, M)};
// Input matrix: row-major M x N matrix. autoconst global_memory_layout_src{cute::make_layout( tensor_shape, cute::GenRowMajor{})}; // (M, N) : (N, 1) // Output matrix: row-major N x M matrix. autoconst global_memory_layout_dst{cute::make_layout( tensor_shape_transposed, cute::GenRowMajor{})}; // (N, M) : (M, 1) // Same output matrix, but different view: column-major M x N matrix. autoconst global_memory_layout_dst_transposed{cute::make_layout( tensor_shape, cute::GenColMajor{})}; // (M, N) : (1, M)
if (coalesced_access_mode == GlobalMemoryCoalescedAccessMode::Read) { CUTE_STATIC_ASSERT( TILE_SIZE_X::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT( TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y"); transpose_naive<<<grid_dim, thread_dim, 0, stream>>>( tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout); } else { CUTE_STATIC_ASSERT( TILE_SIZE_X::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT( TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y"); transpose_naive<<<grid_dim, thread_dim, 0, stream>>>( tiled_tensor_src, tiled_tensor_dst_transposed, thread_layout_transposed); }
returncudaGetLastError(); }
template <typename T> cudaError_t launch_transpose_naive_coalesced_read(T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream) { returnlaunch_transpose_naive_base(input_matrix, output_matrix, M, N, GlobalMemoryCoalescedAccessMode::Read, stream); }
template <typename T> cudaError_t launch_transpose_naive_coalesced_write(T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream) { returnlaunch_transpose_naive_base(input_matrix, output_matrix, M, N, GlobalMemoryCoalescedAccessMode::Write, stream); }
// Explicit instantiation. template cudaError_t launch_transpose_naive_coalesced_read<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_naive_coalesced_read<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
template cudaError_t launch_transpose_naive_coalesced_write<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_naive_coalesced_write<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
Matrix Layout
There are typically two ways to describe a matrix stored in linear storages: row-major and column-major. In row-major layout, the elements of each row are contiguous in memory, while in column-major layout, the elements of each column are contiguous in memory. Given a $M \times N$ input matrix $A$ with $M$ rows and $N$ columns in row-major layout, typically we want to transpose the input matrix $A$ to a $N \times M$ output matrix $A^{\top}$ with $N$ rows and $M$ columns that is also in row-major layout. In this case, the output matrix $A^{\top}$ can not only be viewed as a row-major $N \times M$ matrix but also as a column-major $M \times N$ matrix.
The matrix transpose operation maps the element $A_{i, j}$ from the input matrix $A$ to the element $A^{\top}_{j, i}$ in the output matrix $A^{\top}$. On one hand, if the input matrix is described using row-major layout, the input matrix $A$ is of shape $(M, N)$, and the element $A_{i, j}$ is stored at the coordinate $(i, j)$ in the row-major layout of $A$. The output matrix $A^{\top}$ is of shape $(N, M)$, and the element $A^{\top}_{j, i}$ is stored at the coordinate $(j, i)$ in the row-major layout of $A^{\top}$. On the other hand, if the input matrix is described using column-major layout, the same input matrix $A$ is of shape $(M, N)$, and the element $A_{i, j}$ is stored at the coordinate $(j, i)$ in the column-major layout of $A$. The output matrix $A^{\top}$ is of shape $(M, N)$, and the element $A^{\top}_{j, i}$ is stored at the coordinate $(i, j)$ in the column-major layout of $A^{\top}$.
Although being a little bit brain-twisting, matrix transpose maps an element at the coordinate $(i, j)$ in the row-major layout of a matrix to the element at the coordinate $(i, j)$ in the column-major layout of the output matrix. In CuTe, given an 1D input coordinate and the input matrix and the output matrix both have a shape of $(M, N)$ in the layout, the 1D input coordinate will always be mapped to the same natural coordinate in both the input matrix and the output matrix. When CuTe iterates over $M \times N$ 1D coordinates, the corresponding elements in the input matrix and the output matrix is in a relationship of transpose. This is the key reason why we have to use row-major layout and column-major layout to describe the input matrix and the output matrix, respectively, in CuTe. Otherwise, if the input matrix and the output matrix are both described using the same layout, when CuTe iterates over $M \times N$ 1D coordinates, the corresponding elements in the input matrix and the output matrix will not be in a relationship of transpose.
// Input matrix: row-major M x N matrix. autoconst global_memory_layout_src{cute::make_layout( tensor_shape, cute::GenRowMajor{})}; // (M, N) : (N, 1) // Output matrix: row-major N x M matrix. autoconst global_memory_layout_dst{cute::make_layout( tensor_shape_transposed, cute::GenRowMajor{})}; // (N, M) : (M, 1) // Same output matrix, but different view: column-major M x N matrix. autoconst global_memory_layout_dst_transposed{cute::make_layout( tensor_shape, cute::GenColMajor{})}; // (M, N) : (1, M)
To accelerate matrix transpose for large problems, we will have to divide the input matrix and the output matrix into smaller tiles and compute the transpose of each tile in parallel. In this example, we divide the input matrix and the output matrix into tiles of shape $(bM, bN)$, where $bM$ and $bN$ are the number of rows and columns in each tile, respectively. Both the input matrix and the output matrix will be divided into $\left\lceil \frac{M}{bM} \right\rceil \times \left\lceil \frac{N}{bN} \right\rceil$ tiles. The matrix transpose in each tile is independent and can be processed in parallel.
The divided input matrix and the divided output matrix now have new layouts, whose shapes are both $\left((bM, bN), \left\lceil \frac{M}{bM} \right\rceil \times \left\lceil \frac{N}{bN} \right\rceil\right)$. The row-major and column-major notations are no longer applicable for describing the divided matrices, because the shapes now have 3 modes, i.e., a rank of 3. To describe the storage layout of a tensor that has higher rank (any rank), CuTe uses stride. In our particular problem, it’s not too important, because CuTe automatically handles those concepts for us. In other problems, it might not be the case though.
1 2 3 4 5 6 7 8 9 10 11
using TILE_SIZE_X = cute::Int<64>; // bN using TILE_SIZE_Y = cute::Int<32>; // bM
autoconst tiled_tensor_src{cute::tiled_divide( tensor_src, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M / // TILE_SIZE_Y, N / TILE_SIZE_X) autoconst tiled_tensor_dst_transposed{cute::tiled_divide( tensor_dst_transposed, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M // / TILE_SIZE_Y, N / TILE_SIZE_X)
CUDA Thread Block Layout and Coalesced Memory Access
Each tile of the input matrix and the output matrix is processed by a CUDA thread block that consists of multiple CUDA threads. In our case, we use a thread block of shape $(tM, tN)$ and row-major layout or $(tN, tM)$ and column-major layout. The number of threads in a CUDA thread block is $tM \times tN$. The number of CUDA thread blocks to issue is, obviously, just the number of tiles, i.e., $\left\lceil \frac{M}{bM} \right\rceil \times \left\lceil \frac{N}{bN} \right\rceil$. This is feasible because $bM$ and $bN$, $tM$ and $tN$, are all compile-time constants.
Because the input matrix and its tiles are of row-major layout, and the output matrix and its tiles are of column-major layout, when the thread block is of row-major layout, each warp in the thread block will read from the input matrix on global memory in a coalesced fashion but write to the output matrix on global memory in a strided fashion. Similarly, when the thread block is of column-major layout, each warp in the thread block will read from the input matrix on global memory in a strided fashion but write to the output matrix on global memory in a coalesced fashion.
Inner-partition has been performed previously, where we divide the input matrix and the output matrix into tiles. Usually inner-partition is performed at the CUDA thread block level that distributes the large problems into smaller problems that can be solved by a single CUDA thread block.
Outer-partition is usually performed at the CUDA thread level that distributes the smaller problems into even smaller problems that can be solved by a single CUDA thread. There is a different between inner-partition and outer-partition, without understanding which the implementation can work correctly.
Suppose we have a CuTe layout $(8, 4) : (4, 1)$ and a tile layout $(4, 2) : (2, 1)$. Inner-partition will result in a layout of shape $\left((4, 2), \frac{8}{4}, \frac{4}{2}\right) = \left((4, 2), 2, 2\right)$, and outer-partition will result in a a layout of shape $\left(\left(\frac{8}{4}, \frac{4}{2}\right), 4, 2\right) = \left((2, 2), 4, 2\right)$. Assuming the partitions are accessed using the last two modes layout, the inner-partition layout has 4 partitions whereas the out-partition layout has 8 partitions. The starting coordinates of inner-partition and outer-partition are also different. In this case, the starting coordinates of inner-partition is $(0, 0)$, $(4, 0)$, $(0, 2)$, and $(4, 2)$, whereas the starting coordinates of outer-partition is $(0, 0)$, $(1, 0)$, $(2, 0)$, $(3, 0)$, $(0, 1)$, $(1, 1)$, $(2, 1)$, and $(3, 1)$. Outer-partition is usually performed at the CUDA thread level because all the consecutive threads in a warp, if accessing a piece of contiguous data synergistically especially on global memory, can have a better performance because of the CUDA coalesced memory access.
In each partition, the partition tensor will follow the layout algebra and apply the correct strides to access the data during iteration.
1 2 3 4 5 6 7 8 9 10 11 12 13
auto global_tile_src{tensor_src(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X) auto global_tile_dst_transposed{ tensor_dst_transposed(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X)
auto thread_global_tile_src{cute::local_partition( global_tile_src, THREAD_LAYOUT{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X) auto thread_global_tile_dst_transposed{cute::local_partition( global_tile_dst_transposed, THREAD_LAYOUT{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X)
Predicates and Boundary Checking
CUDA memory access boundary checking is critical in a CUDA kernel in practice, when the problem distribution is not perfect. In CuTe, CUDA memory access boundary checking is performed via predicates. In our particular, we could query the matrix sizes $M$ and $N$ from the strides of the tiled tensor. It’s also more common to just pass these two values to the CUDA kernel.
During the iteration of the CuTe tensor, the iterator has to know its 2D coordinate and check if the element it is about to access is within the boundary. So we will create a 2D identity tensor (it’s an 2D array of tuples though) whose shape is exactly the same as the partition tensor from the global memory. The 2D identity tensor takes a 2D coordinate as input and produces the same 2D coordinate as output. If the partition tensor abd the identity tensor are iterated together, the iterator could get the information of its current coordinate within the partition tensor, making boundary checking possible. At the CUDA thread level, the 2D identity tensor is further partitioned into a 2D thread identity tensor according the same thread layout that is used for partitioning the data. Then the predicates used for accessing the partitioned input tensor and the output tensor can be prepared using the 2D coordinates of the iterator, the partitioned tensor index, the partitioned tensor and the original full tensor shape information.
// A 2D array of tuples that maps (x, y) to (x, y). autoconst identity_tensor{cute::make_identity_tensor(cute::make_shape( cute::size<0>(global_tile_src), cute::size<1>(global_tile_src)))}; autoconst thread_identity_tensor{ cute::local_partition(identity_tensor, THREAD_LAYOUT{}, threadIdx.x)}; auto fragment{cute::make_tensor_like(thread_tile_src)}; auto predicator{cute::make_tensor<bool>( cute::make_shape(cute::size<0>(fragment), cute::size<1>(fragment)))};
Using predicates and performing boundary checking can have degrade CUDA kernel performance, because the warp instruction, such as load from global memory, has to stall before all the predicates from all the threads in the warp are evaluated. To accelerate the computing the most, usually specialized kernels are used for each of the problem configurations and boundary checking are eliminated from the CUDA kernel.
So instead of using cute::copy_if, cute::copy should be used.
In the CuTe matrix transpose CUDA kernel implementation using shared memory, we will perform strided memory reads and writes on shared memory instead of global memory. Using shared memory naively will result in shared memory bank conflicts when performing strided memory reads or writes on shared memory, which will degrade the performance. To mitigate the shared memory bank conflicts, we will also perform special optimizations, such as shared memory padding and swizzling.
CuTe Matrix Transpose Using Shared Memory Implementation
The CuTe matrix transpose using shared memory CUDA kernel implementation could also be found from my CUTLASS Examples GitHub repository.
template <classTENSOR_SRC, classTENSOR_DST, classSHARED_MEMORY_LAYOUT_SRC, classSHARED_MEMORY_LAYOUT_DST, classTHREAD_LAYOUT_SRC, classTHREAD_LAYOUT_DST> __global__ void transpose_shared_memory(TENSOR_SRC tensor_src, TENSOR_DST tensor_dst, SHARED_MEMORY_LAYOUT_SRC, SHARED_MEMORY_LAYOUT_DST, THREAD_LAYOUT_SRC, THREAD_LAYOUT_DST) { using Element = typename TENSOR_SRC::value_type; CUTE_STATIC_ASSERT(cute::size(SHARED_MEMORY_LAYOUT_SRC{}) == cute::size(SHARED_MEMORY_LAYOUT_DST{}), "SHARED_MEMORY_LAYOUT_SRC and SHARED_MEMORY_LAYOUT_DST " "must have the same size."); __shared__ Element shared_memory[cute::size(SHARED_MEMORY_LAYOUT_SRC{})];
auto tensor_cache_src{cute::make_tensor(cute::make_smem_ptr(shared_memory), SHARED_MEMORY_LAYOUT_SRC{})}; auto tensor_cache_dst{cute::make_tensor(cute::make_smem_ptr(shared_memory), SHARED_MEMORY_LAYOUT_DST{})};
auto global_tile_src{tensor_src(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X) auto global_tile_dst{tensor_dst(cute::make_coord(cute::_, cute::_), blockIdx.y, blockIdx.x)}; // (TILE_SIZE_Y, TILE_SIZE_X)
auto thread_global_tile_src{cute::local_partition( global_tile_src, THREAD_LAYOUT_SRC{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X) auto thread_global_tile_dst{cute::local_partition( global_tile_dst, THREAD_LAYOUT_DST{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_X, THREAD_VALUE_SIZE_Y)
auto thread_shared_tile_src{cute::local_partition( tensor_cache_src, THREAD_LAYOUT_SRC{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_Y, THREAD_VALUE_SIZE_X) auto thread_shared_tile_dst{cute::local_partition( tensor_cache_dst, THREAD_LAYOUT_DST{}, threadIdx.x)}; // (THREAD_VALUE_SIZE_X, THREAD_VALUE_SIZE_Y)
// A 2D array of tuples that maps (x, y) to (x, y). autoconst identity_tensor_src{cute::make_identity_tensor(cute::make_shape( cute::size<0>(global_tile_src), cute::size<1>(global_tile_src)))}; autoconst thread_identity_tensor_src{cute::local_partition( identity_tensor_src, THREAD_LAYOUT_SRC{}, threadIdx.x)}; auto predicator_src{cute::make_tensor<bool>( cute::make_shape(cute::size<0>(thread_global_tile_src), cute::size<1>(thread_global_tile_src)))};
template <typename T> cudaError_t launch_transpose_shared_memory_bank_conflict_base( T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, SharedMemoryBankConflictAccessMode bank_conflict_access_mode, cudaStream_t stream) { autoconst tensor_shape{cute::make_shape(M, N)}; autoconst tensor_shape_transposed{cute::make_shape(N, M)};
// Input matrix: row-major M x N matrix. autoconst global_memory_layout_src{cute::make_layout( tensor_shape, cute::GenRowMajor{})}; // (M, N) : (N, 1) // Output matrix: row-major N x M matrix. autoconst global_memory_layout_dst{cute::make_layout( tensor_shape_transposed, cute::GenRowMajor{})}; // (N, M) : (M, 1) // Same output matrix, but different view: column-major M x N matrix. autoconst global_memory_layout_dst_transposed{cute::make_layout( tensor_shape, cute::GenColMajor{})}; // (M, N) : (1, M)
autoconst tiled_tensor_src{cute::tiled_divide( tensor_src, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M / // TILE_SIZE_Y, N / TILE_SIZE_X) autoconst tiled_tensor_dst{cute::tiled_divide( tensor_dst, block_shape_transposed)}; // ((TILE_SIZE_X, TILE_SIZE_Y), N // / TILE_SIZE_X, M / TILE_SIZE_Y) autoconst tiled_tensor_dst_transposed{cute::tiled_divide( tensor_dst_transposed, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M // / TILE_SIZE_Y, N / TILE_SIZE_X)
using THREAD_BLOCK_SIZE_X = cute::Int<32>; // tN using THREAD_BLOCK_SIZE_Y = cute::Int<8>; // tM
CUTE_STATIC_ASSERT(TILE_SIZE_X::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT(TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y");
template <typename T> cudaError_t launch_transpose_shared_memory_bank_conflict_read( T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream) { returnlaunch_transpose_shared_memory_bank_conflict_base( input_matrix, output_matrix, M, N, SharedMemoryBankConflictAccessMode::Read, stream); }
template <typename T> cudaError_t launch_transpose_shared_memory_bank_conflict_write( T const* input_matrix, T* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream) { returnlaunch_transpose_shared_memory_bank_conflict_base( input_matrix, output_matrix, M, N, SharedMemoryBankConflictAccessMode::Write, stream); }
// Input matrix: row-major M x N matrix. autoconst global_memory_layout_src{cute::make_layout( tensor_shape, cute::GenRowMajor{})}; // (M, N) : (N, 1) // Output matrix: row-major N x M matrix. autoconst global_memory_layout_dst{cute::make_layout( tensor_shape_transposed, cute::GenRowMajor{})}; // (N, M) : (M, 1) // Same output matrix, but different view: column-major M x N matrix. autoconst global_memory_layout_dst_transposed{cute::make_layout( tensor_shape, cute::GenColMajor{})}; // (M, N) : (1, M)
autoconst tiled_tensor_src{cute::tiled_divide( tensor_src, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M / // TILE_SIZE_Y, N / TILE_SIZE_X) autoconst tiled_tensor_dst{cute::tiled_divide( tensor_dst, block_shape_transposed)}; // ((TILE_SIZE_X, TILE_SIZE_Y), N // / TILE_SIZE_X, M / TILE_SIZE_Y) autoconst tiled_tensor_dst_transposed{cute::tiled_divide( tensor_dst_transposed, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M // / TILE_SIZE_Y, N / TILE_SIZE_X)
using THREAD_BLOCK_SIZE_X = cute::Int<32>; // tN using THREAD_BLOCK_SIZE_Y = cute::Int<8>; // tM
CUTE_STATIC_ASSERT(TILE_SIZE_X::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT(TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y");
// Input matrix: row-major M x N matrix. autoconst global_memory_layout_src{cute::make_layout( tensor_shape, cute::GenRowMajor{})}; // (M, N) : (N, 1) // Output matrix: row-major N x M matrix. autoconst global_memory_layout_dst{cute::make_layout( tensor_shape_transposed, cute::GenRowMajor{})}; // (N, M) : (M, 1) // Same output matrix, but different view: column-major M x N matrix. autoconst global_memory_layout_dst_transposed{cute::make_layout( tensor_shape, cute::GenColMajor{})}; // (M, N) : (1, M)
// Inspect if the swizzling reduces the shared memory bank conflicts. // print_shared_memory_bank_ids(shared_memory_layout_swizzled_src);
autoconst tiled_tensor_src{cute::tiled_divide( tensor_src, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M / // TILE_SIZE_Y, N / TILE_SIZE_X) autoconst tiled_tensor_dst{cute::tiled_divide( tensor_dst, block_shape_transposed)}; // ((TILE_SIZE_X, TILE_SIZE_Y), N // / TILE_SIZE_X, M / TILE_SIZE_Y) autoconst tiled_tensor_dst_transposed{cute::tiled_divide( tensor_dst_transposed, block_shape)}; // ((TILE_SIZE_Y, TILE_SIZE_X), M // / TILE_SIZE_Y, N / TILE_SIZE_X)
using THREAD_BLOCK_SIZE_X = cute::Int<32>; // tN using THREAD_BLOCK_SIZE_Y = cute::Int<8>; // tM
CUTE_STATIC_ASSERT(TILE_SIZE_X::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT(TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y");
// Explicit instantiation. template cudaError_t launch_transpose_shared_memory_bank_conflict_read<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_shared_memory_bank_conflict_read<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
template cudaError_t launch_transpose_shared_memory_bank_conflict_write<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_shared_memory_bank_conflict_write<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
template cudaError_t launch_transpose_shared_memory_padded<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_shared_memory_padded<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
template cudaError_t launch_transpose_shared_memory_swizzled<float>( floatconst* input_matrix, float* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream); template cudaError_t launch_transpose_shared_memory_swizzled<double>( doubleconst* input_matrix, double* output_matrix, unsignedint M, unsignedint N, cudaStream_t stream);
Shared Memory Layout and CUDA Thread Block Layout
Because the strided memory access will be performed on shared memory, the global memory reads and writes can be fully coalesced. Then we have two options about how to perform the strided memory reads or writes on shared memory. The first option is to perform matrix transpose when reading from global memory to shared memory, and then perform matrix copy from shared memory to global memory, resulting in strided memory writes on shared memory. The second option is to perform matrix copy when reading from global memory to shared memory, and then perform matrix transpose when writing from shared memory to global memory, resulting in strided memory reads on shared memory.
To implement the first option, the shared memory layout has to be column-major if the input matrix layout is row-major. Two different CUDA thread block layouts are used for reading from global memory to shared memory and writing from shared memory to global memory. The first CUDA thread block layout is row-major if the input matrix layout is row-major, resulting in coalesced memory reads from global memory and strided memory writes to shared memory. The second CUDA thread block layout is column-major if the input matrix layout is row-major, which is the same as the output matrix layout, resulting in coalesced memory reads from shared memory and coalesced memory writes to global memory.
To implement the second option, the shared memory layout has to be row-major if the input matrix layout is row-major. Two different CUDA thread block layouts are used for reading from global memory to shared memory and writing from shared memory to global memory. The first CUDA thread block layout is row-major if the input matrix layout is row-major, resulting in coalesced memory reads from global memory and coalesced memory writes to shared memory. The second CUDA thread block layout is column-major if the input matrix layout is row-major, which is the same as the output matrix layout, resulting in strided memory reads from shared memory and coalesced memory writes to global memory.
The strided reads and writes on shared memory will result in as severe as 32-way shared memory bank conflicts. On certain platforms, this will significantly reduce the performance.
using THREAD_BLOCK_SIZE_X = cute::Int<32>; // tN using THREAD_BLOCK_SIZE_Y = cute::Int<8>; // tM
CUTE_STATIC_ASSERT(TILE_SIZE_X::value % THREAD_BLOCK_SIZE_X::value == 0, "TILE_SIZE_X must be divisible by THREAD_BLOCK_SIZE_X"); CUTE_STATIC_ASSERT(TILE_SIZE_Y::value % THREAD_BLOCK_SIZE_Y::value == 0, "TILE_SIZE_Y must be divisible by THREAD_BLOCK_SIZE_Y");
One typical mistake I would make in the implementation is to use the same predicate for both reading from global memory to shared memory and writing from shared memory to global memory because the shapes of the global memory input matrix tile layout, the global memory output matrix tile layout, the shared memory layout are the same. We could not reuse the predicate because the thread layouts for reading from global memory to shared memory and writing from shared memory to global memory are different. Therefore, even for the same thread, different identity tuples are assigned for reading from global memory to shared memory and writing from shared memory to global memory, and we have to use two sets of predicates.
Because shared memory is used as a cache to store the intermediate matrix tile for transpose and all the threads in the same thread block are synergistically reading the matrix tile from global memory to shared memory, before the writing the matrix tile from shared memory to global memory, we have to make sure all the threads in the same thread block have finished reading the matrix tile from global memory to shared memory. In addition to the commonly used __syncthreads(), cute::cp_async_fence() and cute::cp_async_wait<0>() are also used in CuTe for thread block synchronization. This is because cute::copy_if and cute::copy can be asynchronous operations on SM80 and above platforms. cute::cp_async_fence() and cute::cp_async_wait<0>() are no-ops on platforms lower than SM80.
The shared memory padding is a common trick to avoid shared memory bank conflicts when a warp of threads is accessing shared memory.
In our case, assuming we have the strided memory read on shared memory. Then instead of using the shared memory layout of $(bM, bN) : (bN, 1)$, the padded shared memory layout should be $(bM, bN) : (bN + 1, 1)$. Notice that the shared memory shape remains unchanged, but the stride of the shared memory layout gets changed, resulting in the shared memory cosize, i.e. the shared memory that needs to be allocated, is also changed. Using the shared memory layout of $(bM, bN + 1) : (bN + 1, 1)$ is incorrect.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
using TILE_SIZE_X = cute::Int<64>; // bN using TILE_SIZE_X_PADDED = cute::Int<65>; // bN + 1 using TILE_SIZE_Y = cute::Int<32>; // bM
Because the shared memory shape remains the same, the CUDA kernel previously implemented can be just reused.
Shared Memory Swizzling
The shared memory swizzling is another common trick to avoid shared memory bank conflicts when a warp of threads is accessing shared memory. Comparing to the shared memory padding, the shared memory swizzling will not allocate extract shared memory that is not used, and is therefore a more favorable approach. However, the formula of shared memory swizzling is very brain-twisting and the implementation can be very error-prone. In CuTe, fortunately, the shared memory swizzling is implemented as a simple template class, and the shared memory swizzling can be easily applied to the shared memory layout via CuTe layout composition. After verifying the shared memory swizzled bank ids are meeting our requirement, we could just reuse the CUDA kernel previously implemented for the shared memory swizzled layout.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
using TILE_SIZE_X = cute::Int<64>; // bN using TILE_SIZE_Y = cute::Int<32>; // bM constexprint NUM_BITS_X{6}; constexprint NUM_BITS_Y{5};
In our case, given the shared memory of shape $(bM, bN) : (bN, 1) = (32, 64) : (64, 1)$, the shared memory bank id before and after applying swizzling are as follows, respectively.
We could see that when a warp of threads read or write on the column of the shared memory of row-major, it’s shared memory bank conflict free.
Performances
The following tables show the performance measurements of the matrix transpose CUDA kernels on NVIDIA GeForce RTX 3090.
Kernel Name
Latency (ms)
Effective Bandwidth (GB/s)
Peak Bandwidth Percentage (%)
Naive Coalesced Read
7.75992
257.734
27.5329
Naive Coalesced Write
3.12904
639.174
68.2809
Shared Memory Bank Conflict Read
2.98797
669.351
71.5045
Shared Memory Bank Conflict Write
2.9763
671.976
71.7849
Shared Memory Padded
2.98273
670.527
71.6302
Shared Memory Swizzled
2.92828
682.994
72.962
It’s somewhat surprising that except the native coalesced read CUDA kernel, all the other CUDA kernels have similar effective bandwidth and the bandwidth is very close to the ones that can be achieved in practice. Whether having shared memory bank conflicts in this CUDA kernel does not affect the performance significantly, because the performance bottleneck is in the global memory access.
After profiling using NVIDIA Nsight Compute, we could confirm that the global memory access is not fully coalesced for the native coalesced read and the native coalesced write CUDA kernels, shared memory bank load conflicts present in the shared memory bank conflict read CUDA kernel, shared memory bank store conflicts present in the shared memory bank conflict write CUDA kernel, and shared memory bank conflicts are free in the shared memory padded and shared memory swizzled CUDA kernels.