Warp Specialization

In general, branching in GPU code is considered bad. When you write a kernel, it’s very easy to write the same kind of logic as you would on a CPU. However, GPU kernels execute on blocks of threads scheduled on streaming multiprocessors (SMs), and they are optimized for vectorized (or parallel) computation. This optimization relies on the idea that large groups of threads can execute the same instructions on different data in a lockstep fashion. Practically, these are scheduled as “warps” of 32 threads at a time (on Nvidia, the equivalent in AMD is 64 threads).

To work an example , take this naive approach to processing an array that contains both positive and negative values:

__global__ void naive_kernel(const float* input, float* output, int N) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < N) {
        if (input[idx] > 0) {
            // Some operation for positive values
            output[idx] = input[idx] * 2.0f;
        } else {
            // Some operation for negative or zero values
            output[idx] = input[idx] + 1.0f;
        }
    }
}

On a GPU, this branching between positive and negative values can lead to warp divergence – you end up using a small number of the threads in the warp, getting worse utilization. Instead, you can rewrite this logic to effectively remove the branching:

__global__ void improved_kernel(const float* input, float* output, int N) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < N) {
        // Compute the same math in a unified way
        float val = input[idx];
        // Evaluate the transforms without branching
        float val_pos = val * 2.0f;
        float val_neg = val + 1.0f;
        // Use a conditional assignment
        output[idx] = (val > 0) ? val_pos : val_neg;
    }
}

This rewritten version still makes a choice, but it does so in a way that can be handled more concurrently. The basic idea in GPU programming is to use thread and block IDs to develop kernels that operate cooperatively (for example, splitting data among threads).

This additional idea is branching on the warp itself — which is referred to as warp specialization. It’s very common to have kernels that deal with irregular data access, leading to branching, but by grouping specialized tasks into warps, you can still maintain high utilization of threads.

For example, by branching on the thread ID and using barriers, you can specialize roles in warps, and have one set of threads dedicated to data loading and another to processing the data:

__shared__ int data[128];

__global__ void warp_specialization_kernel(int* global_mem) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (threadIdx.x < 32) {
        // Producer warp
        int value = global_mem[idx]; // ... load data from global memory ...
        data[threadIdx.x] = value;
        __namedBarrierArrival("data_ready", 1);
    } else {
        // Consumer warp
        __namedBarrierWait("data_ready", 1);
        int value = data[threadIdx.x - 32];
        // ... process data ...
        global_mem[idx] = value + 42;
    }
}

Here, the first warp (threads 0–31) acts as the producer, loading data into shared memory. The remaining threads (in warps 1, 2, etc.) act as consumers, waiting for the producer to finish before processing the data. The namedbarrierX function calls ensures the producer and consumer warps are synchronized. This sample kernel is simplified to illustrate the concept, but the pattern is useful for specialized tasks.

Triton, with the new changes that landed recently, allows you to specify warp groups in the autotuning parameters:

@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 256,
                "BLOCK_SIZE_K": 64,
                "GROUP_SIZE_M": 8,
            },
            num_stages=2,
            num_warps=4,
            num_consumer_groups=2,
            num_buffers_warp_spec=3,
        ),
    ],
    key=["M", "N", "K"],
)

num_consumer_groups greater than zero enables warp specialization, and sets how many consumers will be available. num_buffers_warp_spec specifies the how many shared memory buffers are use for transfer between the warp groups. The Triton compiler can then optimize kernels based on available warps, grouping threads intelligently and applying warp-level optimizations, which you can read about in the PyTorch blog post on warp specialization.

One of the reasons for the visibility of this technique now is in the Hopper architecture there are 8 independent schedulers per SM (up from 4 on Ampere) which enables more concurrent execution of warp groups, and the added support for warp-group level instructions, which make synchronization between warp groups pretty cheap.

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading