CuTe ldmatrix

Introduction

The CUDA PTX instruction ldmatrix is a collective load instruction that loads one or more matrices from shared memory to registers for mma instruction. To facilitate the use of the ldmatrix instruction, CuTe also provides a set of wrappers of the ldmatrix instruction.

In this blog post, I would like to discuss the original ldmatrix instruction and its CuTe wrappers.

CUDA PTX ldmatrix

CUDA PTX ldmatrix Instruction

The ldmatrix instruction is documented in the “Parallel Thread Execution ISA” documentation. I would like to slightly elaborate and emphasize on some details based on the original documentation.

1
2
3
4
5
6
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape = {.m8n8};
.num = {.x1, .x2, .x4};
.ss = {.shared{::cta}};
.type = {.b16};

The destination operand r is a brace-enclosed vector expression consisting of 1, 2, or 4 32-bit registers as per the value of .num. Each component of the vector expression holds a fragment from the corresponding matrix.

The source operand p is a 32-bit integer register containing the address of the first value of the row of the row-major matrix or the address of the first value of the column of the column-major matrix. The address must be 16-byte aligned.

.num Threads 0-7 Threads 8-15 Threads 16-23 Threads 24-31
.x1 addr0–addr7 - - -
.x2 addr0–addr7 addr8–addr15 - -
.x4 addr0–addr7 addr8–addr15 addr16–addr23 addr24–addr31

8 contiguous 16-bit values, i.e., a 16-byte value, starting from the address specified are loaded into one 32-bit register from 4 threads, respectively. If .num is .x1, only 8 rows or columns are loaded into a warp of 32 threads. That’s why the .shape is always .m8n8. If .num is .x2, 16 rows or columns are loaded into a warp of 32 threads, where each thread now needs two 32-bit registers to hold the loaded data. If .num is .x4, 32 rows or columns are loaded into a warp of 32 threads, where each thread now needs four 32-bit registers to hold the loaded data.

CUDA PTX ldmatrix Thread-Value Mapping

The ldmatrix fragment layout for one, two, and four 8x8 row-major matrix with 16-bit elements could be illustrated as follows.

When .num is .x1, one 8x8 row-major matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the matrix will be specified by the threads 0, 1, 2, …, 7 respectively. In the 8x8 row-major matrix, the two 16-bit values at (0,0) and (0,1), V0 and V1, will be loaded by the thread 0 (T0) into the 32-bit register R0 owned by T0, the two 16-bit values at (0,2) and (0,3), V0 and V1, will be loaded by the thread 1 (T1) into the 32-bit register R0 owned by T1, and so on.

Row / Col Index 0 1 2 3 4 5 6 7
0 T0V0:R0 T0V1:R0 T1V0:R0 T1V1:R0 T2V0:R0 T2V1:R0 T3V0:R0 T3V1:R0
1 T4V0:R0 T4V1:R0 T5V0:R0 T5V1:R0 T6V0:R0 T6V1:R0 T7V0:R0 T7V1:R0
2 T8V0:R0 T8V1:R0 T9V0:R0 T9V1:R0 T10V0:R0 T10V1:R0 T11V0:R0 T11V1:R0
3 T12V0:R0 T12V1:R0 T13V0:R0 T13V1:R0 T14V0:R0 T14V1:R0 T15V0:R0 T15V1:R0
4 T16V0:R0 T16V1:R0 T17V0:R0 T17V1:R0 T18V0:R0 T18V1:R0 T19V0:R0 T19V1:R0
5 T20V0:R0 T20V1:R0 T21V0:R0 T21V1:R0 T22V0:R0 T22V1:R0 T23V0:R0 T23V1:R0
6 T24V0:R0 T24V1:R0 T25V0:R0 T25V1:R0 T26V0:R0 T26V1:R0 T27V0:R0 T27V1:R0
7 T28V0:R0 T28V1:R0 T29V0:R0 T29V1:R0 T30V0:R0 T30V1:R0 T31V0:R0 T31V1:R0

When .num is .x1, two 8x8 row-major matrix will be loaded. In addition to the first 8x8 matrix loaded as described above, a second 8x8 matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the second matrix will be specified by the threads 8, 9, 10, …, 15 respectively. In the second 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V2 and V3, will be loaded by the thread 0 (T0) into the 32-bit register R1 owned by T0, the two 16-bit values at (0,2) and (0,3), V2 and V3, will be loaded by the thread 1 (T1) into the 32-bit register R1 owned by T1, and so on.

