CuTe Arithmetic Tuple Tensor

Introduction

CuTe Tensor is parameterized by two template parameters: Engine and Layout. The Engine holds an iterator which can be dereferenced for accessing the data.

When we print a Tensor instance from a CuTe program, in most cases, we will see the following printout, which consists of an iterator and a layout.

1
ptr[16b](0x5ded6f122010) o (_128,_32):(_1,_128)

In this case, the iterator is a pointer described as ptr[16b](0x5ded6f122010), which is a pointer to 16-byte elements. The layout is (_128,_32):(_1,_128), which means the tensor has a shape of $(128, 32)$ and a stride of $(1, 128)$.

However, sometimes, we might encounter a different printout from a CuTe Tensor instance, such as:

1
ArithTuple(0,0) o (_128,_128):(_1@0,_1@1)

In this case, the iterator is an ArithmeticTuple described as ArithTuple(0,0) and the layout is (_128,_128):(_1@0,_1@1), both of which look quite different from the previous common case.

In the CuTe official documentation, such CuTe Tensors were referred to as “CuTe TMA Tensors”. However, because it is not only used for TMA operations in CuTe and it already existed before TMA becomes available on NVIDIA Hopper GPUs, personally I don’t like to call it “CuTe TMA Tensor”. Instead, I would like to call it “CuTe Arithmetic Tuple Tensor” because the iterator used in the CuTe tensor is an ArithmeticTuple, as opposed to “CuTe Data Tensor” where the iterator is a pointer pointing to data.

In this article, I would like to quickly discuss the CuTe Arithmetic Tuple Tensor.

CuTe Arithmetic Tuple Tensor Example

CuTe Arithmetic Tuple Identity Tensor

In addition to the CuTe TMA operations, CuTe Arithmetic Tuple Tensor is also commonly used for computing the coordinates of each element in the original tensor from a partitioned tensor for data access boundary checking. More specifically, the cute::make_identity_tensor function is often used for creating a CuTe Arithmetic Tuple tensor that represents the coordinates of each element in the original CuTe Data tensor. Then the CuTe tensor will follow the same problem partitioning as the correspondent CuTe Data tensor.

The following is a simple host preview example of using cute::make_identity_tensor and CuTe partition functions to compute the coordinates of each element in a partitioned tensor for a MMA problem.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <cassert>
#include <fstream>
#include <iomanip>
#include <iostream>

#include <cute/layout.hpp>
#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>

#include <thrust/host_vector.h>

int main(int argc, const char** argv)
{
constexpr int M{512};
constexpr int N{512};
constexpr int K{512};

constexpr int bM{128};
constexpr int bN{128};
constexpr int bK{32};

// Configure tiled MMA.
using MmaTraits = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>;
using MmaAtomShape = MmaTraits::Shape_MNK;
auto const mma_atom = cute::MMA_Atom<MmaTraits>{};
auto const mma_atom_shape = MmaAtomShape{};
// Repeating the mma atom along the M, N, and K dimensions.
// This increases the number of threads to process the tiled MMA.
constexpr int MMA_LAYOUT_M{2};
constexpr int MMA_LAYOUT_N{2};
constexpr int MMA_LAYOUT_K{1};
auto mma_layout{cute::make_layout(
cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{},
cute::Int<MMA_LAYOUT_K>{}))};
// Repeating the mma processing along the M, N, and K dimensions.
// This does not increase the number of threads to process the tiled MMA.
// But the number of registers required for processing the tiled MMA
// increases.
constexpr int NUM_MMA_TILE_M{1};
constexpr int NUM_MMA_TILE_N{2};
constexpr int NUM_MMA_TILE_K{1};
constexpr int MMA_TILE_M{cute::get<0>(mma_atom_shape) * MMA_LAYOUT_M *
NUM_MMA_TILE_M};
constexpr int MMA_TILE_N{cute::get<1>(mma_atom_shape) * MMA_LAYOUT_N *
NUM_MMA_TILE_N};
constexpr int MMA_TILE_K{cute::get<2>(mma_atom_shape) * MMA_LAYOUT_K *
NUM_MMA_TILE_K};
auto mma_tile{cute::make_tile(cute::Int<MMA_TILE_M>{},
cute::Int<MMA_TILE_N>{},
cute::Int<MMA_TILE_K>{})};
auto tiled_mma{cute::make_tiled_mma(mma_atom, mma_layout, mma_tile)};

