CuTe Local Partition

Introduction

CuTe outer-partition is often used at the thread level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread in the thread block. The CuTe local_partition function is a convenient wrapper around the outer_partition function that allows the developer to partition a tensor into tiles and slice the tile(s) based on the index of the thread in the thread block.

In this article, I would like to discuss the local_partition function in CuTe and how exactly it works.

CuTe Local Partition

CuTe Local Partition Implementation

The implementation of local_partition in CuTe is as follows. Essentially, local_partition flattens the shape of each mode in the tile layout and only took the shape so that the tile along each mode has a stride of 1, converts the index back to the coordinate by inverting the mapping of tile layout function, and calls outer_partition to partition and slice the tile(s). outer_partition, on the other hand, performs zipped_divide and slices the tile(s) based on the coordinate in the first tile mode. That’s why it’s called “outer” partition.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes.
// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor.
// Split the modes of tensor according to the Tiler
// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y))
// Then slice into the first mode (the "Tile" mode) with Coord
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
outer_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;

// The coord slices into the first mode (the "tile" mode), flatten the second
if constexpr (is_tuple<Coord>::value) {
// Append trailing modes if coord is tuple
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
return tensor_tiled(append<R0>(coord,_), repeat<R1>(_));
} else {
// Flat indexing if coord is not tuple
return tensor_tiled(coord, repeat<R1>(_));
}
}

// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index.
// This is typical at the Thread level where data is partitioned across repeated patterns of threads:
// Tensor data = ... // (_16,_64)
// Tensor thr_data = local_partition(data, Layout<Shape<_2,_16>>{}, thr_idx); // ( _8, _4)
template <class Tensor, class LShape, class LStride, class Index,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index) // index to slice for
{
static_assert(is_integral<Index>::value);
return outer_partition(static_cast<Tensor&&>(tensor),
product_each(shape(tile)),
tile.get_flat_coord(index));
}

The reason why a tile was processed using product_each(shape(tile)) before outer_partition was probably because tile is a layout and product_each(shape(tile)), which is a shape, will be converted to a tiler for by-mode division. The tiling of by-mode division using tiler constructed from shape is relatively straightforward to picture in mind.

Given usually the domain and codomain of a thread layout, i.e., the tile used for local_partition, is the same, it might seem to be a waste of effort to use tile.get_flat_coord(index) to get the coordinate from the index. However, the local_partition function never stated that only compact tile can be used for local_partition. If tile is not compact, using index for outer_partition will go off the coordinate space.

Threads Slicing Different Elements In Tile

In a thread block, given a tensor on global memory and a tensor on shared memory, we would like to use all the threads in the thread block to synergistically and efficiently copy the data from the global memory tensor to the shared memory tensor. The layout of threads in the thread block is the tile layout and we will use outer partition to partition and slice the tile(s). The goal is, when iterating through the tile(s), each thread in the thread block will slice one element from the global memory tensor and adjacent threads will slice adjacent elements from the global memory tensor so that the global memory tensor can be accessed efficiently in a coalesced fashion. The adjacency of the elements in the global memory tensor is determined by the index instead of the coordinate. If 8 adjacent threads, whose indices $i$, $i+1$, $i+2$, $i+3$, $i+4$, $i+5$, $i+6$, $i+7$, are all accessing 8 elements in the global memory tensor, whose indices are $j$, $j+1$, $j+2$, $j+3$, $j+4$, $j+5$, $j+6$, $j+7$, then the global memory access is coalesced.

Depending on the implementation, the layout of threads in the thread block can be column-major, row-major, or even something more complicated. For example, a thread block consisting of 32 threads can have a column-major layout of $(8, 4) : (1, 8)$ or a row-major layout of $(4, 8) : (8, 1)$. In the column-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$, which are adjacent indices. However, in the row-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $8$, $16$, $24$, $32$, $40$, $48$, $56$, which are not adjacent indices. Therefore, without local_partition, the developer will have to very carefully compute the coordinates that correspond to the adjacent indices in the thread layout for each thread by inverting the mapping of tile thread layout function, i.e., tile.get_flat_coord(index), which is something that local_partition already helps us to do.

Assuming the layout of threads is “tight”, mathematically how the tile.get_flat_coord(index) function computes the coordinate from the index is not very complicated and we have discussed its mathematical derivations in my previous article “CuTe Index To Coordinate”.

