CuTe Local Tile

Introduction

CuTe inner-partition is often used at the thread group level to partition a tensor into tiles and slice the tile(s) based on the coordinate of the thread group, typically a thread block, in the thread grid.

In this article, I would like to discuss the local_tile function in CuTe and how it exactly works.

CuTe Local Tile

Compared to the local_partition function we discussed in the previous article “CuTe Local Partition”, local_tile is more straightforward to understand because there is no mathematics and conversions involved.

CuTe Local Tile Implementation

The implementation of local_partition in CuTe is as follows. The local_tile function is just a different name for inner_partition, which partitions the tensor into tiles and slices the tile(s) based on the coordinate for the second tile mode.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes.
// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to.
// Split the modes of tensor according to the Tiler
// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y))
// Then slice into the second mode (the "Rest" mode) with Coord
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
inner_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;

// The coord slices into the second mode (the "rest" mode), flatten the first
if constexpr (is_tuple<Coord>::value) {
// Append trailing modes if coord is tuple
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;
return tensor_tiled(repeat<R0>(_), append<R1>(coord,_));
} else {
// Flat indexing if coord is not tuple
return tensor_tiled(repeat<R0>(_), coord);
}
}

// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile.
// This is typical at the CTA level where tiles of data are extracted:
// Tensor data = ... // ( M, N)
// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64)
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_tile(Tensor && tensor,
Tiler const& tiler, // tiler to apply
Coord const& coord) // coord to slice into "remainder"
{
return inner_partition(static_cast<Tensor&&>(tensor),
tiler,
coord);
}

The local_tile function is typically used at the thread block level to partition a large problem into multiple smaller problems that can be processed by the threads within the block. The tile size is typically the smaller problem size we statically defined for each thread block to process. The thread grid size will be determined by the problem size divided by the tile size, and the coordinate for local tile is usually the thread block index in the thread grid. There is no need of computing coordinate from index, as in the local_partition function.

In some problems, such as GEMM, multiple different thread blocks will access the same tile of data. There is also an overriding local_tile function that takes an additional projection parameter to strip out unwanted tiling modes for convenience.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience
// when using projections of the same tiler.
// This is typical at the CTA level where tiles of data are extracted as projections:
// Tensor dataA = ... // (M,K)
// Tensor dataB = ... // (N,K)
// Tensor dataC = ... // (M,N)
// auto cta_tiler = Shape<_32, _64, _4>{};
// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);
// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k)
// Tensor ctaB = local_tile(dataB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k)
// Tensor ctaC = local_tile(dataC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64)
template <class Tensor, class Tiler, class Coord, class Proj,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_tile(Tensor && tensor,
Tiler const& tiler, // tiler to apply
Coord const& coord, // coord to slice into "remainder"
Proj const& proj) // projection to apply to tiler and coord
{
return local_tile(static_cast<Tensor&&>(tensor),
dice(proj, tiler),
dice(proj, coord));
}

Suppose we have a smaller GEMM problem to process for each thread block, whose $M = 32$, $N = 64$, and $K$ is the same as the original large GEMM problem. The tile for the output tensor dataC will be of size $32 \times 64$, and the tile for the input tensors dataA and dataB will be of size $32 \times K$ and $64 \times K$, respectively. Because this $K$ might be very large and it is not static, we also partition $K$ into smaller tiles of static size $4$ in advance, so the tile for dataA will be of size $32 \times 4 \times k$ and the tile for dataB will be of size $64 \times 4 \times k$, where $k = \lceil K / 4 \rceil$ and can be iterated over for accumulation by the threads in the thread block. Because the partition of $K$ is not at the thread block level, the coordinate used for slicing the tile for dataA and dataB will (blockIdx.x, :) and (:, blockIdx.y), respectively.

References

Author

Lei Mao

Posted on

08-01-2025

Updated on

08-01-2025

Licensed under


Comments