CuTe Local Partition
Introduction
CuTe outer-partition is often used at the thread level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread in the thread block. The CuTe local_partition
function is a convenient wrapper around the outer_partition
function that allows the developer to partition a tensor into tiles and slice the tile(s) based on the index of the thread in the thread block.
In this article, I would like to discuss the local_partition
function in CuTe and how exactly it works.
CuTe Local Partition
CuTe Local Partition Implementation
The implementation of local_partition
in CuTe is as follows. Essentially, local_partition
flattens the shape of each mode in the tile layout and only took the shape so that the tile along each mode has a stride of 1, converts the index back to the coordinate by inverting the mapping of tile layout function, and calls outer_partition
to partition and slice the tile(s). outer_partition
, on the other hand, performs zipped_divide
and slices the tile(s) based on the coordinate in the first tile mode. That’s why it’s called “outer” partition.
1 | // Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. |
The reason why a tile
was processed using product_each(shape(tile))
before outer_partition
was probably because tile
is a layout and product_each(shape(tile))
, which is a shape, will be converted to a tiler for by-mode division. The tiling of by-mode division using tiler constructed from shape is relatively straightforward to picture in mind.
Given usually the domain and codomain of a thread layout, i.e., the tile
used for local_partition
, is the same, it might seem to be a waste of effort to use tile.get_flat_coord(index)
to get the coordinate from the index. However, the local_partition
function never stated that only compact tile
can be used for local_partition
. If tile
is not compact, using index
for outer_partition
will go off the coordinate space.
Threads Slicing Different Elements In Tile
In a thread block, given a tensor on global memory and a tensor on shared memory, we would like to use all the threads in the thread block to synergistically and efficiently copy the data from the global memory tensor to the shared memory tensor. The layout of threads in the thread block is the tile layout and we will use outer partition to partition and slice the tile(s). The goal is, when iterating through the tile(s), each thread in the thread block will slice one element from the global memory tensor and adjacent threads will slice adjacent elements from the global memory tensor so that the global memory tensor can be accessed efficiently in a coalesced fashion. The adjacency of the elements in the global memory tensor is determined by the index instead of the coordinate. If 8 adjacent threads, whose indices $i$, $i+1$, $i+2$, $i+3$, $i+4$, $i+5$, $i+6$, $i+7$, are all accessing 8 elements in the global memory tensor, whose indices are $j$, $j+1$, $j+2$, $j+3$, $j+4$, $j+5$, $j+6$, $j+7$, then the global memory access is coalesced.
Depending on the implementation, the layout of threads in the thread block can be column-major, row-major, or even something more complicated. For example, a thread block consisting of 32 threads can have a column-major layout of $(8, 4) : (1, 8)$ or a row-major layout of $(4, 8) : (8, 1)$. In the column-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$, which are adjacent indices. However, in the row-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $8$, $16$, $24$, $32$, $40$, $48$, $56$, which are not adjacent indices. Therefore, without local_partition
, the developer will have to very carefully compute the coordinates that correspond to the adjacent indices in the thread layout for each thread by inverting the mapping of tile thread layout function, i.e., tile.get_flat_coord(index)
, which is something that local_partition
already helps us to do.
Assuming the layout of threads is “tight”, mathematically how the tile.get_flat_coord(index)
function computes the coordinate from the index is not very complicated and we have discussed its mathematical derivations in my previous article “CuTe Index To Coordinate”.
Whether the global memory access is determined by the tensor layout and the thread layout. Threads that have the adjacent indices will not necessarily access the adjacent elements in the global memory tensor. One will have to examine the tensor_tiled
to know definitively. But in most cases, if the tensor layout is column-major and the thread layout is also column-major, or the tensor layout is row-major and the thread layout is also row-major, then the threads with adjacent indices will access the adjacent elements in the global memory tensor.
Threads Slicing Same Element In Tile (Broadcast)
Suppose we want to perform a GEMM operation and we have two input tensors dataA
and dataB
, and an output tensor dataC
. The problem is partitioned in a way such that each thread in the thread block will compute a tile of the output tensor dataC
by accumulating the products of the corresponding tiles of the input tensors dataA
and dataB
. Suppose there are $2 \times 16 = 32$ threads in the thread block, the output tensor dataC
is partitioned into $2 \times 16 = 32$ two-dimensional tiles, and the input dataA
and dataB
are partitioned into $2 \times 1 = 2$ and $16 \times 1 = 16$ two-dimensional tiles respectively. This means, the 32 threads have to not only access the tiles of the output tensor dataC
, but also access the tiles of the input tensors dataA
and dataB
correctly, using the thread index. Slicing the tiles of the output tensor dataC
is straightforward using the local_partition
function mentioned above. However, we need to be a little bit more careful when slicing the tiles of the input tensors dataA
and dataB
. Because there are only $2$ tiles of the input tensor dataA
, this means there are 16 threads accessing the first tile of the input tensor dataA
using the same way and the other 16 threads accessing the second tile of the input tensor dataA
using the same way. Suppose the thread layout is $(2, 16) : (16, 1)$, i.e., row-major, for the 32 threads, the threads whose indices are from $0$ to $15$ will have to access the first tile of the input tensor dataA
and the threads whose indices are from $16$ to $31$ will have to access the second tile of the input tensor dataA
. How could we ensure such partition and tile slice using index is done correctly?
There is an overriding local_partition
function that takes an additional projection parameter to strip out unwanted tiling modes and calls the local_partition
function that was described previously for convenience and can be used in this situation.
1 | // Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience |
This overriding local_partition
function will call the original local_partition
function with the tile layout whose unwanted tiling modes are stripped out by the dice
function at the X
position. For example, $\text{dice}((2, 16, 1):(16, 1, 0), (1, X, 1)) = (2, 1):(16, 0)$, $\text{dice}((2, 16, 1):(16, 1, 0), (X, 1, 1)) = (16, 1):(1, 0)$, and $\text{dice}((2, 16, 1):(16, 1, 0), (1, 1, X)) = (2, 16):(16, 1)$. This allows the partition, i.e., zipped_divide
, called in outer_partition
to work correctly for the input tensors dataA
and dataB
and the output tensor dataC
in the GEMM operation. But how about the tile slicing? Since the index remains unchanged, it seems to be incompatible with the tile layout after mode removal. Will this affect the correct coordinate conversion using tile.get_flat_coord(index)
? The answer is no.
As mentioned in the previous article “CuTe Index To Coordinate”, given a tile layout with shape $(d_0, d_1, \ldots, d_{\alpha})$ and stride $(M_0, M_1, \ldots, M_{\alpha})$, the coordinate $\mathbf{x}$ corresponding to the index $f_L(\mathbf{x})$ is computed as follows:
$$
\begin{align}
\mathbf{x} &= (x_0, x_1, \ldots, x_{\alpha}) \\
&= \left(\left\lfloor \frac{f_L{(\mathbf{x})}}{d_0} \right\rfloor \mod M_0, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_1} \right\rfloor \mod M_1, \ldots, \left\lfloor \frac{f_L{(\mathbf{x})}}{d_{\alpha}} \right\rfloor \mod M_{\alpha}\right) \\
\end{align}
$$
Thus, even if some of the modes are stripped out, the coordinate will be computed correctly for the remaining modes and it will ensure that the tile slicing is correct for all the threads in the thread block.
Mathematically, to compute the coordinate, it does not matter whether the tile layout is stripped out first and the index is computed or the index is computed first and then the index mode is stripped out. In other words, we have the following equivalence:
1 | dice(proj, tile).get_flat_coord(index) == dice(proj, tile.get_flat_coord(index)) |
In our case, because the original tile layout is $(2, 16):(16, 1)$, it has the index to coordinate mapping as follows:
1 | 0 -> (0, 0) |
Each thread in the thread block will slice different elements in the tile of the output tensor dataC
, which is exactly the threads slicing different elements in the tile we discussed in the previous section.
After the second mode is stripped out, the tile layout becomes $2:16$, and the index to coordinate mapping becomes as follows:
1 | 0 -> (0,) |
This is as if the unwanted modes of the coordinate computed using the original tile layout were stripped out. Multiple threads in the thread block will slice the same element in the tile of the input tensors dataA
. This is called broadcast in CUDA and usually is not a costly operation. For example, if the data is on shared memory, then accessing the same element in the tile from different threads in a warp will not result in shared memory bank conflicts.
In this example, we could see that the local_partition
function actually simplifies the partition functions that the user have to call.
Using the local_partition
function and thread index as input, the user only have to implement the following code to partition the output tensor dataC
and the input tensors dataA
and dataB
:
1 | Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) |
Using the outer_partition
function, the user would have to implement the following code to partition the output tensor dataC
and the input tensors dataA
and dataB
:
1 | Tensor thrA = outer_partition(dataA, product_each(shape(dice(Step<_1, X,_1>{}, thr_layout))), dice(Step<_1, X,_1>{}, tile.get_flat_coord(thr_idx))); // (M/2,K/1) |
Here, the developer cannot use the thread index as 1D coordinate directly for the outer_partition
function. Instead, the developer has to manage the 2D coordinate conversion and strip the unwanted modes explicitly. It seems to be an overkill and is less elegant from the perspective of API design. Therefore, the local_partition
function was introduced to simplify the implementation from developers. There is of course a trade-off for abstractions. Although the implementation becomes much more concise, it becomes less clear to the user what is happening under the hood.
References
CuTe Local Partition