"BankConflictsCuda"

Avoiding Memory Bank Conflicts in Cuda Programs

Global memory access latencies are very high (relative to computing latencies) and a common source of under-utilization of CPUs and GPUs. Accessing memory efficiently is particularly important for harnessing the power of vector processors, such as SIMD processors and GPUs. Memory banks have been used for a long time to hide access latency for vector processors. The first vector machine, the CRAY-1 computer, already had 16 memory banks. To hide memory latency, memory banks need to be accessed conflict-free. The CRAY-1 already instructed its users how to avoid bank conflicts. Today, bank conflicts are of prime importance for CUDA kernels and can seriously impair the throughput of a GPU.

As this post shows, taking good care of memory banks in shared memory can improve the speed of Cuda kernels by 50%. Despite its importance, shared memory bank access is not documented well in the CUDA programming guide. After introducing shared memory, the blog post presents several undocumented bank-free access patterns. The final section of the blog post shows that recent additions to the CUDA programming model (requiring compute capability 8.0 and above) help avoid many bank conflicts. Older suggestions to use memory padding are instead shown to be sub-optimal.

The Importance of Shared Memory (L1 Cache)

Shared memory refers to the L1 cache of Streaming-Multiprocessors (which resemble a collection of SIMD processors). In contrast to CPUs, the L1 cache can be programmed and is not controlled entirely by the execution units on the hardware. Shared memory is much faster but much smaller than global memory. For the V100 architecture, the measured size and bandwidth are:

Global Memory L1 Cache
Size 16GB 96 KB
Bandwidth 750 GB/sec 12,080 GB/sec

The same qualitative differences hold for CPU: Global memory (RAM) comes in all sizes, and shared memory is often 32 KB or 48 KB. RAM bandwidth is about 10GB/sec, while L1 bandwidth is 80GB/sec. However, it can be seen the relative L1 bandwidth is larger on the GPU than the CPU.

Shared memory is partitioned into banks. Consecutive 32bit words belong to one bank, and there are 32 different memory banks, as described in the Cuda Programming Guide. As the size of the shared memory is much larger than the number of banks, different shared memory positions lie on the same bank. The threads in a warp can access different banks without any conflicts. If threads want to access memory from the same bank (but in different memory positions), their access is serialized. A serialized access means that the threads stall. A tabular representation of the banks can be seen below:

Byte Position 0...3 3...7 ... 124...127 128...131 132...137 ...
Memory Bank 0 1 ... 31 0 1

To illustrate the memory bank conflicts in a simple example, consider the following two kernels:

__global__ void ConflictFreeAccess(float* sum) {
    constexpr size_t sz = 32 * 32;
    __shared__ float shmem[sz];
    shmem[threadIdx.x] = threadIdx.x;
    [...CODE NOT SHOWN…]
}

__global__ void ConflictAccess(float* sum) {
    __shared__ float shmem[32 * 32];
    constexpr size_t sz = 32 * 32;
    shmem[threadIdx.x * 32] = threadIdx.x;
    [...CODE NOT SHOWN…]
}