Row / Col Index 0 1 2 3 4 5 6 7
0 T0V2:R1 T0V3:R1 T1V2:R1 T1V3:R1 T2V2:R1 T2V3:R1 T3V2:R1 T3V3:R1
1 T4V2:R1 T4V3:R1 T5V2:R1 T5V3:R1 T6V2:R1 T6V3:R1 T7V2:R1 T7V3:R1
2 T8V2:R1 T8V3:R1 T9V2:R1 T9V3:R1 T10V2:R1 T10V3:R1 T11V2:R1 T11V3:R1
3 T12V2:R1 T12V3:R1 T13V2:R1 T13V3:R1 T14V2:R1 T14V3:R1 T15V2:R1 T15V3:R1
4 T16V2:R1 T16V3:R1 T17V2:R1 T17V3:R1 T18V2:R1 T18V3:R1 T19V2:R1 T19V3:R1
5 T20V2:R1 T20V3:R1 T21V2:R1 T21V3:R1 T22V2:R1 T22V3:R1 T23V2:R1 T23V3:R1
6 T24V2:R1 T24V3:R1 T25V2:R1 T25V3:R1 T26V2:R1 T26V3:R1 T27V2:R1 T27V3:R1
7 T28V2:R1 T28V3:R1 T29V2:R1 T29V3:R1 T30V2:R1 T30V3:R1 T31V2:R1 T31V3:R1

When .num is .x4, four 8x8 row-major matrix will be loaded. In addition to the first two 8x8 matrix loaded as described above, a third and a fourth 8x8 matrix will be loaded. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the third matrix will be specified by the threads 16, 17, 18, …, 23 respectively. The addresses of the values at (0,0), (1,0), (2,0), …, (7,0) in the fourth matrix will be specified by the threads 24, 25, 26, …, 31 respectively. In the third 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V4 and V5, will be loaded by the thread 0 (T0) into the 32-bit register R2 owned by T0, the two 16-bit values at (0,2) and (0,3), V4 and V5, will be loaded by the thread 1 (T1) into the 32-bit register R2 owned by T1, and so on. In the fourth 8x8 matrix, the two 16-bit values at (0,0) and (0,1), V6 and V7, will be loaded by the thread 0 (T0) into the 32-bit register R3 owned by T0, the two 16-bit values at (0,2) and (0,3), V6 and V7, will be loaded by the thread 1 (T1) into the 32-bit register R3 owned by T1, and so on.

Row / Col Index 0 1 2 3 4 5 6 7
0 T0V4:R2 T0V5:R2 T1V4:R2 T1V5:R2 T2V4:R2 T2V5:R2 T3V4:R2 T3V5:R2
1 T4V4:R2 T4V5:R2 T5V4:R2 T5V5:R2 T6V4:R2 T6V5:R2 T7V4:R2 T7V5:R2
2 T8V4:R2 T8V5:R2 T9V4:R2 T9V5:R2 T10V4:R2 T10V5:R2 T11V4:R2 T11V5:R2
3 T12V4:R2 T12V5:R2 T13V4:R2 T13V5:R2 T14V4:R2 T14V5:R2 T15V4:R2 T15V5:R2
4 T16V4:R2 T16V5:R2 T17V4:R2 T17V5:R2 T18V4:R2 T18V5:R2 T19V4:R2 T19V5:R2
5 T20V4:R2 T20V5:R2 T21V4:R2 T21V5:R2 T22V4:R2 T22V5:R2 T23V4:R2 T23V5:R2
6 T24V4:R2 T24V5:R2 T25V4:R2 T25V5:R2 T26V4:R2 T26V5:R2 T27V4:R2 T27V5:R2
7 T28V4:R2 T28V5:R2 T29V4:R2 T29V5:R2 T30V4:R2 T30V5:R2 T31V4:R2 T31V5:R2
Row / Col Index 0 1 2 3 4 5 6 7
0 T0V6:R3 T0V7:R3 T1V6:R3 T1V7:R3 T2V6:R3 T2V7:R3 T3V6:R3 T3V7:R3
1 T4V6:R3 T4V7:R3 T5V6:R3 T5V7:R3 T6V6:R3 T6V7:R3 T7V6:R3 T7V7:R3
2 T8V6:R3 T8V7:R3 T9V6:R3 T9V7:R3 T10V6:R3 T10V7:R3 T11V6:R3 T11V7:R3
3 T12V6:R3 T12V7:R3 T13V6:R3 T13V7:R3 T14V6:R3 T14V7:R3 T15V6:R3 T15V7:R3
4 T16V6:R3 T16V7:R3 T17V6:R3 T17V7:R3 T18V6:R3 T18V7:R3 T19V6:R3 T19V7:R3
5 T20V6:R3 T20V7:R3 T21V6:R3 T21V7:R3 T22V6:R3 T22V7:R3 T23V6:R3 T23V7:R3
6 T24V6:R3 T24V7:R3 T25V6:R3 T25V7:R3 T26V6:R3 T26V7:R3 T27V6:R3 T27V7:R3
7 T28V6:R3 T28V7:R3 T29V6:R3 T29V7:R3 T30V6:R3 T30V7:R3 T31V6:R3 T31V7:R3