Whether the global memory access is determined by the tensor layout and the thread layout. Threads that have the adjacent indices will not necessarily access the adjacent elements in the global memory tensor. One will have to examine the tensor_tiled to know definitively. But in most cases, if the tensor layout is column-major and the thread layout is also column-major, or the tensor layout is row-major and the thread layout is also row-major, then the threads with adjacent indices will access the adjacent elements in the global memory tensor.

Threads Slicing Same Element In Tile (Broadcast)

Suppose we want to perform a GEMM operation and we have two input tensors dataA and dataB, and an output tensor dataC. The problem is partitioned in a way such that each thread in the thread block will compute a tile of the output tensor dataC by accumulating the products of the corresponding tiles of the input tensors dataA and dataB. Suppose there are $2 \times 16 = 32$ threads in the thread block, the output tensor dataC is partitioned into $2 \times 16 = 32$ two-dimensional tiles, and the input dataA and dataB are partitioned into $2 \times 1 = 2$ and $16 \times 1 = 16$ two-dimensional tiles respectively. This means, the 32 threads have to not only access the tiles of the output tensor dataC, but also access the tiles of the input tensors dataA and dataB correctly, using the thread index. Slicing the tiles of the output tensor dataC is straightforward using the local_partition function mentioned above. However, we need to be a little bit more careful when slicing the tiles of the input tensors dataA and dataB. Because there are only $2$ tiles of the input tensor dataA, this means there are 16 threads accessing the first tile of the input tensor dataA using the same way and the other 16 threads accessing the second tile of the input tensor dataA using the same way. Suppose the thread layout is $(2, 16) : (16, 1)$, i.e., row-major, for the 32 threads, the threads whose indices are from $0$ to $15$ will have to access the first tile of the input tensor dataA and the threads whose indices are from $16$ to $31$ will have to access the second tile of the input tensor dataA. How could we ensure such partition and tile slice using index is done correctly?

There is an overriding local_partition function that takes an additional projection parameter to strip out unwanted tiling modes and calls the local_partition function that was described previously for convenience and can be used in this situation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience
// when using projections of the same tiler.
// This is typical at the Thread level where data is partitioned across projected layouts of threads:
// Tensor dataA = ... // (M,K)
// Tensor dataB = ... // (N,K)
// Tensor dataC = ... // (M,N)
// auto thr_layout = Layout<Shape<_2,_16,_1>, Stride<_16,_1,_0>>{};
// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1)
// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1)
// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16)
template <class Tensor, class LShape, class LStride, class Index, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index, // index to slice for
Projection const& proj)
{
return local_partition(static_cast<Tensor&&>(tensor),
dice(proj, tile),
index);
}

This overriding local_partition function will call the original local_partition function with the tile layout whose unwanted tiling modes are stripped out by the dice function at the X position. For example, $\text{dice}((2, 16, 1):(16, 1, 0), (1, X, 1)) = (2, 1):(16, 0)$, $\text{dice}((2, 16, 1):(16, 1, 0), (X, 1, 1)) = (16, 1):(1, 0)$, and $\text{dice}((2, 16, 1):(16, 1, 0), (1, 1, X)) = (2, 16):(16, 1)$. This allows the partition, i.e., zipped_divide, called in outer_partition to work correctly for the input tensors dataA and dataB and the output tensor dataC in the GEMM operation. But how about the tile slicing? Since the index remains unchanged, it seems to be incompatible with the tile layout after mode removal. Will this affect the correct coordinate conversion using tile.get_flat_coord(index)? The answer is no.

As mentioned in the previous article “CuTe Index To Coordinate”, given a tile layout with shape $(d_0, d_1, \ldots, d_{\alpha})$ and stride $(M_0, M_1, \ldots, M_{\alpha})$, the coordinate $\mathbf{x}$ corresponding to the index $f_L(\mathbf{x})$ is computed as follows:

$$
\begin{align}
\mathbf{x} &= (x_0, x_1, \ldots, x_{\alpha}) \\
&= \left(\left\lfloor \frac{f_L{(\mathbf{x})}}{d_0} \right\rfloor \mod M_0, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_1} \right\rfloor \mod M_1, \ldots, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_{\alpha}} \right\rfloor \mod M_{\alpha}\right) \\
\end{align}
$$

