CuTe outer-partition is often used at the thread level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread in the thread block. The CuTe local_partition function is a convenient wrapper around the outer_partition function that allows the developer to partition a tensor into tiles and slice the tile(s) based on the index of the thread in the thread block.
In this article, I would like to discuss the local_partition function in CuTe and how exactly it works.
CuTe Local Partition
CuTe Local Partition Implementation
The implementation of local_partition in CuTe is as follows. Essentially, local_partition flattens the shape of each mode in the tile layout, converts the index back to the coordinate by inverting the mapping of tile layout function, and calls outer_partition to partition and slice the tile(s). outer_partition, on the other hand, performs zipped_divide and slices the tile(s) based on the coordinate in the first tile mode. That’s why it’s called “outer” partition.
// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. // With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor. // Split the modes of tensor according to the Tiler // zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) // Then slice into the first mode (the "Tile" mode) with Coord template <classTensor, classTiler, classCoord, __CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)> CUTE_HOST_DEVICE constexpr auto outer_partition(Tensor && tensor, Tiler const& tiler, Coord const& coord) { auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler); constexprint R1 = decltype(rank<1>(tensor_tiled))::value;
// The coord slices into the first mode (the "tile" mode), flatten the second ifconstexpr(is_tuple<Coord>::value){ // Append trailing modes if coord is tuple constexprint R0 = decltype(rank<0>(tensor_tiled))::value; returntensor_tiled(append<R0>(coord,_), repeat<R1>(_)); } else { // Flat indexing if coord is not tuple returntensor_tiled(coord, repeat<R1>(_)); } }
// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index. // This is typical at the Thread level where data is partitioned across repeated patterns of threads: // Tensor data = ... // (_16,_64) // Tensor thr_data = local_partition(data, Layout<Shape<_2,_16>>{}, thr_idx); // ( _8, _4) template <classTensor, classLShape, classLStride, classIndex, __CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)> CUTE_HOST_DEVICE auto local_partition(Tensor && tensor, Layout<LShape,LStride> const& tile, // coord -> index Index const& index)// index to slice for { static_assert(is_integral<Index>::value); returnouter_partition(static_cast<Tensor&&>(tensor), product_each(shape(tile)), tile.get_flat_coord(index)); }
The question is, given outer_partition, why do we still need local_partition? Usually, given a layout function, we will compute index from coordinate. Computing coordinate from index, however, is not straightforward mathematically, especially when the layout is not “tight”, i.e., the domain and the codomain of the layout function are not the same. In what circumstances would we slice tile(s) using index instead of coordinate?
Threads Slicing Different Elements In Tile
In a thread block, given a tensor on global memory and a tensor on shared memory, we would like to use all the threads in the thread block to synergistically and efficiently copy the data from the global memory tensor to the shared memory tensor. The layout of threads in the thread block is the tile layout and we will use outer partition to partition and slice the tile(s). The goal is, when iterating through the tile(s), each thread in the thread block will slice one element from the global memory tensor and adjacent threads will slice adjacent elements from the global memory tensor so that the global memory tensor can be accessed efficiently in a coalesced fashion. The adjacency of the elements in the global memory tensor is determined by the index instead of the coordinate. If 8 adjacent threads, whose indices $i$, $i+1$, $i+2$, $i+3$, $i+4$, $i+5$, $i+6$, $i+7$, are all accessing 8 elements in the global memory tensor, whose indices are $j$, $j+1$, $j+2$, $j+3$, $j+4$, $j+5$, $j+6$, $j+7$, then the global memory access is coalesced.
Because the layout of the first mode in the zipped division, which is the mode we are slicing into, is the same as the tile layout, if the 8 adjacent threads are adjacent in the tile layout, their access to the global memory tensor will guarantee to be coalesced.
Depending on the implementation, the layout of threads in the thread block can be column-major, row-major, or even something more complicated. For example, a thread block consisting of 32 threads can have a column-major layout of $(8, 4) : (1, 8)$ or a row-major layout of $(4, 8) : (8, 1)$. In the column-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$, which are adjacent indices. However, in the row-major layout, the coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ correspond to the indices $0$, $8$, $16$, $24$, $32$, $40$, $48$, $56$, which are not adjacent indices. Therefore, without local_partition, the developer will have to very carefully compute the coordinates that correspond to the adjacent indices in the thread layout for each thread by inverting the mapping of tile thread layout function, i.e., tile.get_flat_coord(index), which is something that local_partition already helps us to do.
Assuming the layout of threads is “tight”, mathematically how the tile.get_flat_coord(index) function computes the coordinate from the index is not very complicated and we have discussed its mathematical derivations in my previous article “CuTe Index To Coordinate”.
Using the local_partition function, no matter whether the tile layout is column-major or row-major, as long as they are compact layout, we can simply use the index of the thread in the thread block to slice the tile(s) and access the global memory tensor. In our previous case, the thread index $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ will map to coordinates $0$, $1$, $2$, $3$, $4$, $5$, $6$, $7$ in the column-major layout of $(8, 4) : (1, 8)$ and to coordinates $0$, $4$, $8$, $12$, $16$, $20$, $24$, $28$ in the row-major layout of $(4, 8) : (8, 1)$. But the developer does not have to worry about how those coordinates are computed by just using the local_partition function.
However, this still have not explained why we need local_partition since the outer_partition function can just take the thread index as the coordinate. In fact, even if we are using local_partition and thread index as input, there is no guarantee that the data access from adjacent threads will be coalesced. On the contrary, using outer_partition with the thread index as the coordinate might also result in coalesced data access from adjacent threads. One will still have to engineer or search for the tensor layout and the tile layout combination that results in coalesced data access from adjacent threads.
The following example illustrates the stride of data access in outer partition from adjacent threads, using either coordinate or index as input.
intmain(int argc, constchar** argv) { using BM = cute::Int<256>; using BK = cute::Int<32>; using TM = cute::Int<32>; using TK = cute::Int<8>;
auto gmem_layout_1 = cute::make_layout(cute::Shape<BM, BK>{}, cute::LayoutLeft{}); auto gmem_layout_2 = cute::make_layout(cute::Shape<BM, BK>{}, cute::LayoutRight{}); auto gmem_layout_3 = cute::make_layout(cute::Shape<BK, BM>{}, cute::LayoutLeft{}); auto gmem_layout_4 = cute::make_layout(cute::Shape<BK, BM>{}, cute::LayoutRight{});
auto thread_layout_1 = cute::make_layout(cute::Shape<TM, TK>{}, cute::LayoutLeft{}); auto thread_layout_2 = cute::make_layout(cute::Shape<TM, TK>{}, cute::LayoutRight{}); auto thread_layout_3 = cute::make_layout(cute::Shape<TK, TM>{}, cute::LayoutLeft{}); auto thread_layout_4 = cute::make_layout(cute::Shape<TK, TM>{}, cute::LayoutRight{});
auto thread_1_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_1); auto thread_1_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_1); auto thread_1_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_1); auto thread_1_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_1); auto thread_2_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_2); auto thread_2_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_2); auto thread_2_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_2); auto thread_2_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_2); auto thread_3_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_3); auto thread_3_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_3); auto thread_3_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_3); auto thread_3_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_3); auto thread_4_gmem_1_layout = cute::zipped_divide(gmem_layout_1, thread_layout_4); auto thread_4_gmem_2_layout = cute::zipped_divide(gmem_layout_2, thread_layout_4); auto thread_4_gmem_3_layout = cute::zipped_divide(gmem_layout_3, thread_layout_4); auto thread_4_gmem_4_layout = cute::zipped_divide(gmem_layout_4, thread_layout_4);
From the stride between the data elements that thread 0 and 1 are accessing, we could see that using index as input does not result in higher probability of coalesced data access from adjacent threads. So it seems to be more confusing why local_partition is needed. In the next example, we will see that why local_partition is useful.
Threads Slicing Same Element In Tile (Broadcast)
Suppose we want to perform a GEMM operation and we have two input tensors dataA and dataB, and an output tensor dataC. The problem is partitioned in a way such that each thread in the thread block will compute a tile of the output tensor dataC by accumulating the products of the corresponding tiles of the input tensors dataA and dataB. Suppose there are $2 \times 16 = 32$ threads in the thread block, the output tensor dataC is partitioned into $2 \times 16 = 32$ two-dimensional tiles, and the input dataA and dataB are partitioned into $2 \times 1 = 2$ and $16 \times 1 = 16$ two-dimensional tiles respectively. This means, the 32 threads have to not only access the tiles of the output tensor dataC, but also access the tiles of the input tensors dataA and dataB correctly, using the thread index. Slicing the tiles of the output tensor dataC is straightforward using the local_partition function mentioned above. However, we need to be a little bit more careful when slicing the tiles of the input tensors dataA and dataB. Because there are only $2$ tiles of the input tensor dataA, this means there are 16 threads accessing the first tile of the input tensor dataA using the same way and the other 16 threads accessing the second tile of the input tensor dataA using the same way. Suppose the thread layout is $(2, 16) : (16, 1)$, i.e., row-major, for the 32 threads, the threads whose indices are from $0$ to $15$ will have to access the first tile of the input tensor dataA and the threads whose indices are from $16$ to $31$ will have to access the second tile of the input tensor dataA. How could we ensure such partition and tile slice using index is done correctly?
There is an overriding local_partition function that takes an additional projection parameter to strip out unwanted tiling modes and calls the local_partition function that was described previously for convenience and can be used in this situation.
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience // when using projections of the same tiler. // This is typical at the Thread level where data is partitioned across projected layouts of threads: // Tensor dataA = ... // (M,K) // Tensor dataB = ... // (N,K) // Tensor dataC = ... // (M,N) // auto thr_layout = Layout<Shape<_2,_16,_1>, Stride<_16,_1,_0>>{}; // Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) // Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) // Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) template <classTensor, classLShape, classLStride, classIndex, classProjection, __CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)> CUTE_HOST_DEVICE auto local_partition(Tensor && tensor, Layout<LShape,LStride> const& tile, // coord -> index Index const& index, // index to slice for Projection const& proj) { returnlocal_partition(static_cast<Tensor&&>(tensor), dice(proj, tile), index); }
This overriding local_partition function will call the original local_partition function with the tile layout whose unwanted tiling modes are stripped out by the dice function at the X position. For example, $\text{dice}((2, 16, 1):(16, 1, 0), (1, X, 1)) = (2, 1):(16, 0)$, $\text{dice}((2, 16, 1):(16, 1, 0), (X, 1, 1)) = (16, 1):(1, 0)$, and $\text{dice}((2, 16, 1):(16, 1, 0), (1, 1, X)) = (2, 16):(16, 1)$. This allows the partition, i.e., zipped_divide, called in outer_partition to work correctly for the input tensors dataA and dataB and the output tensor dataC in the GEMM operation. But how about the tile slicing? Since the index remains unchanged, it seems to be incompatible with the tile layout after mode removal. Will this affect the correct coordinate conversion using tile.get_flat_coord(index)? The answer is no.
As mentioned in the previous article “CuTe Index To Coordinate”, given a tile layout with shape $(d_0, d_1, \ldots, d_{\alpha})$ and stride $(M_0, M_1, \ldots, M_{\alpha})$, the coordinate $\mathbf{x}$ corresponding to the index $f_L(\mathbf{x})$ is computed as follows:
Thus, even if some of the modes are stripped out, the coordinate will be computed correctly for the remaining modes and it will ensure that the tile slicing is correct for all the threads in the thread block.
Mathematically, to compute the coordinate, it does not matter whether the tile layout is stripped out first and the index is computed or the index is computed first and then the index mode is stripped out. In other words, we have the following equivalence:
Each thread in the thread block will slice different elements in the tile of the output tensor dataC, which is exactly the threads slicing different elements in the tile we discussed in the previous section.
After the second mode is stripped out, the tile layout becomes $2:16$, and the index to coordinate mapping becomes as follows:
This is as if the unwanted modes of the coordinate computed using the original tile layout were stripped out. Multiple threads in the thread block will slice the same element in the tile of the input tensors dataA. This is called broadcast in CUDA and usually is not a costly operation. For example, if the data is on shared memory, then accessing the same element in the tile from different threads in a warp will not result in shared memory bank conflicts.
In this example, we could see that the local_partition function actually simplifies the partition functions that the user have to call.
Using the local_partition function and thread index as input, the user only have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:
Using the outer_partition function, the user would have to implement the following code to partition the output tensor dataC and the input tensors dataA and dataB:
Here, the developer cannot use the thread index as 1D coordinate directly for the outer_partition function. Instead, the developer has to manage the 2D coordinate conversion and strip the unwanted modes explicitly. It seems to be an overkill and is less elegant from the perspective of API design. Therefore, the local_partition function was introduced to simplify the implementation from developers. There is of course a trade-off for abstractions. Although the implementation becomes much more concise, it becomes less clear to the user what is happening under the hood.