Transposed and Non-Transposed

Because ldmatrix is usually used for loading matrix fragments from shared memory to registers for MMA operations and MMA operations provide instructions for matrices which are transposed or non-transposed specified as follows, if the layout of the matrix on shared memory is different from the required layout by MMA operations, we need to transpose the matrix when loading it from shared memory to registers.

BLAS / CuTe A Majorness A Layout B Majorness B Layout
NT M-major (M,K):(1,ldA) N-major (N,K):(1,ldB)
TN K-major (M,K):(ldA,1) K-major (N,K):(ldB,1)
NN M-major (M,K):(1,ldA) K-major (N,K):(ldB,1)
TT K-major (M,K):(ldA,1) N-major (N,K):(1,ldB)

For example, suppose we want to use the cute::SM80_16x8x16_F16F16F16F16_TN MMA operation, which requires the matrix A to be $M \times K$ row-major and the matrix B to be $N \times K$ row-major. However, the matrix A is stored as $M \times K$ column-major on shared memory and the matrix B is stored as $N \times K$ column-major on shared memory. In this case, we will have to load the matrix A from shared memory to registers by transposing it so that it becomes $K \times M$ row-major, i.e., $M \times K$ column-major, and load the matrix B from shared memory to registers by transposing it so that it becomes $K \times N$ row-major, i.e., $N \times K$ column-major. Consequently, trans should be specified when loading both matrices A and B from shared memory to registers using ldmatrix.

Fundamentally, trans only affects how ldmatrix maps the threads to the values in the matrix.

Without trans, the thread-value mapping of the ldmatrix instruction is as described as follows.

Address / Index 0 1 2 3 4 5 6 7
Addr0 T0V0:R0 T0V1:R0 T1V0:R0 T1V1:R0 T2V0:R0 T2V1:R0 T3V0:R0 T3V1:R0
Addr1 T4V0:R0 T4V1:R0 T5V0:R0 T5V1:R0 T6V0:R0 T6V1:R0 T7V0:R0 T7V1:R0
Addr2 T8V0:R0 T8V1:R0 T9V0:R0 T9V1:R0 T10V0:R0 T10V1:R0 T11V0:R0 T11V1:R0
Addr3 T12V0:R0 T12V1:R0 T13V0:R0 T13V1:R0 T14V0:R0 T14V1:R0 T15V0:R0 T15V1:R0
Addr4 T16V0:R0 T16V1:R0 T17V0:R0 T17V1:R0 T18V0:R0 T18V1:R0 T19V0:R0 T19V1:R0
Addr5 T20V0:R0 T20V1:R0 T21V0:R0 T21V1:R0 T22V0:R0 T22V1:R0 T23V0:R0 T23V1:R0
Addr6 T24V0:R0 T24V1:R0 T25V0:R0 T25V1:R0 T26V0:R0 T26V1:R0 T27V0:R0 T27V1:R0
Addr7 T28V0:R0 T28V1:R0 T29V0:R0 T29V1:R0 T30V0:R0 T30V1:R0 T31V0:R0 T31V1:R0

With trans, the thread-value mapping of the ldmatrix instruction becomes as follows, which can be imagined as transposing the above thread-value mapping table.

