CuTe Local Partition

Introduction

CuTe outer-partition is often used at the thread level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread in the thread block. The CuTe local_partition function is a convenient wrapper around the outer_partition function that allows the developer to partition a tensor into tiles and slice the tile(s) based on the index of the thread in the thread block.

In this article, I would like to discuss the local_partition function in CuTe and how exactly it works.

CuTe Local Partition

CuTe Local Partition Implementation

The implementation of local_partition in CuTe is as follows. Essentially, local_partition flattens the shape of each mode in the tile layout, converts the index back to the coordinate by inverting the mapping of tile layout function, and calls outer_partition to partition and slice the tile(s). outer_partition, on the other hand, performs zipped_divide and slices the tile(s) based on the coordinate in the first tile mode. That’s why it’s called “outer” partition.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes.
// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor.
// Split the modes of tensor according to the Tiler
// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y))
// Then slice into the first mode (the "Tile" mode) with Coord
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
outer_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;

// The coord slices into the first mode (the "tile" mode), flatten the second
if constexpr (is_tuple<Coord>::value) {
// Append trailing modes if coord is tuple
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
return tensor_tiled(append<R0>(coord,_), repeat<R1>(_));
} else {
// Flat indexing if coord is not tuple
return tensor_tiled(coord, repeat<R1>(_));
}
}

// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index.
// This is typical at the Thread level where data is partitioned across repeated patterns of threads:
// Tensor data = ... // (_16,_64)
// Tensor thr_data = local_partition(data, Layout<Shape<_2,_16>>{}, thr_idx); // ( _8, _4)
template <class Tensor, class LShape, class LStride, class Index,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index) // index to slice for
{
static_assert(is_integral<Index>::value);
return outer_partition(static_cast<Tensor&&>(tensor),
product_each(shape(tile)),
tile.get_flat_coord(index));
}

The question is, given outer_partition, why do we still need local_partition? Usually, given a layout function, we will compute index from coordinate. Computing coordinate from index, however, is not straightforward mathematically, especially when the layout is not “tight”, i.e., the domain and the codomain of the layout function are not the same. In what circumstances would we slice tile(s) using index instead of coordinate?

Threads Slicing Different Elements In Tile

In a thread block, given a tensor on global memory and a tensor on shared memory, we would like to use all the threads in the thread block to synergistically and efficiently copy the data from the global memory tensor to the shared memory tensor. The layout of threads in the thread block is the tile layout and we will use outer partition to partition and slice the tile(s). The goal is, when iterating through the tile(s), each thread in the thread block will slice one element from the global memory tensor and adjacent threads will slice adjacent elements from the global memory tensor so that the global memory tensor can be accessed efficiently in a coalesced fashion. The adjacency of the elements in the global memory tensor is determined by the index instead of the coordinate. If 8 adjacent threads, whose indices $i$, $i+1$, $i+2$, $i+3$, $i+4$, $i+5$, $i+6$, $i+7$, are all accessing 8 elements in the global memory tensor, whose indices are $j$, $j+1$, $j+2$, $j+3$, $j+4$, $j+5$, $j+6$, $j+7$, then the global memory access is coalesced.

Because the layout of the first mode in the zipped division, which is the mode we are slicing into, is the same as the tile layout, if the 8 adjacent threads are adjacent in the tile layout, their access to the global memory tensor will guarantee to be coalesced.

Depending on the implementation, the layout of threads in the thread block can be column-major, row-major, or even something more complicated. For example, a thread block consisting of 32 threads can have a column-major layout of $(8, 4) : (1, 8)$ or a row-major layout of $(4, 8) : (8, 1)$. In the column-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$, which are adjacent indices. However, in the row-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $8$, $16$, $24$, $32$, $40$, $48$, $56$, which are not adjacent indices. Therefore, without local_partition, the developer will have to very carefully compute the coordinates that correspond to the adjacent indices in the thread layout for each thread by inverting the mapping of tile thread layout function, i.e., tile.get_flat_coord(index), which is something that local_partition already helps us to do.