Both kernels allocate a shared buffer for 1024 floats. Threads in the ConflictFreeAccess write their ID into successive memory positions (and successive banks). Instead, threads in the ConflictAccess write their IDs into the same bank. The hidden code prevents the compiler from optimizing the kernel away. Bank conflicts get identified by the Cuda Nsight profiler. The profiler shows that there is a 31-way conflict in the second kernel but no conflict in the first kernel:

  ConflictFreeAccess(float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 7.5
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                             Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                        0
    -------------------------------------------------------- ----------- ------------

  ConflictAccess(float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 7.5
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                             Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                    31
    -------------------------------------------------------- ----------- ------------

Undocumented Features of Memory banks

Bank conflicts were originally described with floats and banks spanning four bytes (one float). However, highly performant GEMM kernels neither use floats nor load values individually. Most commonly, they use half floats and load, in SIMD style, packed 128-bit values. Both would incur a memory bank conflict. In tabular form, the two examples are:

Banks 0 1 ... 15 16 .. 31 0
Load 16bit float 0, 1 2, 3 30, 31 - - -
Vectorized, loat 8 16bit floats 0 0 3 4 7 8

These two kernels (see here for the complete implementation) can be profiled again with Nsight. Profiling shows that the two kernels do not incur any bank conflicts. Although this is great for performance (and does not force the programmer to trade off vectorized access with bank conflicts), this is not documented in the Cuda programming guide and makes kernels hard to reason about. Memory bank conflicts primarily need to be identified with a profiler. The following section highlights how crucial shared memory bank conflicts are in a high-performance GEMM implementation.

Addressing Memory Bank conflicts with Async Memory Access

As described, shared memory has very high throughput, and bank conflicts need to be identified empirically. Several avenues exist to circumvent bank conflicts. Some suggest [padding]((https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/) of the memory buffer. However, I usually find async memory access the fastest and most reliable solution. The section below documents that avoiding bank conflicts increases the performance of a GEMM kernel by 50%, enabled by async memory accesses using modern PTX intrinsics.

Consider the following GEMM kernel below. For simplicity, the kernel uses floats and does not use any Tensor core intrinsics.

template <typename T, typename accum_type, typename tb, typename wb>
__global__ void GEMM(T *__restrict__ A, T *__restrict__ B, T *__restrict__ C,
                     size_t M, size_t N, size_t K, accum_type alpha,
                     accum_type beta) {
    assert(M % tb::kM == 0);
    assert(N % tb::kN == 0);
    assert(K % tb::kK == 0);

    const size_t block_base_x = blockIdx.x * tb::kN;
    const size_t block_base_y = blockIdx.y * tb::kM;

    constexpr size_t skew = 8;
    __shared__ T As[tb::kM][tb::kK + skew];  // long and skinny
    __shared__ T Bs[tb::kK][tb::kN + skew];  // short and wide
    // non quadratic to max load instruction usage

    constexpr size_t n_gld =
        CUDA_VECTORIZED_BITS_LOAD / (sizeof(T) * 8);  // bytes to bits

    const size_t total_threadId = blockDim.x * threadIdx.y + threadIdx.x;
    // assert(total_threadId < 256);
    const size_t thread_num = blockDim.x * blockDim.y;

    const size_t stride_a = thread_num * n_gld / tb::kK;
    assert(stride_a <= tb::kM);  // One thread needs load one line

    const size_t stride_b = thread_num * n_gld / tb::kN;
    assert(stride_b <= tb::kK);  // Thread needs to load at least one line

    Loader<T, tb::kM, tb::kK + skew, Index<tb::kM, tb::kK, n_gld>> LoaderA{
        A, As, total_threadId, blockIdx.y * tb::kM * K, K, stride_a};
    Loader<T, tb::kK, tb::kN + skew, Index<tb::kK, tb::kN, n_gld>> LoaderB{
        B, Bs, total_threadId, blockIdx.x * tb::kN, N, stride_b};

    accum_type tmp[wb::kM][wb::kN] = {0.};
    T registerA[wb::kM] = {0.};
    T registerB[wb::kN] = {0.};

    // Needed to access the variables from tmp
    const size_t tx = threadIdx.x;
    const size_t ty = threadIdx.y;

    for (size_t bk = 0; bk < K; bk += tb::kK) {
        LoaderA.load();
        LoaderB.load();
        LoaderA.next(tb::kK);
        LoaderB.next(tb::kK * N);
        __syncthreads();
#pragma unroll
        for (size_t k = 0; k < tb::kK; k++) {
#pragma unroll
            for (size_t i = 0; i < wb::kM; i++) {
                registerA[i] = As[(ty + i * 16)][k];
            };
#pragma unroll
            for (size_t i = 0; i < wb::kN; ++i) {
                registerB[i] = Bs[k][tx + i * 16];
            }
#pragma unroll
            for (int i = 0; i < wb::kM; ++i) {
#pragma unroll
                for (int j = 0; j < wb::kN; ++j) {
                    tmp[i][j] += (accum_type)(registerA[i] * registerB[j]);
                }
            }
        }
        __syncthreads();
    }
#pragma unroll
    for (size_t i = 0; i < wb::kM; i++) {
        for (size_t j = 0; j < wb::kN; j++) {
            const size_t out = (ty + i * 16) * N + tx + j * 16 +
                               block_base_y * N + block_base_x;
            accum_type res = alpha * tmp[i][j] + beta * (accum_type)C[out];
            C[out] = (T)res;
        }
    }
}

The entire code can be found here. Memory bank conflicts can occur when storing to or loading from shared memory. The important parts for memory throughput relate to the Loader class and the scheduling of data loads:

LoaderA.load();
LoaderB.load();
LoaderA.next(tb::kK);
LoaderB.next(tb::kK * N);

Loading Data from Global to Shared Memory Loading data from shared to global memory is done in several steps. First, a shared memory buffer is created (for example, the shared memory buffer for matrix A is called As). The Loader class encapsulates the memory transfer. The loader class looks as follows:

template <typename T, size_t rows, size_t cols, typename ThreadOffset>
struct Loader {
    T (&shmem_)[rows][cols];
    T *global_ptr_;
    const ThreadOffset offset_;
    const size_t stride_;
    const size_t ld_;
    __host__ __device__ Loader(T *global_ptr, T (&shmem)[rows][cols],
                               size_t threadId, size_t blockOffset, size_t ld,
                               size_t stride)
        : global_ptr_(global_ptr + blockOffset),
          shmem_(shmem),
          offset_(threadId),
          ld_(ld),
          stride_(stride){};
    __device__ void load() {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
            const T *src = global_ptr_ + row * ld_ + global_idx;
            T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
            int4 t = reinterpret_cast<const int4 *>(src)[0];
            reinterpret_cast<int4 *>(dst)[0] = t;
        }
    }
};

The constructor calculates an offset into the global matrix and creates a pointer to the shared memory. Global to shared memory is loaded in the function load: At first, each thread calculates the appropriate entry point to the global memory. Loading data happens inside a loop: The source and destination pointer for the thread are calculated, and each thread moves 128-bit (the maximum payload) from global to shared memory. It is instructive to look at the generated PTX code for the load instructions. Thi is:

[Cuda:]
int4 t = reinterpret_cast<const int4 *>(src)[0];
reinterpret_cast<int4 *>(dst)[0] = t;
[PTX:]
ld.global.nc.v4.u32     {%r26, %r27, %r28, %r29}, [%rd73];
st.shared.v4.u32        [%r25], {%r26, %r27, %r28, %r29};

Loading global memory proceeds in a non-coalesced manner. This means the load goes directly to the registers, bypassing the L1 cache to avoid cache pollution. It is vectorized and comprises four unsigned ints. The data is stored in the registers 26-29, and register 73 contains the load address. The values are then stored in shared memory.

The entire program above achieves a throughput of 8 TFlops on an A5000. Such throughput represents 40 \% of the available throughput of the GPU. When storing the memory, the threads might access the same bank, and their access gets serialized. In fact, the profiler identifies 1,8 million bank conflicts, causing the threads to stall. The following subsection shows how to avoid such stalls.

Addressing Shared Memory Conflicts

A solution to avoid bank conflicts is using async memory transfers. Async memory transfer from CPU to GPU memory exists for a long time. Async global to shared memory transfers were introduced with PTX 7.0 and require a compute capability 8.0 or above (Ampere microarchitecture). Async copies allow the threads within a warp to continue progressing after issuing a memory load. PTX splits a load in three instructions:

  • [Initiate] Done with cp.async: Initiating memory transfers and queues up the work. There's no guaranteed ordering between different calls to cp.async, and threads are not informed when the data is ready.
  • [Commit] Done with cp.async.commit_group: Associates all prior calls to cp.async with a group.
  • [Wait] Done with cp.async.wait_group: A barrier (fence) at which all threads wait for the memory operations associated with the group to complete. Serves as coordination and informs the threads that the data is ready.

The function load of the loader class is modified as follows:

#define CP_ASYNC_CG(dst, src, Bytes)                                       \
    asm volatile(                                                          \
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), \
        "l"(src), "n"(Bytes))

__device__ void load() {
        const size_t global_idx = offset_.row * ld_ + offset_.col;
        constexpr size_t load_bytes = 16;
#pragma unroll
        for (size_t row = 0; row < rows; row += stride_) {
                const T *src = global_ptr_ + row * ld_ + global_idx;
                T *dst = &shmem_[offset_.row][offset_.col] + row * cols;
                uint32_t pos_in_ss =
                        __cvta_generic_to_shared(reinterpret_cast<int4 *>(dst));
                CP_ASYNC_CG(pos_in_ss, src, load_bytes);
        }
}

The copy async instruction requires the destination operand to be in the shared space. As before, 16 bytes get loaded. The generated PTX for this function is:

cp.async.cg.shared.global.L2::128B [%r17], [%rd67], 16;

First, no ld or st instruction is present anymore. The instruction mandates an async copy using the global cache (L2) from global to shared. This is only a performance hint by the compiler, and the runtime may ignore it. This instruction only initiates a load, but the commit and wait instructions follow afterward. This is:

LoaderA.load();
LoaderB.load();
LoaderA.next(tb::kK);
LoaderB.next(tb::kK * N);
asm volatile("cp.async.commit_group");
asm volatile("cp.async.wait_group 0");
__syncthreads();

Instead of blocking the memory load in the loop of the load function, the threads continue their work. After initiating all calls, the threads commit their work to group zero and wait for the loads to complete. Using only one group is done for simplicity, but (strictly speaking) the loads for matrix B don't need to be ready concurrently as loads from matrix A.

With these few changes, the throughput of the kernel increases to 12 TFlops: A 50 % increase. The profiler also does not identify any bank conflicts anymore. Avoiding bank conflicts and async memory loads are a core element of high-performance Cuda kernels.

links