constexpr auto NUM_THREADS{cute::size(tiled_mma)};
CUTE_STATIC_ASSERT(NUM_THREADS ==
MMA_LAYOUT_M * MMA_LAYOUT_N * MMA_LAYOUT_K *
cute::size(decltype(mma_atom)::ThrID{}));

std::cout << "mma_atom" << std::endl;
cute::print(mma_atom);
std::cout << std::endl;

std::cout << "tiled_mma" << std::endl;
cute::print(tiled_mma);
std::cout << std::endl;

// Partition via MMA.
// set an arbitrary thread index.
constexpr int THREAD_IDX{1};
CUTE_STATIC_ASSERT(THREAD_IDX < NUM_THREADS);
CUTE_STATIC_ASSERT(THREAD_IDX >= 0);

auto thread_mma{tiled_mma.get_slice(THREAD_IDX)};

// Set an arbitrary block index.
auto const block_coord{cute::make_coord(1, 1)};
auto const global_identity_tensor{
cute::make_identity_tensor(cute::make_shape(M, N))};
auto const block_identity_tensor{cute::local_tile(
global_identity_tensor,
cute::make_tile(cute::Int<bM>{}, cute::Int<bN>{}), block_coord)};
auto const thread_identity_tensor{
thread_mma.partition_C(block_identity_tensor)};

auto h_C = thrust::host_vector<cute::half_t>(M * N);
auto global_layout_C{
cute::make_layout(cute::make_shape(cute::Int<M>{}, cute::Int<N>{}),
cute::make_stride(cute::Int<1>{}, cute::Int<M>{}))};
auto const global_tensor{cute::make_tensor(h_C.data(), global_layout_C)};
auto const block_tensor{cute::local_tile(
global_tensor, cute::make_tile(cute::Int<bM>{}, cute::Int<bN>{}),
block_coord)};
auto const thread_tensor{thread_mma.partition_C(block_tensor)};

std::cout << "global_identity_tensor" << std::endl;
cute::print(global_identity_tensor);
std::cout << std::endl;
std::cout << "block_identity_tensor" << std::endl;
cute::print(block_identity_tensor);
std::cout << std::endl;
// cute::print_tensor(block_identity_tensor);
// std::cout << std::endl;
std::cout << "thread_identity_tensor" << std::endl;
// cute::print(thread_identity_tensor);
// std::cout << std::endl;
cute::print_tensor(thread_identity_tensor);
std::cout << std::endl;
std::cout << "global_tensor" << std::endl;
cute::print(global_tensor);
std::cout << std::endl;
std::cout << "block_tensor" << std::endl;
cute::print(block_tensor);
std::cout << std::endl;
std::cout << "thread_tensor" << std::endl;
cute::print(thread_tensor);
std::cout << std::endl;

return 0;
}

The coordinate tensors created by cute::make_identity_tensor and partitioned by CuTe partition functions, including global_identity_tensor, block_identity_tensor, and thread_identity_tensor, can be printed using cute::print or cute::print_tensor. We could see that they are all CuTe Arithmetic Tuple Tensors.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
mma_atom
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

tiled_mma
TiledMMA
ThrLayoutVMNK: (_32,_2,_2,_1):(_1,_32,_64,_0)
PermutationMNK: (_32,_32,_16)
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

global_identity_tensor
ArithTuple(_0,_0) o (512,512):(_1@0,_1@1)
block_identity_tensor
ArithTuple(128,128) o (_128,_128):(_1@0,_1@1)
thread_identity_tensor
ArithTuple(128,130) o ((_2,_2),_4,_8):((_1@1,_8@0),_32@0,_16@1):
(128,130) (160,130) (192,130) (224,130)
(128,131) (160,131) (192,131) (224,131)
(136,130) (168,130) (200,130) (232,130)
(136,131) (168,131) (200,131) (232,131)
--------------------
(128,146) (160,146) (192,146) (224,146)
(128,147) (160,147) (192,147) (224,147)
(136,146) (168,146) (200,146) (232,146)
(136,147) (168,147) (200,147) (232,147)
--------------------
(128,162) (160,162) (192,162) (224,162)
(128,163) (160,163) (192,163) (224,163)
(136,162) (168,162) (200,162) (232,162)
(136,163) (168,163) (200,163) (232,163)
--------------------
(128,178) (160,178) (192,178) (224,178)
(128,179) (160,179) (192,179) (224,179)
(136,178) (168,178) (200,178) (232,178)
(136,179) (168,179) (200,179) (232,179)
--------------------
(128,194) (160,194) (192,194) (224,194)
(128,195) (160,195) (192,195) (224,195)
(136,194) (168,194) (200,194) (232,194)
(136,195) (168,195) (200,195) (232,195)
--------------------
(128,210) (160,210) (192,210) (224,210)
(128,211) (160,211) (192,211) (224,211)
(136,210) (168,210) (200,210) (232,210)
(136,211) (168,211) (200,211) (232,211)
--------------------
(128,226) (160,226) (192,226) (224,226)
(128,227) (160,227) (192,227) (224,227)
(136,226) (168,226) (200,226) (232,226)
(136,227) (168,227) (200,227) (232,227)
--------------------
(128,242) (160,242) (192,242) (224,242)
(128,243) (160,243) (192,243) (224,243)
(136,242) (168,242) (200,242) (232,242)
(136,243) (168,243) (200,243) (232,243)