Assuming the layout of threads is “tight”, mathematically how the tile.get_flat_coord(index) function computes the coordinate from the index is not very complicated and we have discussed its mathematical derivations in my previous article “CuTe Index To Coordinate”.

Using the local_partition function, no matter whether the tile layout is column-major or row-major, as long as they are compact layout, we can simply use the index of the thread in the thread block to slice the tile(s) and access the global memory tensor. In our previous case, the thread index $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ will map to coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ in the column-major layout of $(8, 4) : (1, 8)$ and to coordinates $0$, $4$, $8$, $12$, $16$, $20$, $24$, $28$ in the row-major layout of $(4, 8) : (8, 1)$. But the developer does not have to worry about how those coordinates are computed by just using the local_partition function.

However, this still have not explained why we need local_partition since the outer_partition function can just take the thread index as the coordinate. In fact, even if we are using local_partition and thread index as input, there is no guarantee that the data access from adjacent threads will be coalesced. On the contrary, using outer_partition with the thread index as the coordinate might also result in coalesced data access from adjacent threads. One will still have to engineer or search for the tensor layout and the tile layout combination that results in coalesced data access from adjacent threads.

The following example illustrates the stride of data access in outer partition from adjacent threads, using either coordinate or index as input.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#include <cassert>
#include <fstream>
#include <iomanip>
#include <iostream>

#include <cute/layout.hpp>
#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>


int main(int argc, const char** argv)
{
using BM = cute::Int<256>;
using BK = cute::Int<32>;
using TM = cute::Int<32>;
using TK = cute::Int<8>;

auto gmem_layout_1 = cute::make_layout(cute::Shape<BM, BK>{}, cute::LayoutLeft{});
auto gmem_layout_2 = cute::make_layout(cute::Shape<BM, BK>{}, cute::LayoutRight{});
auto gmem_layout_3 = cute::make_layout(cute::Shape<BK, BM>{}, cute::LayoutLeft{});
auto gmem_layout_4 = cute::make_layout(cute::Shape<BK, BM>{}, cute::LayoutRight{});

auto thread_layout_1 = cute::make_layout(cute::Shape<TM, TK>{}, cute::LayoutLeft{});
auto thread_layout_2 = cute::make_layout(cute::Shape<TM, TK>{}, cute::LayoutRight{});
auto thread_layout_3 = cute::make_layout(cute::Shape<TK, TM>{}, cute::LayoutLeft{});
auto thread_layout_4 = cute::make_layout(cute::Shape<TK, TM>{}, cute::LayoutRight{});

auto thread_1_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_1);
auto thread_1_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_1);
auto thread_1_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_1);
auto thread_1_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_1);
auto thread_2_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_2);
auto thread_2_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_2);
auto thread_2_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_2);
auto thread_2_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_2);
auto thread_3_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_3);
auto thread_3_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_3);
auto thread_3_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_3);
auto thread_3_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_3);
auto thread_4_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_4);
auto thread_4_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_4);
auto thread_4_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_4);
auto thread_4_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_4);

std::cout << "gmem_layout_1: " << std::endl;
cute::print(gmem_layout_1);
std::cout << std::endl;
std::cout << "gmem_layout_2: " << std::endl;
cute::print(gmem_layout_2);
std::cout << std::endl;
std::cout << "gmem_layout_3: " << std::endl;
cute::print(gmem_layout_3);
std::cout << std::endl;
std::cout << "gmem_layout_4: " << std::endl;
cute::print(gmem_layout_4);
std::cout << std::endl;

std::cout << "thread_layout_1: " << std::endl;
cute::print(thread_layout_1);
std::cout << std::endl;
std::cout << "thread_layout_2: " << std::endl;
cute::print(thread_layout_2);
std::cout << std::endl;
std::cout << "thread_layout_3: " << std::endl;
cute::print(thread_layout_3);
std::cout << std::endl;
std::cout << "thread_layout_4: " << std::endl;
cute::print(thread_layout_4);
std::cout << std::endl;

