Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence. On the Move.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS

Introduction

Using shared memory in CUDA could potentially increase the performance of your program. However, when I tried to use shared memory in templated CUDA kernels, I got weird errors from complier. It turns out that CUDA does not directly allow shared memory usage in template functions. After searching for a while, I found the motivations behind and some solutions to get around this problem.

Problem

Let’s say we want to use shared memory in the following templated kernel function.

template <typename T>
__global__ void kernel(T * d_out, const T * d_in, const unsigned int n)
{
    // Declare shared memory
    extern __shared__ T s_data[];

    // Do things in the kernel
    // ...
}

When you compile the program, you will definitely get the following error.

$ make
/home/leimao/Workspace/GPU_Algorithms_CUDA/reduce/reduce.cu(37): error: declaration is incompatible with previous "s_data"

Problem Causes

The problem root is actually simple. In order to use shared memory, we have to use the keyword extern in our kernel function to declare a variable outside the current scope. It has no problem at all when the kernel function is not templated. However, if your kernel function is templated, there is a chance that you will use different types for the templated the kernel functions, and the extern variable you declared will have conflicted types. Therefore it is not allowed to use shared memory with template type directly in the kernel function.

Solutions

Use CUDPP Header

One solution is to use the SharedMemory struct defined in the open source CUDPP library. You could simply copied the sharedmem.h file to your source directory, and use the following code to declare shared memory. Then everything compiles!

#include "sharedmem.h"

template <typename T>
__global__ void kernel(T * d_out, const T * d_in, const unsigned int n)
{
    // Declare shared memory
    // The follow dynamic allocated memory does not work in templated kernels
    // extern __shared__ T s_data[];

    // To get around
    SharedMemory<T> smem;
    T * s_data = smem.getPointer();

    // Do things in the kernel
    // ...
}

How does it work? Let us check the source code.

template <typename T>
struct SharedMemory
{
    //! @brief Return a pointer to the runtime-sized shared memory array.
    //! @returns Pointer to runtime-sized shared memory array
    __device__ T* getPointer() 
    { 
        extern __device__ void Error_UnsupportedType(); // Ensure that we won't compile any un-specialized types
        Error_UnsupportedType();
        return (T*)0;
    }
    // TODO: Use operator overloading to make this class look like a regular array
};

// Following are the specializations for the following types.
// int, uint, char, uchar, short, ushort, long, ulong, bool, float, and double
// One could also specialize it for user-defined types.

template <>
struct SharedMemory <int>
{
    __device__ int* getPointer() { extern __shared__ int s_int[]; return s_int; }      
};

template <>
struct SharedMemory <unsigned int>
{
    __device__ unsigned int* getPointer() { extern __shared__ unsigned int s_uint[]; return s_uint; }    
};

// ...

We can easily see from the source code that basically SharedMemory for different types of the shared memory have different variable name! No conflicts any more. It is implemented using C++ template specialization so that for different type the variable name could be different.

Use Pointer Casting

The other simple solution is to use pointer type casting.

#include "sharedmem.h"

template <typename T>
__global__ void kernel(T * d_out, const T * d_in, const unsigned int n)
{
    // Declare shared memory
    // The follow dynamic allocated memory does not work in templated kernels
    // extern __shared__ T s_data[];

    // To get around
    extern __shared__ char smem[];
    T * s_data = reinterpret_cast<T *>(smem);

    // Do things in the kernel
    // ...
}

This solution essentially uses the same pointer variable name for memory, but casting the pointer type to different types later on.

References