CuTe ldmatrix
Introduction
The CUDA PTX instruction ldmatrix
is a collective load instruction that loads one or more matrices from shared memory to registers for mma
instruction. To facilitate the use of the ldmatrix
instruction, CuTe also provides a set of wrappers of the ldmatrix
instruction.
In this blog post, I would like to discuss the original ldmatrix
instruction and its CuTe wrappers.
CUDA PTX ldmatrix
CUDA PTX ldmatrix Instruction
The ldmatrix
instruction is documented in the “Parallel Thread Execution ISA” documentation. I would like to slightly elaborate and emphasize on some details based on the original documentation.
1 | ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p]; |
The destination operand r
is a brace-enclosed vector expression consisting of 1, 2, or 4 32-bit registers as per the value of .num
. Each component of the vector expression holds a fragment from the corresponding matrix.
The source operand p
is a 32-bit integer register containing the address of the first value of the row of the row-major matrix or the address of the first value of the column of the column-major matrix. The address must be 16-byte aligned.
.num |
Threads 0-7 | Threads 8-15 | Threads 16-23 | Threads 24-31 |
---|---|---|---|---|
.x1 |
addr0–addr7 | - | - | - |
.x2 |
addr0–addr7 | addr8–addr15 | - | - |
.x4 |
addr0–addr7 | addr8–addr15 | addr16–addr23 | addr24–addr31 |
8 contiguous 16-bit values, i.e., a 16-byte value, starting from the address specified are loaded into one 32-bit register from 4 threads, respectively. If .num
is .x1
, only 8 rows or columns are loaded into a warp of 32 threads. That’s why the .shape
is always .m8n8
. If .num
is .x2
, 16 rows or columns are loaded into a warp of 32 threads, where each thread now needs two 32-bit registers to hold the loaded data. If .num
is .x4
, 32 rows or columns are loaded into a warp of 32 threads, where each thread now needs four 32-bit registers to hold the loaded data.
CUDA PTX ldmatrix Thread-Value Mapping
The ldmatrix
fragment layout for one, two, and four 8x8 row-major matrix with 16-bit elements could be illustrated as follows.
When .num
is .x1
, one 8x8 row-major matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the matrix will be specified by the threads 0, 1, 2, …, 7 respectively. In the 8x8 row-major matrix, the two 16-bit values at (0,0) and (0,1), V0 and V1, will be loaded by the thread 0 (T0) into the 32-bit register R0 owned by T0, the two 16-bit values at (0,2) and (0,3), V0 and V1, will be loaded by the thread 1 (T1) into the 32-bit register R0 owned by T1, and so on.
Row / Col Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
0 | T0V0:R0 | T0V1:R0 | T1V0:R0 | T1V1:R0 | T2V0:R0 | T2V1:R0 | T3V0:R0 | T3V1:R0 |
1 | T4V0:R0 | T4V1:R0 | T5V0:R0 | T5V1:R0 | T6V0:R0 | T6V1:R0 | T7V0:R0 | T7V1:R0 |
2 | T8V0:R0 | T8V1:R0 | T9V0:R0 | T9V1:R0 | T10V0:R0 | T10V1:R0 | T11V0:R0 | T11V1:R0 |
3 | T12V0:R0 | T12V1:R0 | T13V0:R0 | T13V1:R0 | T14V0:R0 | T14V1:R0 | T15V0:R0 | T15V1:R0 |
4 | T16V0:R0 | T16V1:R0 | T17V0:R0 | T17V1:R0 | T18V0:R0 | T18V1:R0 | T19V0:R0 | T19V1:R0 |
5 | T20V0:R0 | T20V1:R0 | T21V0:R0 | T21V1:R0 | T22V0:R0 | T22V1:R0 | T23V0:R0 | T23V1:R0 |
6 | T24V0:R0 | T24V1:R0 | T25V0:R0 | T25V1:R0 | T26V0:R0 | T26V1:R0 | T27V0:R0 | T27V1:R0 |
7 | T28V0:R0 | T28V1:R0 | T29V0:R0 | T29V1:R0 | T30V0:R0 | T30V1:R0 | T31V0:R0 | T31V1:R0 |
When .num
is .x1
, two 8x8 row-major matrix will be loaded. In addition to the first 8x8 matrix loaded as described above, a second 8x8 matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the second matrix will be specified by the threads 8, 9, 10, …, 15 respectively. In the second 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V2 and V3, will be loaded by the thread 0 (T0) into the 32-bit register R1 owned by T0, the two 16-bit values at (0,2) and (0,3), V2 and V3, will be loaded by the thread 1 (T1) into the 32-bit register R1 owned by T1, and so on.
Row / Col Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
0 | T0V2:R1 | T0V3:R1 | T1V2:R1 | T1V3:R1 | T2V2:R1 | T2V3:R1 | T3V2:R1 | T3V3:R1 |
1 | T4V2:R1 | T4V3:R1 | T5V2:R1 | T5V3:R1 | T6V2:R1 | T6V3:R1 | T7V2:R1 | T7V3:R1 |
2 | T8V2:R1 | T8V3:R1 | T9V2:R1 | T9V3:R1 | T10V2:R1 | T10V3:R1 | T11V2:R1 | T11V3:R1 |
3 | T12V2:R1 | T12V3:R1 | T13V2:R1 | T13V3:R1 | T14V2:R1 | T14V3:R1 | T15V2:R1 | T15V3:R1 |
4 | T16V2:R1 | T16V3:R1 | T17V2:R1 | T17V3:R1 | T18V2:R1 | T18V3:R1 | T19V2:R1 | T19V3:R1 |
5 | T20V2:R1 | T20V3:R1 | T21V2:R1 | T21V3:R1 | T22V2:R1 | T22V3:R1 | T23V2:R1 | T23V3:R1 |
6 | T24V2:R1 | T24V3:R1 | T25V2:R1 | T25V3:R1 | T26V2:R1 | T26V3:R1 | T27V2:R1 | T27V3:R1 |
7 | T28V2:R1 | T28V3:R1 | T29V2:R1 | T29V3:R1 | T30V2:R1 | T30V3:R1 | T31V2:R1 | T31V3:R1 |
When .num
is .x4
, four 8x8 row-major matrix will be loaded. In addition to the first two 8x8 matrix loaded as described above, a third and a fourth 8x8 matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the third matrix will be specified by the threads 16, 17, 18, …, 23 respectively. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the fourth matrix will be specified by the threads 24, 25, 26, …, 31 respectively. In the third 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V4 and V5, will be loaded by the thread 0 (T0) into the 32-bit register R2 owned by T0, the two 16-bit values at (0,2) and (0,3), V4 and V5, will be loaded by the thread 1 (T1) into the 32-bit register R2 owned by T1, and so on. In the fourth 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V6 and V7, will be loaded by the thread 0 (T0) into the 32-bit register R3 owned by T0, the two 16-bit values at (0,2) and (0,3), V6 and V7, will be loaded by the thread 1 (T1) into the 32-bit register R3 owned by T1, and so on.
Row / Col Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
0 | T0V4:R2 | T0V5:R2 | T1V4:R2 | T1V5:R2 | T2V4:R2 | T2V5:R2 | T3V4:R2 | T3V5:R2 |
1 | T4V4:R2 | T4V5:R2 | T5V4:R2 | T5V5:R2 | T6V4:R2 | T6V5:R2 | T7V4:R2 | T7V5:R2 |
2 | T8V4:R2 | T8V5:R2 | T9V4:R2 | T9V5:R2 | T10V4:R2 | T10V5:R2 | T11V4:R2 | T11V5:R2 |
3 | T12V4:R2 | T12V5:R2 | T13V4:R2 | T13V5:R2 | T14V4:R2 | T14V5:R2 | T15V4:R2 | T15V5:R2 |
4 | T16V4:R2 | T16V5:R2 | T17V4:R2 | T17V5:R2 | T18V4:R2 | T18V5:R2 | T19V4:R2 | T19V5:R2 |
5 | T20V4:R2 | T20V5:R2 | T21V4:R2 | T21V5:R2 | T22V4:R2 | T22V5:R2 | T23V4:R2 | T23V5:R2 |
6 | T24V4:R2 | T24V5:R2 | T25V4:R2 | T25V5:R2 | T26V4:R2 | T26V5:R2 | T27V4:R2 | T27V5:R2 |
7 | T28V4:R2 | T28V5:R2 | T29V4:R2 | T29V5:R2 | T30V4:R2 | T30V5:R2 | T31V4:R2 | T31V5:R2 |
Row / Col Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
0 | T0V6:R3 | T0V7:R3 | T1V6:R3 | T1V7:R3 | T2V6:R3 | T2V7:R3 | T3V6:R3 | T3V7:R3 |
1 | T4V6:R3 | T4V7:R3 | T5V6:R3 | T5V7:R3 | T6V6:R3 | T6V7:R3 | T7V6:R3 | T7V7:R3 |
2 | T8V6:R3 | T8V7:R3 | T9V6:R3 | T9V7:R3 | T10V6:R3 | T10V7:R3 | T11V6:R3 | T11V7:R3 |
3 | T12V6:R3 | T12V7:R3 | T13V6:R3 | T13V7:R3 | T14V6:R3 | T14V7:R3 | T15V6:R3 | T15V7:R3 |
4 | T16V6:R3 | T16V7:R3 | T17V6:R3 | T17V7:R3 | T18V6:R3 | T18V7:R3 | T19V6:R3 | T19V7:R3 |
5 | T20V6:R3 | T20V7:R3 | T21V6:R3 | T21V7:R3 | T22V6:R3 | T22V7:R3 | T23V6:R3 | T23V7:R3 |
6 | T24V6:R3 | T24V7:R3 | T25V6:R3 | T25V7:R3 | T26V6:R3 | T26V7:R3 | T27V6:R3 | T27V7:R3 |
7 | T28V6:R3 | T28V7:R3 | T29V6:R3 | T29V7:R3 | T30V6:R3 | T30V7:R3 | T31V6:R3 | T31V7:R3 |
Transposed and Non-Transposed
Because ldmatrix
is usually used for loading matrix fragments from shared memory to registers for MMA operations and MMA operations provide instructions for matrices which are transposed or non-transposed specified as follows, if the layout of the matrix on shared memory is different from the required layout by MMA operations, we need to transpose the matrix when loading it from shared memory to registers.
BLAS / CuTe | A Majorness | A Layout | B Majorness | B Layout |
---|---|---|---|---|
NT | M-major | (M,K):(1,ldA) |
N-major | (N,K):(1,ldB) |
TN | K-major | (M,K):(ldA,1) |
K-major | (N,K):(ldB,1) |
NN | M-major | (M,K):(1,ldA) |
K-major | (N,K):(ldB,1) |
TT | K-major | (M,K):(ldA,1) |
N-major | (N,K):(1,ldB) |
For example, suppose we want to use the cute::SM80_16x8x16_F16F16F16F16_TN
MMA operation, which requires the matrix A to be $M \times K$ row-major and the matrix B to be $N \times K$ row-major. However, the matrix A is stored as $M \times K$ column-major on shared memory and the matrix B is stored as $N \times K$ column-major on shared memory. In this case, we will have to load the matrix A from shared memory to registers by transposing it so that it becomes $K \times M$ row-major, i.e., $M \times K$ column-major, and load the matrix B from shared memory to registers by transposing it so that it becomes $K \times N$ row-major, i.e., $N \times K$ column-major. Consequently, trans
should be specified when loading both matrices A and B from shared memory to registers using ldmatrix
.
Fundamentally, trans
only affects how ldmatrix
maps the threads to the values in the matrix.
Without trans
, the thread-value mapping of the ldmatrix
instruction is as described as follows.
Address / Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
Addr0 | T0V0:R0 | T0V1:R0 | T1V0:R0 | T1V1:R0 | T2V0:R0 | T2V1:R0 | T3V0:R0 | T3V1:R0 |
Addr1 | T4V0:R0 | T4V1:R0 | T5V0:R0 | T5V1:R0 | T6V0:R0 | T6V1:R0 | T7V0:R0 | T7V1:R0 |
Addr2 | T8V0:R0 | T8V1:R0 | T9V0:R0 | T9V1:R0 | T10V0:R0 | T10V1:R0 | T11V0:R0 | T11V1:R0 |
Addr3 | T12V0:R0 | T12V1:R0 | T13V0:R0 | T13V1:R0 | T14V0:R0 | T14V1:R0 | T15V0:R0 | T15V1:R0 |
Addr4 | T16V0:R0 | T16V1:R0 | T17V0:R0 | T17V1:R0 | T18V0:R0 | T18V1:R0 | T19V0:R0 | T19V1:R0 |
Addr5 | T20V0:R0 | T20V1:R0 | T21V0:R0 | T21V1:R0 | T22V0:R0 | T22V1:R0 | T23V0:R0 | T23V1:R0 |
Addr6 | T24V0:R0 | T24V1:R0 | T25V0:R0 | T25V1:R0 | T26V0:R0 | T26V1:R0 | T27V0:R0 | T27V1:R0 |
Addr7 | T28V0:R0 | T28V1:R0 | T29V0:R0 | T29V1:R0 | T30V0:R0 | T30V1:R0 | T31V0:R0 | T31V1:R0 |
With trans
, the thread-value mapping of the ldmatrix
instruction becomes as follows, which can be imagined as transposing the above thread-value mapping table.
Address / Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
Addr0 | T0V0:R0 | T4V0:R0 | T8V0:R0 | T12V0:R0 | T16V0:R0 | T20V0:R0 | T24V0:R0 | T28V0:R0 |
Addr1 | T0V1:R0 | T4V1:R0 | T8V1:R0 | T12V1:R0 | T16V1:R0 | T20V1:R0 | T24V1:R0 | T28V1:R0 |
Addr2 | T1V0:R0 | T5V0:R0 | T9V0:R0 | T13V0:R0 | T17V0:R0 | T21V0:R0 | T25V0:R0 | T29V0:R0 |
Addr3 | T1V1:R0 | T5V1:R0 | T9V1:R0 | T13V1:R0 | T17V1:R0 | T21V1:R0 | T25V1:R0 | T29V1:R0 |
Addr4 | T2V0:R0 | T6V0:R0 | T10V0:R0 | T14V0:R0 | T18V0:R0 | T22V0:R0 | T26V0:R0 | T30V0:R0 |
Addr5 | T2V1:R0 | T6V1:R0 | T10V1:R0 | T14V1:R0 | T18V1:R0 | T22V1:R0 | T26V1:R0 | T30V1:R0 |
Addr6 | T3V0:R0 | T7V0:R0 | T11V0:R0 | T15V0:R0 | T19V0:R0 | T23V0:R0 | T27V0:R0 | T31V0:R0 |
Addr7 | T3V1:R0 | T7V1:R0 | T11V1:R0 | T15V1:R0 | T19V1:R0 | T23V1:R0 | T27V1:R0 | T31V1:R0 |
Although I have never verified it, I believe the performance of ldmatrix
with trans
should be exactly the same as that without trans
.
CuTe ldmatrix
CuTe ldmatrix Implementations
CuTe provides a set of wrappers of the ldmatrix
instruction for SM75 architecture in the cute/arch/copy_sm75.hpp
file, including SM75_U32x1_LDSM_N
, SM75_U32x2_LDSM_N
, SM75_U32x4_LDSM_N
, SM75_U16x2_LDSM_T
, SM75_U16x4_LDSM_T
, and SM75_U16x8_LDSM_T
.
1 | struct SM75_U32x1_LDSM_N |
Taking SM75_U16x8_LDSM_T
as an example, the copy
function takes a reference to a uint128_t
variable smem_src
, which can be converted to the starting address of 8 contiguous 16-bit values on shared memory, and four references to uint32_t
variables dst0
, dst1
, dst2
, and dst3
, which are used to hold the 8 16-bit values loaded from 4 8x8 matrices with 16-bit elements from shared memory to registers uint32_t[4]
.
CuTe ldmatrix Naming Conventions
It is relatively straightforward to understand the wrapper names of SM75_U16x2_LDSM_T
, SM75_U16x4_LDSM_T
, and SM75_U16x8_LDSM_T
. SM75
indicates that they are for SM75 architecture. U16
indicates that the data type of the elements in the matrix is 16-bit unsigned integer. Of course, loading elements of any 16-bit data type, e.g., half
, bfloat16
, int16_t
, and uint16_t
, is supported. x2
, x4
, and x8
indicate how many 16-bit elements each thread will load, which corresponds to .num
being .x1
, .x2
, and .x4
respectively in the ldmatrix
instruction. LDSM
indicates that it is a wrapper of the ldmatrix
instruction. T
indicates that the trans
version of the ldmatrix
instruction will be used.
The user might expect that there will also be SM75_U16x1_LDSM_N
, SM75_U16x2_LDSM_N
, and SM75_U16x4_LDSM_N
wrappers for the non-trans
version of the ldmatrix
instruction. But they are actually named as SM75_U32x1_LDSM_N
, SM75_U32x2_LDSM_N
, and SM75_U32x4_LDSM_N
respectively instead. The user could use SM75_U32x1_LDSM_N
, SM75_U32x2_LDSM_N
, and SM75_U32x4_LDSM_N
wrappers to load 8x8 matrices with 16-bit elements or 4x4 matrices with 32-bit elements from shared memory to registers.
In addition, we also noted that there are also no wrapper names of SM75_U32x1_LDSM_T
, SM75_U32x2_LDSM_T
, and SM75_U32x4_LDSM_T
. This is because there is no trans
version of the ldmatrix
instruction for loading 4x4 matrices with 32-bit elements. The trans
version of the ldmatrix
instruction allows transposing an 8x8 matrix with 16-bit elements to a 8x8 matrix with 16-bit elements, but the same instruction cannot be used to transpose a 4x4 matrix with 32-bit elements. Otherwise, a 32-bit element will be split into two 16-bit parts and loaded into two different threads, which is not the intended behavior.
Conclusions
The trans
and non-trans
versions of the ldmatrix
instruction provides the flexibility of loading matrices with different layouts from shared memory to registers for MMA operations. CuTe provides a set of wrappers of the ldmatrix
instruction whose names are orchestrated to prevent the user from using them incorrectly.
References
CuTe ldmatrix