std::cout << "thread_1_gmem_1_layout: " << std::endl;
cute::print(thread_1_gmem_1_layout);
std::cout << std::endl;
std::cout << "thread_1_gmem_2_layout: " << std::endl;
cute::print(thread_1_gmem_2_layout);
std::cout << std::endl;
std::cout << "thread_1_gmem_3_layout: " << std::endl;
cute::print(thread_1_gmem_3_layout);
std::cout << std::endl;
std::cout << "thread_1_gmem_4_layout: " << std::endl;
cute::print(thread_1_gmem_4_layout);
std::cout << std::endl;
std::cout << "thread_2_gmem_1_layout: " << std::endl;
cute::print(thread_2_gmem_1_layout);
std::cout << std::endl;
std::cout << "thread_2_gmem_2_layout: " << std::endl;
cute::print(thread_2_gmem_2_layout);
std::cout << std::endl;
std::cout << "thread_2_gmem_3_layout: " << std::endl;
cute::print(thread_2_gmem_3_layout);
std::cout << std::endl;
std::cout << "thread_2_gmem_4_layout: " << std::endl;
cute::print(thread_2_gmem_4_layout);
std::cout << std::endl;
std::cout << "thread_3_gmem_1_layout: " << std::endl;
cute::print(thread_3_gmem_1_layout);
std::cout << std::endl;
std::cout << "thread_3_gmem_2_layout: " << std::endl;
cute::print(thread_3_gmem_2_layout);
std::cout << std::endl;
std::cout << "thread_3_gmem_3_layout: " << std::endl;
cute::print(thread_3_gmem_3_layout);
std::cout << std::endl;
std::cout << "thread_3_gmem_4_layout: " << std::endl;
cute::print(thread_3_gmem_4_layout);
std::cout << std::endl;
std::cout << "thread_4_gmem_1_layout: " << std::endl;
cute::print(thread_4_gmem_1_layout);
std::cout << std::endl;
std::cout << "thread_4_gmem_2_layout: " << std::endl;
cute::print(thread_4_gmem_2_layout);
std::cout << std::endl;
std::cout << "thread_4_gmem_3_layout: " << std::endl;
cute::print(thread_4_gmem_3_layout);
std::cout << std::endl;
std::cout << "thread_4_gmem_4_layout: " << std::endl;
cute::print(thread_4_gmem_4_layout);
std::cout << std::endl;

std::cout << "The stride between coordinate 0 and 1: " << std::endl;
std::cout << "thread_1_gmem_1_layout: ";
std::cout << thread_1_gmem_1_layout(1) - thread_1_gmem_1_layout(0, 0) << std::endl;
std::cout << "thread_1_gmem_2_layout: ";
std::cout << thread_1_gmem_2_layout(1) - thread_1_gmem_2_layout(0, 0) << std::endl;
std::cout << "thread_1_gmem_3_layout: ";
std::cout << thread_1_gmem_3_layout(1) - thread_1_gmem_3_layout(0, 0) << std::endl;
std::cout << "thread_1_gmem_4_layout: ";
std::cout << thread_1_gmem_4_layout(1) - thread_1_gmem_4_layout(0, 0) << std::endl;
std::cout << "thread_2_gmem_1_layout: ";
std::cout << thread_2_gmem_1_layout(1) - thread_2_gmem_1_layout(0, 0) << std::endl;
std::cout << "thread_2_gmem_2_layout: ";
std::cout << thread_2_gmem_2_layout(1) - thread_2_gmem_2_layout(0, 0) << std::endl;
std::cout << "thread_2_gmem_3_layout: ";
std::cout << thread_2_gmem_3_layout(1) - thread_2_gmem_3_layout(0, 0) << std::endl;
std::cout << "thread_2_gmem_4_layout: ";
std::cout << thread_2_gmem_4_layout(1) - thread_2_gmem_4_layout(0, 0) << std::endl;
std::cout << "thread_3_gmem_1_layout: ";
std::cout << thread_3_gmem_1_layout(1) - thread_3_gmem_1_layout(0, 0) << std::endl;
std::cout << "thread_3_gmem_2_layout: ";
std::cout << thread_3_gmem_2_layout(1) - thread_3_gmem_2_layout(0, 0) << std::endl;
std::cout << "thread_3_gmem_3_layout: ";
std::cout << thread_3_gmem_3_layout(1) - thread_3_gmem_3_layout(0, 0) << std::endl;
std::cout << "thread_3_gmem_4_layout: ";
std::cout << thread_3_gmem_4_layout(1) - thread_3_gmem_4_layout(0, 0) << std::endl;
std::cout << "thread_4_gmem_1_layout: ";
std::cout << thread_4_gmem_1_layout(1) - thread_4_gmem_1_layout(0, 0) << std::endl;
std::cout << "thread_4_gmem_2_layout: ";
std::cout << thread_4_gmem_2_layout(1) - thread_4_gmem_2_layout(0, 0) << std::endl;
std::cout << "thread_4_gmem_3_layout: ";
std::cout << thread_4_gmem_3_layout(1) - thread_4_gmem_3_layout(0, 0) << std::endl;
std::cout << "thread_4_gmem_4_layout: ";
std::cout << thread_4_gmem_4_layout(1) - thread_4_gmem_4_layout(0, 0) << std::endl;