global_tensor
ptr[16b](0x7a06afbba010) o (_512,_512):(_1,_512)
block_tensor
ptr[16b](0x7a06afbda110) o (_128,_128):(_1,_512)
thread_tensor
ptr[16b](0x7a06afbda910) o ((_2,_2),_4,_8):((_512,_8),_32,_8192)

For example, the CuTe Arithmetic Tuple Tensor thread_identity_tensor has an iterator ArithTuple(128,130) and a layout of ((_2,_2),_4,_8):((_1@1,_8@0),_32@0,_16@1). The cute::print_tensor function will iterate through each element in the tensor, compute the coordinates of each element based on the iterator and layout, and print the coordinates. In this case, the stride of the layout is ((_1@1,_8@0),_32@0,_16@1), which is different from the integer stride like (_1,_128) in CuTe Data Tensor, yet CuTe layout algebra, such as logical division, still seem to work. We will discuss how CuTe algebra works for layouts whose strides are arithmetic tuples in the next section.

In fact, before I learned the difference between CuTe Arithmetic Tuple Tensor and CuTe Data Tensor, I was fooled by the name of cute::make_identity_tensor. I thought cute::make_identity_tensor would produce a CuTe Data Tensor whose storage saves the coordinates. However, if this is the case, it will usually bring the problem that the coordinates data for large tensors will not fit into the storage, especially when the storage is registers, which makes completely no sense in high performance computing. Therefore, the coordinates are generated on-the-fly from CuTe Arithmetic Tuple Tensor without taking additional storage. This is also consistent with how we compute the coordinates of each element in a tensor when we write CUDA kernels without using CuTe.

CuTe Layout Algebra In CuTe Arithmetic Tuple Tensor

CuTe Layout Coordinate Mapping With Arithmetic Tuple Stride

In my previous article “CuTe Layout Algebra”, I have discussed how CuTe Layout Algebra works for the layouts whose shape and stride are all integers. In CuTe Arithmetic Tuple Tensor, the stride of the layout is not an integer but an arithmetic tuple, such as (_1@0,_1@1).

So what is 1@0, 1@1, etc. in the non-integer stride? The descriptions can be found in the CuTe official documentation. Basically, they represent basis elements in an infinite-dimensional vector space.

String Representation Description
1 1
1@0 (1,0,...)
1@1 (0,1,0,...)
1@0@0 ((1,0,...),0,...)
1@1@0 ((0,1,0,...),0,...)
1@0@1 (0,(1,0,...),0,...)
1@1@1 (0,(0,1,0,...),0,...)

The basis elements can be nested. That’s why we could see multiple @ in the string representation. For example, 1@0@1 represents the basis element (0,(1,0,...),0,...).

The basis elements can be scaled by integers. For example, 3@1 represents the basis element (0,3,0,...).

The basis elements can be added together. For example, 1@0 + 2@1 represents the basis element (1,2,0,...).

By defining the above properties of basis elements, we could compute coordinates of each element in a CuTe Arithmetic Tuple Tensor based on its iterator and layout, which is essentially an inner product.

Taking an example of thread_identity_tensor from the previous section, its iterator is ArithTuple(128,130) and layout is ((_2,_2),_4,_8):((_1@1,_8@0),_32@0,_16@1). If we have an input coordinate ((1,1),2,3) to the layout, the inner product can be computed as (1,1) x (_1@1,_8@0) + 2 x _32@0 + 3 x _16@1 = 1 x (0,1,0,...) + 1 x (8,0,0,...) + 2 x (32,0,0,...) + 3 x (0,16,0,...) = (0,1,0,...) + (8,0,0,...) + (64,0,0,...) + (0,48,0,...) = (72,49,0,...). Then we could add the iterator ArithTuple(128,130) to the output coordinate to get the final coordinate of the element in the original tensor: (72+128,49+130) = (200,179).

