In CUDA programming, local memory is private storage for an executing thread, and is not visible outside of that thread. The local memory space resides in device memory, so local memory accesses have the same high latency and low bandwidth as global memory accesses and are subject to the same requirements for memory coalescing.
An automatic variable declared without the __device__, __shared__ and __constant__ memory space specifiers can either be placed in registers or in local memory by the compiler. It will be likely placed in local memory if it is one of the following:
Arrays for which it cannot determine that they are indexed with constant quantities,
Large structures or arrays that would consume too much register space,
Any variable if the kernel uses more registers than available (this is also known as register spilling).
It is very straightforward to understand the second and the third points. However, the first point is being a little bit tricky since it implies that even for very small arrays it can be placed in local memory rather than in registers and most of the time we would like those small arrays to be placed in registers for better performance.
In this blog post, I would like to show an example of how the compiler decides to place an array in local memory rather than in registers and discuss the general rules that a user can follow to avoid small arrays being placed in local memory.
CUDA Local Memory
In the following example, I created two CUDA kernels that compute the running mean of an input array given a fixed window size. Both of the kernels declared a local array window whose size is known at the compile time. The implementations of the two kernels are almost exactly the same except the first kernel uses a straightforward indexing to access the window array, while the second kernel uses an index that seems to be less trivial.
template <int WindowSize> __global__ voidrunning_mean_register_array(floatconst* input, float* output, int n) { float window[WindowSize]; intconst thread_idx{ static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x)}; intconst stride{static_cast<int>(blockDim.x * gridDim.x)}; for (int i{thread_idx}; i < n; i += stride) { // Read data into the window. for (int j{0}; j < WindowSize; ++j) { intconst idx{i - WindowSize / 2 + j}; window[j] = (idx < 0 || idx >= n) ? 0 : input[idx]; } // Compute the mean from the window. float sum{0}; for (int j{0}; j < WindowSize; ++j) { sum += window[j]; } floatconst mean{sum / WindowSize}; // Write the mean to the output. output[i] = mean; } }
template <int WindowSize> __global__ voidrunning_mean_local_memory_array(floatconst* input, float* output, int n) { float window[WindowSize]; intconst thread_idx{ static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x)}; intconst stride{static_cast<int>(blockDim.x * gridDim.x)}; for (int i{thread_idx}; i < n; i += stride) { // Read data into the window. for (int j{0}; j < WindowSize; ++j) { intconst idx{i - WindowSize / 2 + j}; window[j] = (idx < 0 || idx >= n) ? 0 : input[idx]; } // Compute the mean from the window. float sum{0}; for (int j{0}; j < WindowSize; ++j) { // This index accessing the window array cannot be resolved at the // compile time by the compiler, even if such indexing would not // affect the correctness of the kernel. The consequence is the // compiler will place the window array in the local memory rather // than in the register file. intconst idx{(j + n) % WindowSize}; sum += window[idx]; } floatconst mean{sum / WindowSize}; // Write the mean to the output. output[i] = mean; } }
// Verify the correctness of the kernel given a window size and a launch // function. template <int WindowSize> voidverify_running_mean(int n, cudaError_t (*launch_func)(floatconst*, float*, int, cudaStream_t)) { std::vector<float> h_input_vec(n, 0.f); std::vector<float> h_output_vec(n, 1.f); std::vector<float> h_output_vec_ref(n, 2.f); // Fill the input vector with values. for (int i{0}; i < n; ++i) { h_input_vec[i] = static_cast<float>(i); } // Compute the reference output vector. for (int i{0}; i < n; ++i) { float sum{0}; for (int j{0}; j < WindowSize; ++j) { intconst idx{i - WindowSize / 2 + j}; floatconst val{(idx < 0 || idx >= n) ? 0 : h_input_vec[idx]}; sum += val; } h_output_vec_ref[i] = sum / WindowSize; } // Allocate device memory. float* d_input; float* d_output; CHECK_CUDA_ERROR(cudaMalloc(&d_input, n * sizeof(float))); CHECK_CUDA_ERROR(cudaMalloc(&d_output, n * sizeof(float))); // Copy data to the device. CHECK_CUDA_ERROR(cudaMemcpy(d_input, h_input_vec.data(), n * sizeof(float), cudaMemcpyHostToDevice)); CHECK_CUDA_ERROR(cudaMemcpy(d_output, h_output_vec.data(), n * sizeof(float), cudaMemcpyHostToDevice)); // Launch the kernel. cudaStream_t stream; CHECK_CUDA_ERROR(cudaStreamCreate(&stream)); CHECK_CUDA_ERROR(launch_func(d_input, d_output, n, stream)); CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); // Copy the result back to the host. CHECK_CUDA_ERROR(cudaMemcpy(h_output_vec.data(), d_output, n * sizeof(float), cudaMemcpyDeviceToHost)); // Check the result. for (int i{0}; i < n; ++i) { if (h_output_vec.at(i) != h_output_vec_ref.at(i)) { std::cerr << "Mismatch at index " << i << ": " << h_output_vec.at(i) << " != " << h_output_vec_ref.at(i) << std::endl; std::exit(EXIT_FAILURE); } } // Free device memory. CHECK_CUDA_ERROR(cudaFree(d_input)); CHECK_CUDA_ERROR(cudaFree(d_output)); CHECK_CUDA_ERROR(cudaStreamDestroy(stream)); }
intmain() { // Try different window sizes from small to large. constexprint WindowSize{32}; intconst n{8192}; verify_running_mean<WindowSize>( n, launch_running_mean_register_array<WindowSize>); verify_running_mean<WindowSize>( n, launch_running_mean_local_memory_array<WindowSize>); return0; }
To build and run the example, please run the following commands. There should be no error message encountered when running the example.
In the PTX of the two kernels, we could find that the first kernel has nothing declared with .local directive, while the second kernel has a local array __local_depot1 declared with .local directive. This confirms that the first kernel has the array window placed in registers, while the second kernel has the array window placed in local memory. Even if the local array declared in both kernels are of the same size, because the compiler cannot determine the array used in the second kernel is indexed with constant quantities, it is placed in local memory.
To avoid small arrays being placed in local memory, we should avoid using very complex indexing that the compiler cannot determine if they are constant quantities. But the question is how do we know if the compiler can determine if the indexing is constant quantities or not?
In turns out that registers actually cannot be indexed, so does the array placed in registers. If the small array is placed in registers, the equivalent constant indexing of the small array can also be written in the program as well.
For example, the following implementation from the first kernel running_mean_register_array,
1 2 3 4 5 6 7
constexprint WindowSize{4}; float window[WindowSize]; float sum{0}; for (int j{0}; j < WindowSize; ++j) { sum += window[j]; }
has an equivalent form as if the declaration of the array window is unnecessary.
1 2 3 4 5 6
float window0, window1, window2, window3; float sum{0}; sum += window0; sum += window1; sum += window2; sum += window3;
whereas the following implementation from the second kernel running_mean_local_memory_array,
has no equivalent form as if the declaration of the array window is necessary because the value of n can only be known at the compile time.
Mathematically, it is also equivalent as the following form, but it is a non-trivial task for the compiler to figure out.
1 2 3 4 5 6
float window0, window1, window2, window3; float sum{0}; sum += window0; sum += window1; sum += window2; sum += window3;
In fact, this is also case for CUDA TensorCore MMA PTX because TensorCore MMA needs to read data from registers for the best performance. For example, the SM80_16x8x8_F16F16F16F16_TN MMA in CUTLASS is implemented as follows and the MMA PTX only accesses registers, even if the buffers were declared as arrays.