Address / Index 0 1 2 3 4 5 6 7
Addr0 T0V0:R0 T4V0:R0 T8V0:R0 T12V0:R0 T16V0:R0 T20V0:R0 T24V0:R0 T28V0:R0
Addr1 T0V1:R0 T4V1:R0 T8V1:R0 T12V1:R0 T16V1:R0 T20V1:R0 T24V1:R0 T28V1:R0
Addr2 T1V0:R0 T5V0:R0 T9V0:R0 T13V0:R0 T17V0:R0 T21V0:R0 T25V0:R0 T29V0:R0
Addr3 T1V1:R0 T5V1:R0 T9V1:R0 T13V1:R0 T17V1:R0 T21V1:R0 T25V1:R0 T29V1:R0
Addr4 T2V0:R0 T6V0:R0 T10V0:R0 T14V0:R0 T18V0:R0 T22V0:R0 T26V0:R0 T30V0:R0
Addr5 T2V1:R0 T6V1:R0 T10V1:R0 T14V1:R0 T18V1:R0 T22V1:R0 T26V1:R0 T30V1:R0
Addr6 T3V0:R0 T7V0:R0 T11V0:R0 T15V0:R0 T19V0:R0 T23V0:R0 T27V0:R0 T31V0:R0
Addr7 T3V1:R0 T7V1:R0 T11V1:R0 T15V1:R0 T19V1:R0 T23V1:R0 T27V1:R0 T31V1:R0

Although I have never verified it, I believe the performance of ldmatrix with trans should be exactly the same as that without trans.

CuTe ldmatrix

CuTe ldmatrix Implementations

CuTe provides a set of wrappers of the ldmatrix instruction for SM75 architecture in the cute/arch/copy_sm75.hpp file, including SM75_U32x1_LDSM_N, SM75_U32x2_LDSM_N, SM75_U32x4_LDSM_N, SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, and SM75_U16x8_LDSM_T.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
struct SM75_U32x1_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[1];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

struct SM75_U32x2_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[2];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

struct SM75_U32x4_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[4];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

struct SM75_U16x2_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[1];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

struct SM75_U16x4_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[2];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

struct SM75_U16x8_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[4];

CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

Taking SM75_U16x8_LDSM_T as an example, the copy function takes a reference to a uint128_t variable smem_src, which can be converted to the starting address of 8 contiguous 16-bit values on shared memory, and four references to uint32_t variables dst0, dst1, dst2, and dst3, which are used to hold the 8 16-bit values loaded from 4 8x8 matrices with 16-bit elements from shared memory to registers uint32_t[4].

CuTe ldmatrix Naming Conventions

It is relatively straightforward to understand the wrapper names of SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, and SM75_U16x8_LDSM_T. SM75 indicates that they are for SM75 architecture. U16 indicates that the data type of the elements in the matrix is 16-bit unsigned integer. Of course, loading elements of any 16-bit data type, e.g., half, bfloat16, int16_t, and uint16_t, is supported. x2, x4, and x8 indicate how many 16-bit elements each thread will load, which corresponds to .num being .x1, .x2, and .x4 respectively in the ldmatrix instruction. LDSM indicates that it is a wrapper of the ldmatrix instruction. T indicates that the trans version of the ldmatrix instruction will be used.

The user might expect that there will also be SM75_U16x1_LDSM_N, SM75_U16x2_LDSM_N, and SM75_U16x4_LDSM_N wrappers for the non-trans version of the ldmatrix instruction. But they are actually named as SM75_U32x1_LDSM_N, SM75_U32x2_LDSM_N, and SM75_U32x4_LDSM_N respectively instead. The user could use SM75_U32x1_LDSM_N, SM75_U32x2_LDSM_N, and SM75_U32x4_LDSM_N wrappers to load 8x8 matrices with 16-bit elements or 4x4 matrices with 32-bit elements from shared memory to registers.

In addition, we also noted that there are also no wrapper names of SM75_U32x1_LDSM_T, SM75_U32x2_LDSM_T, and SM75_U32x4_LDSM_T. This is because there is no trans version of the ldmatrix instruction for loading 4x4 matrices with 32-bit elements. The trans version of the ldmatrix instruction allows transposing an 8x8 matrix with 16-bit elements to a 8x8 matrix with 16-bit elements, but the same instruction cannot be used to transpose a 4x4 matrix with 32-bit elements. Otherwise, a 32-bit element will be split into two 16-bit parts and loaded into two different threads, which is not the intended behavior.

Conclusions

The trans and non-trans versions of the ldmatrix instruction provides the flexibility of loading matrices with different layouts from shared memory to registers for MMA operations. CuTe provides a set of wrappers of the ldmatrix instruction whose names are orchestrated to prevent the user from using them incorrectly.

References

Author

Lei Mao

Posted on

10-03-2025

Updated on

10-03-2025

Licensed under


Comments