CuTe Layout Algebra With Arithmetic Tuple Stride

CuTe layout coordinate mapping with arithmetic tuple stride seems to be straightforward. The next question is how CuTe layout algebra, such as composition, complement, logical division, and logical product, can be applied in this context. It turns out that CuTe layout algebra is still applicable to the layouts whose strides are arithmetic tuples.

Without loss of generality, suppose we have a 2D layout with arithmetic tuple L1 = (S1,S2):(T1@i,T2@j). It’s CuTe layout algebra is exactly the same as a 2D layout with integer stride L2 = (S1,S2):(T1,T2). In many cases, L2 will not be a valid layout for a CuTe Data Tensor because the index mapping might not be injective. However, L2 can still be used for CuTe layout algebra.

For example, CuTe composition is an essential operation for both logical division and logical product. In the following example, we found that the CuTe composition and the CuTe division behaviors are the same for layouts whose strides are arithmetic tuples as for layouts whose strides are integers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cassert>
#include <fstream>
#include <iomanip>
#include <iostream>

#include <cute/layout.hpp>
#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>

#include <thrust/host_vector.h>

int main(int argc, const char** argv)
{
auto const tiler_layout{cute::make_layout(
cute::make_shape(cute::Int<4>{}), cute::make_stride(cute::Int<5>{}))};
std::cout << "tiler_layout" << std::endl;
cute::print(tiler_layout);
std::cout << std::endl;

auto const arithmetic_tuple_strided_layout{
cute::make_layout(cute::make_shape(cute::Int<10>{}, cute::Int<2>{}),
cute::make_stride(cute::Int<4>{} * cute::E<0>{},
cute::Int<5>{} * cute::E<1>{}))};
auto const arithmetic_tuple_strided_composed_layout{
cute::composition(arithmetic_tuple_strided_layout, tiler_layout)};
auto const arithmetic_tuple_strided_divided_layout{
cute::logical_divide(arithmetic_tuple_strided_layout, tiler_layout)};
std::cout << "arithmetic_tuple_strided_layout" << std::endl;
cute::print(arithmetic_tuple_strided_layout);
std::cout << std::endl;
std::cout << "arithmetic_tuple_strided_composed_layout" << std::endl;
cute::print(arithmetic_tuple_strided_composed_layout);
std::cout << std::endl;
std::cout << "arithmetic_tuple_strided_divided_layout" << std::endl;
cute::print(arithmetic_tuple_strided_divided_layout);
std::cout << std::endl;

auto const integer_strided_layout{
cute::make_layout(cute::make_shape(cute::Int<10>{}, cute::Int<2>{}),
cute::make_stride(cute::Int<4>{}, cute::Int<5>{}))};
auto const integer_strided_composed_layout{
cute::composition(integer_strided_layout, tiler_layout)};
auto const integer_strided_divided_layout{
cute::logical_divide(integer_strided_layout, tiler_layout)};
std::cout << "integer_strided_layout" << std::endl;
cute::print(integer_strided_layout);
std::cout << std::endl;
std::cout << "integer_strided_composed_layout" << std::endl;
cute::print(integer_strided_composed_layout);
std::cout << std::endl;
std::cout << "integer_strided_divided_layout" << std::endl;
cute::print(integer_strided_divided_layout);
std::cout << std::endl;

return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
tiler_layout
(_4):(_5)
arithmetic_tuple_strided_layout
(_10,_2):(_4@0,_5@1)
arithmetic_tuple_strided_composed_layout
((_2,_2)):((_20@0,_5@1))
arithmetic_tuple_strided_divided_layout
(((_2,_2)),_5):(((_20@0,_5@1)),_4@0)
integer_strided_layout
(_10,_2):(_4,_5)
integer_strided_composed_layout
((_2,_2)):((_20,_5))
integer_strided_divided_layout
(((_2,_2)),_5):(((_20,_5)),_4)

We could also verify the idea using the example from the previous section. Because the arithmetic tuple stride of the global_identity_tensor in the example is (_1@0,_1@1), we will use the integer stride (_1,_1) for global_layout_C instead of the original integer stride (_1,_M).

1
2
3
4
5
6
// auto global_layout_C{
// cute::make_layout(cute::make_shape(cute::Int<M>{}, cute::Int<N>{}),
// cute::make_stride(cute::Int<1>{}, cute::Int<M>{}))};
auto global_layout_C{cute::make_layout(
cute::make_shape(cute::Int<M>{}, cute::Int<N>{}),
cute::make_stride(cute::Int<1>{}, cute::Int<1>{}))};

The layouts of the integer strided tensors, including global_tensor, block_tensor, and thread_tensor, and the layouts of the arithmetic tuple strided tensors, including global_identity_tensor, block_identity_tensor, and thread_identity_tensor, are printed out as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
mma_atom
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

tiled_mma
TiledMMA
ThrLayoutVMNK: (_32,_2,_2,_1):(_1,_32,_64,_0)
PermutationMNK: (_32,_32,_16)
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_16)
LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))