Thus, even if some of the modes are stripped out, the coordinate will be computed correctly for the remaining modes and it will ensure that the tile slicing is correct for all the threads in the thread block.

Mathematically, to compute the coordinate, it does not matter whether the tile layout is stripped out first and the index is computed or the index is computed first and then the index mode is stripped out. In other words, we have the following equivalence:

1
dice(proj, tile).get_flat_coord(index) == dice(proj, tile.get_flat_coord(index))

In our case, because the original tile layout is $(2, 16):(16, 1)$, it has the index to coordinate mapping as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
0  -> (0, 0)
1 -> (0, 1)
2 -> (0, 2)
3 -> (0, 3)
4 -> (0, 4)
5 -> (0, 5)
6 -> (0, 6)
7 -> (0, 7)
8 -> (0, 8)
9 -> (0, 9)
10 -> (0, 10)
11 -> (0, 11)
12 -> (0, 12)
13 -> (0, 13)
14 -> (0, 14)
15 -> (0, 15)
16 -> (1, 0)
17 -> (1, 1)
18 -> (1, 2)
19 -> (1, 3)
20 -> (1, 4)
21 -> (1, 5)
22 -> (1, 6)
23 -> (1, 7)
24 -> (1, 8)
25 -> (1, 9)
26 -> (1, 10)
27 -> (1, 11)
28 -> (1, 12)
29 -> (1, 13)
30 -> (1, 14)
31 -> (1, 15)

Each thread in the thread block will slice different elements in the tile of the output tensor dataC, which is exactly the threads slicing different elements in the tile we discussed in the previous section.

After the second mode is stripped out, the tile layout becomes $2:16$, and the index to coordinate mapping becomes as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
0  -> (0,)
1 -> (0,)
2 -> (0,)
3 -> (0,)
4 -> (0,)
5 -> (0,)
6 -> (0,)
7 -> (0,)
8 -> (0,)
9 -> (0,)
10 -> (0,)
11 -> (0,)
12 -> (0,)
13 -> (0,)
14 -> (0,)
15 -> (0,)
16 -> (1,)
17 -> (1,)
18 -> (1,)
19 -> (1,)
20 -> (1,)
21 -> (1,)
22 -> (1,)
23 -> (1,)
24 -> (1,)
25 -> (1,)
26 -> (1,)
27 -> (1,)
28 -> (1,)
29 -> (1,)
30 -> (1,)
31 -> (1,)

This is as if the unwanted modes of the coordinate computed using the original tile layout were stripped out. Multiple threads in the thread block will slice the same element in the tile of the input tensors dataA. This is called broadcast in CUDA and usually is not a costly operation. For example, if the data is on shared memory, then accessing the same element in the tile from different threads in a warp will not result in shared memory bank conflicts.

In this example, we could see that the local_partition function actually simplifies the partition functions that the user have to call.

Using the local_partition function and thread index as input, the user only have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:

1
2
3
Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{});  // (M/2,K/1)
Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1)
Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16)

Using the outer_partition function, the user would have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:

1
2
3
Tensor thrA = outer_partition(dataA, product_each(shape(dice(Step<_1, X,_1>{}, thr_layout))), dice(Step<_1, X,_1>{}, tile.get_flat_coord(thr_idx)));  // (M/2,K/1)
Tensor thrB = outer_partition(dataB, product_each(shape(dice(Step< X,_1,_1>{}, thr_layout))), dice(Step< X,_1,_1>{}, tile.get_flat_coord(thr_idx))); // (N/16,K/1)
Tensor thrC = outer_partition(dataC, product_each(shape(dice(Step<_1,_1, X>{}, thr_layout))), dice(Step<_1,_1, X>{}, tile.get_flat_coord(thr_idx))); // (M/2,N/16)

Here, the developer cannot use the thread index as 1D coordinate directly for the outer_partition function. Instead, the developer has to manage the 2D coordinate conversion and strip the unwanted modes explicitly. It seems to be an overkill and is less elegant from the perspective of API design. Therefore, the local_partition function was introduced to simplify the implementation from developers. There is of course a trade-off for abstractions. Although the implementation becomes much more concise, it becomes less clear to the user what is happening under the hood.

References

Author

Lei Mao

Posted on

07-25-2025

Updated on

08-01-2025

Licensed under


Comments