CuTe Local Tile
Introduction
CuTe inner-partition is often used at the thread group level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread group, typically a thread block, in the thread grid.
In this article, I would like to discuss the local_tile
function in CuTe and how it exactly works.
CuTe Local Tile
Compared to the local_partition
function we discussed in the previous article “CuTe Local Partition”, local_tile
is more straightforward to understand because there is no mathematics and conversions involved.
CuTe Local Tile Implementation
The implementation of local_partition
in CuTe is as follows. The local_tile
function is just a different name for inner_partition
, which partitions the tensor into tiles and slices the tile(s) based on the coordinate for the second tile mode.
1 | // Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes. |
The local_tile
function is typically used at the thread block level to partition a large problem into multiple smaller problems that can be processed by the threads within the block. The tile size is typically the smaller problem size we statically defined for each thread block to process. The thread grid size will be determined by the problem size divided by the tile size, and the coordinate for local tile is usually the thread block index in the thread grid. There is no need of computing coordinate from index, as in the local_partition
function.
In some problems, such as GEMM, multiple different thread blocks will access the same tile of data. There is also an overriding local_tile
function that takes an additional projection parameter to strip out unwanted tiling modes for convenience.
1 | // Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience |
Suppose we have a smaller GEMM problem to process for each thread block, whose $M = 32$, $N = 64$, and $K$ is the same as the original large GEMM problem. The tile for the output tensor dataC
will be of size $32 \times 64$, and the tile for the input tensors dataA
and dataB
will be of size $32 \times K$ and $64 \times K$, respectively. Because this $K$ might be very large and it is not static, we also partition $K$ into smaller tiles of static size $4$ in advance, so the tile for dataA
will be of size $32 \times 4 \times k$ and the tile for dataB
will be of size $64 \times 4 \times k$, where $k = \lceil K / 4 \rceil$ and can be iterated over for accumulation by the threads in the thread block. Because the partition of $K$ is not at the thread block level, the coordinate used for slicing the tile for dataA
and dataB
will (blockIdx.x, :)
and (:, blockIdx.y)
, respectively.
References
CuTe Local Tile