global_identity_tensor
ArithTuple(_0,_0) o (512,512):(_1@0,_1@1)
block_identity_tensor
ArithTuple(128,128) o (_128,_128):(_1@0,_1@1)
thread_identity_tensor
ArithTuple(128,130) o ((_2,_2),_4,_8):((_1@1,_8@0),_32@0,_16@1):
(128,130) (160,130) (192,130) (224,130)
(128,131) (160,131) (192,131) (224,131)
(136,130) (168,130) (200,130) (232,130)
(136,131) (168,131) (200,131) (232,131)
--------------------
(128,146) (160,146) (192,146) (224,146)
(128,147) (160,147) (192,147) (224,147)
(136,146) (168,146) (200,146) (232,146)
(136,147) (168,147) (200,147) (232,147)
--------------------
(128,162) (160,162) (192,162) (224,162)
(128,163) (160,163) (192,163) (224,163)
(136,162) (168,162) (200,162) (232,162)
(136,163) (168,163) (200,163) (232,163)
--------------------
(128,178) (160,178) (192,178) (224,178)
(128,179) (160,179) (192,179) (224,179)
(136,178) (168,178) (200,178) (232,178)
(136,179) (168,179) (200,179) (232,179)
--------------------
(128,194) (160,194) (192,194) (224,194)
(128,195) (160,195) (192,195) (224,195)
(136,194) (168,194) (200,194) (232,194)
(136,195) (168,195) (200,195) (232,195)
--------------------
(128,210) (160,210) (192,210) (224,210)
(128,211) (160,211) (192,211) (224,211)
(136,210) (168,210) (200,210) (232,210)
(136,211) (168,211) (200,211) (232,211)
--------------------
(128,226) (160,226) (192,226) (224,226)
(128,227) (160,227) (192,227) (224,227)
(136,226) (168,226) (200,226) (232,226)
(136,227) (168,227) (200,227) (232,227)
--------------------
(128,242) (160,242) (192,242) (224,242)
(128,243) (160,243) (192,243) (224,243)
(136,242) (168,242) (200,242) (232,242)
(136,243) (168,243) (200,243) (232,243)

global_tensor
ptr[16b](0x73332bde7010) o (_512,_512):(_1,_1)
block_tensor
ptr[16b](0x73332bde7210) o (_128,_128):(_1,_1)
thread_tensor
ptr[16b](0x73332bde7214) o ((_2,_2),_4,_8):((_1,_8),_32,_16)

The layouts of global_identity_tensor, block_identity_tensor, and thread_identity_tensor will are (_512,_512):(_1,_1), (_128,_128):(_1,_1), and ((_2,_2),_4,_8):((_1@1,_8@0),_32@0,_16@1), respectively. The layouts of global_tensor, block_tensor, and thread_tensor will be (_512,_512):(_1,_1), (_128,_128):(_1,_1), and ((_2,_2),_4,_8):((_1,_8),_32,_16), respectively. This verifies that the CuTe layout algebra works the same for layouts whose strides are arithmetic tuples as for layouts whose strides are integers.

To have a better intuition of how CuTe layout algebra works for layouts whose strides are arithmetic tuples, try deriving the CuTe composition of global_identity_tensor and tiler_layout manually.

Conclusions

CuTe Arithmetic Tuple Tensor is just like a Python Generator expression, which generates coordinates on-the-fly based on the arithmetic tuple and the layout.

References

Author

Lei Mao

Posted on

10-20-2025

Updated on

10-20-2025

Licensed under


Comments