std::cout << "The stride between index 0 and 1: " << std::endl;
std::cout << "thread_1_gmem_1_layout: ";
std::cout << thread_1_gmem_1_layout(thread_layout_1.get_flat_coord(1), 0) -
thread_1_gmem_1_layout(thread_layout_1.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_1_gmem_2_layout: ";
std::cout << thread_1_gmem_2_layout(thread_layout_2.get_flat_coord(1), 0) -
thread_1_gmem_2_layout(thread_layout_2.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_1_gmem_3_layout: ";
std::cout << thread_1_gmem_3_layout(thread_layout_3.get_flat_coord(1), 0) -
thread_1_gmem_3_layout(thread_layout_3.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_1_gmem_4_layout: ";
std::cout << thread_1_gmem_4_layout(thread_layout_4.get_flat_coord(1), 0) -
thread_1_gmem_4_layout(thread_layout_4.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_2_gmem_1_layout: ";
std::cout << thread_2_gmem_1_layout(thread_layout_1.get_flat_coord(1), 0) -
thread_2_gmem_1_layout(thread_layout_1.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_2_gmem_2_layout: ";
std::cout << thread_2_gmem_2_layout(thread_layout_2.get_flat_coord(1), 0) -
thread_2_gmem_2_layout(thread_layout_2.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_2_gmem_3_layout: ";
std::cout << thread_2_gmem_3_layout(thread_layout_3.get_flat_coord(1), 0) -
thread_2_gmem_3_layout(thread_layout_3.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_2_gmem_4_layout: ";
std::cout << thread_2_gmem_4_layout(thread_layout_4.get_flat_coord(1), 0) -
thread_2_gmem_4_layout(thread_layout_4.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_3_gmem_1_layout: ";
std::cout << thread_3_gmem_1_layout(thread_layout_1.get_flat_coord(1), 0) -
thread_3_gmem_1_layout(thread_layout_1.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_3_gmem_2_layout: ";
std::cout << thread_3_gmem_2_layout(thread_layout_2.get_flat_coord(1), 0) -
thread_3_gmem_2_layout(thread_layout_2.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_3_gmem_3_layout: ";
std::cout << thread_3_gmem_3_layout(thread_layout_3.get_flat_coord(1), 0) -
thread_3_gmem_3_layout(thread_layout_3.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_3_gmem_4_layout: ";
std::cout << thread_3_gmem_4_layout(thread_layout_4.get_flat_coord(1), 0) -
thread_3_gmem_4_layout(thread_layout_4.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_4_gmem_1_layout: ";
std::cout << thread_4_gmem_1_layout(thread_layout_1.get_flat_coord(1), 0) -
thread_4_gmem_1_layout(thread_layout_1.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_4_gmem_2_layout: ";
std::cout << thread_4_gmem_2_layout(thread_layout_2.get_flat_coord(1), 0) -
thread_4_gmem_2_layout(thread_layout_2.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_4_gmem_3_layout: ";
std::cout << thread_4_gmem_3_layout(thread_layout_3.get_flat_coord(1), 0) -
thread_4_gmem_3_layout(thread_layout_3.get_flat_coord(0), 0)
<< std::endl;
std::cout << "thread_4_gmem_4_layout: ";
std::cout << thread_4_gmem_4_layout(thread_layout_4.get_flat_coord(1), 0) -
thread_4_gmem_4_layout(thread_layout_4.get_flat_coord(0), 0)
<< std::endl;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
gmem_layout_1:
(_256,_32):(_1,_256)
gmem_layout_2:
(_256,_32):(_32,_1)
gmem_layout_3:
(_32,_256):(_1,_32)
gmem_layout_4:
(_32,_256):(_256,_1)
thread_layout_1:
(_32,_8):(_1,_32)
thread_layout_2:
(_32,_8):(_8,_1)
thread_layout_3:
(_8,_32):(_1,_8)
thread_layout_4:
(_8,_32):(_32,_1)
thread_1_gmem_1_layout:
((_32,_8),_32):((_1,_32),_256)
thread_1_gmem_2_layout:
((_32,_8),_32):((_32,_1024),_1)
thread_1_gmem_3_layout:
((_32,_8),_32):((_1,_32),_256)
thread_1_gmem_4_layout:
((_32,_8),_32):((_256,_1),_8)
thread_2_gmem_1_layout:
((_32,_8),_32):((_8,_1),_256)
thread_2_gmem_2_layout:
((_32,_8),_32):((_256,_32),_1)
thread_2_gmem_3_layout:
((_32,_8),_32):((_8,_1),_256)
thread_2_gmem_4_layout:
(((_4,_8),_8),_32):(((_2048,_1),_256),_8)
thread_3_gmem_1_layout:
((_8,_32),_32):((_1,_8),_256)
thread_3_gmem_2_layout:
((_8,_32),_32):((_32,_256),_1)
thread_3_gmem_3_layout:
((_8,_32),_32):((_1,_8),_256)
thread_3_gmem_4_layout:
((_8,(_4,_8)),_32):((_256,(_2048,_1)),_8)
thread_4_gmem_1_layout:
((_8,_32),_32):((_32,_1),_256)
thread_4_gmem_2_layout:
((_8,_32),_32):((_1024,_32),_1)
thread_4_gmem_3_layout:
((_8,_32),_32):((_32,_1),_256)
thread_4_gmem_4_layout:
((_8,_32),_32):((_1,_256),_8)
The stride between coordinate 0 and 1:
thread_1_gmem_1_layout: 1
thread_1_gmem_2_layout: 32
thread_1_gmem_3_layout: 1
thread_1_gmem_4_layout: 256
thread_2_gmem_1_layout: 8
thread_2_gmem_2_layout: 256
thread_2_gmem_3_layout: 8
thread_2_gmem_4_layout: 2048
thread_3_gmem_1_layout: 1
thread_3_gmem_2_layout: 32
thread_3_gmem_3_layout: 1
thread_3_gmem_4_layout: 256
thread_4_gmem_1_layout: 32
thread_4_gmem_2_layout: 1024
thread_4_gmem_3_layout: 32
thread_4_gmem_4_layout: 1
The stride between index 0 and 1:
thread_1_gmem_1_layout: 1
thread_1_gmem_2_layout: 1024
thread_1_gmem_3_layout: 1
thread_1_gmem_4_layout: 1
thread_2_gmem_1_layout: 8
thread_2_gmem_2_layout: 32
thread_2_gmem_3_layout: 8
thread_2_gmem_4_layout: 256
thread_3_gmem_1_layout: 1
thread_3_gmem_2_layout: 256
thread_3_gmem_3_layout: 1
thread_3_gmem_4_layout: 2048
thread_4_gmem_1_layout: 32
thread_4_gmem_2_layout: 32
thread_4_gmem_3_layout: 32
thread_4_gmem_4_layout: 256

From the stride between the data elements that thread 0 and 1 are accessing, we could see that using index as input does not result in higher probability of coalesced data access from adjacent threads. So it seems to be more confusing why local_partition is needed. In the next example, we will see that why local_partition is useful.

Threads Slicing Same Element In Tile (Broadcast)

Suppose we want to perform a GEMM operation and we have two input tensors dataA and dataB, and an output tensor dataC. The problem is partitioned in a way such that each thread in the thread block will compute a tile of the output tensor dataC by accumulating the products of the corresponding tiles of the input tensors dataA and dataB. Suppose there are $2 \times 16 = 32$ threads in the thread block, the output tensor dataC is partitioned into $2 \times 16 = 32$ two-dimensional tiles, and the input dataA and dataB are partitioned into $2 \times 1 = 2$ and $16 \times 1 = 16$ two-dimensional tiles respectively. This means, the 32 threads have to not only access the tiles of the output tensor dataC, but also access the tiles of the input tensors dataA and dataB correctly, using the thread index. Slicing the tiles of the output tensor dataC is straightforward using the local_partition function mentioned above. However, we need to be a little bit more careful when slicing the tiles of the input tensors dataA and dataB. Because there are only $2$ tiles of the input tensor dataA, this means there are 16 threads accessing the first tile of the input tensor dataA using the same way and the other 16 threads accessing the second tile of the input tensor dataA using the same way. Suppose the thread layout is $(2, 16) : (16, 1)$, i.e., row-major, for the 32 threads, the threads whose indices are from $0$ to $15$ will have to access the first tile of the input tensor dataA and the threads whose indices are from $16$ to $31$ will have to access the second tile of the input tensor dataA. How could we ensure such partition and tile slice using index is done correctly?

There is an overriding local_partition function that takes an additional projection parameter to strip out unwanted tiling modes and calls the local_partition function that was described previously for convenience and can be used in this situation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience
// when using projections of the same tiler.
// This is typical at the Thread level where data is partitioned across projected layouts of threads:
// Tensor dataA = ... // (M,K)
// Tensor dataB = ... // (N,K)
// Tensor dataC = ... // (M,N)
// auto thr_layout = Layout<Shape<_2,_16,_1>, Stride<_16,_1,_0>>{};
// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1)
// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1)
// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16)
template <class Tensor, class LShape, class LStride, class Index, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index, // index to slice for
Projection const& proj)
{
return local_partition(static_cast<Tensor&&>(tensor),
dice(proj, tile),
index);
}

This overriding local_partition function will call the original local_partition function with the tile layout whose unwanted tiling modes are stripped out by the dice function at the X position. For example, $\text{dice}((2, 16, 1):(16, 1, 0), (1, X, 1)) = (2, 1):(16, 0)$, $\text{dice}((2, 16, 1):(16, 1, 0), (X, 1, 1)) = (16, 1):(1, 0)$, and $\text{dice}((2, 16, 1):(16, 1, 0), (1, 1, X)) = (2, 16):(16, 1)$. This allows the partition, i.e., zipped_divide, called in outer_partition to work correctly for the input tensors dataA and dataB and the output tensor dataC in the GEMM operation. But how about the tile slicing? Since the index remains unchanged, it seems to be incompatible with the tile layout after mode removal. Will this affect the correct coordinate conversion using tile.get_flat_coord(index)? The answer is no.

As mentioned in the previous article “CuTe Index To Coordinate”, given a tile layout with shape $(d_0, d_1, \ldots, d_{\alpha})$ and stride $(M_0, M_1, \ldots, M_{\alpha})$, the coordinate $\mathbf{x}$ corresponding to the index $f_L(\mathbf{x})$ is computed as follows:

$$
\begin{align}
\mathbf{x} &= (x_0, x_1, \ldots, x_{\alpha}) \\
&= \left(\left\lfloor \frac{f_L{(\mathbf{x})}}{d_0} \right\rfloor \mod M_0, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_1} \right\rfloor \mod M_1, \ldots, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_{\alpha}} \right\rfloor \mod M_{\alpha}\right) \\
\end{align}
$$

Thus, even if some of the modes are stripped out, the coordinate will be computed correctly for the remaining modes and it will ensure that the tile slicing is correct for all the threads in the thread block.

Mathematically, to compute the coordinate, it does not matter whether the tile layout is stripped out first and the index is computed or the index is computed first and then the index mode is stripped out. In other words, we have the following equivalence:

1
dice(proj, tile).get_flat_coord(index) == dice(proj, tile.get_flat_coord(index))

In our case, because the original tile layout is $(2, 16):(16, 1)$, it has the index to coordinate mapping as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
0  -> (0, 0)
1 -> (0, 1)
2 -> (0, 2)
3 -> (0, 3)
4 -> (0, 4)
5 -> (0, 5)
6 -> (0, 6)
7 -> (0, 7)
8 -> (0, 8)
9 -> (0, 9)
10 -> (0, 10)
11 -> (0, 11)
12 -> (0, 12)
13 -> (0, 13)
14 -> (0, 14)
15 -> (0, 15)
16 -> (1, 0)
17 -> (1, 1)
18 -> (1, 2)
19 -> (1, 3)
20 -> (1, 4)
21 -> (1, 5)
22 -> (1, 6)
23 -> (1, 7)
24 -> (1, 8)
25 -> (1, 9)
26 -> (1, 10)
27 -> (1, 11)
28 -> (1, 12)
29 -> (1, 13)
30 -> (1, 14)
31 -> (1, 15)

Each thread in the thread block will slice different elements in the tile of the output tensor dataC, which is exactly the threads slicing different elements in the tile we discussed in the previous section.

After the second mode is stripped out, the tile layout becomes $2:16$, and the index to coordinate mapping becomes as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
0  -> (0,)
1 -> (0,)
2 -> (0,)
3 -> (0,)
4 -> (0,)
5 -> (0,)
6 -> (0,)
7 -> (0,)
8 -> (0,)
9 -> (0,)
10 -> (0,)
11 -> (0,)
12 -> (0,)
13 -> (0,)
14 -> (0,)
15 -> (0,)
16 -> (1,)
17 -> (1,)
18 -> (1,)
19 -> (1,)
20 -> (1,)
21 -> (1,)
22 -> (1,)
23 -> (1,)
24 -> (1,)
25 -> (1,)
26 -> (1,)
27 -> (1,)
28 -> (1,)
29 -> (1,)
30 -> (1,)
31 -> (1,)

This is as if the unwanted modes of the coordinate computed using the original tile layout were stripped out. Multiple threads in the thread block will slice the same element in the tile of the input tensors dataA. This is called broadcast in CUDA and usually is not a costly operation. For example, if the data is on shared memory, then accessing the same element in the tile from different threads in a warp will not result in shared memory bank conflicts.

In this example, we could see that the local_partition function actually simplifies the partition functions that the user have to call.

Using the local_partition function and thread index as input, the user only have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:

1
2
3
Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{});  // (M/2,K/1)
Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1)
Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16)

Using the outer_partition function, the user would have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:

1
2
3
Tensor thrA = outer_partition(dataA, dice(Step<_1, X,_1>{}, thr_layout), dice(Step<_1, X,_1>{}, tile.get_flat_coord(thr_idx)));  // (M/2,K/1)
Tensor thrB = outer_partition(dataB, dice(Step< X,_1,_1>{}, thr_layout), dice(Step< X,_1,_1>{}, tile.get_flat_coord(thr_idx))); // (N/16,K/1)
Tensor thrC = outer_partition(dataC, dice(Step<_1,_1, X>{}, thr_layout), dice(Step<_1,_1, X>{}, tile.get_flat_coord(thr_idx))); // (M/2,N/16)

Here, the developer cannot use the thread index as 1D coordinate directly for the outer_partition function. Instead, the developer has to manage the 2D coordinate conversion and strip the unwanted modes explicitly. It seems to be an overkill and is less elegant from the perspective of API design. Therefore, the local_partition function was introduced to simplify the implementation from developers. There is of course a trade-off for abstractions. Although the implementation becomes much more concise, it becomes less clear to the user what is happening under the hood.

References

Author

Lei Mao

Posted on

07-25-2025

Updated on

07-25-2025

Licensed under


Comments