How does Triton do Warp Spec?

Kapil Sharma from the PyTorch team has a great series of posts diving into the Triton compiler process: 1, 2, 3. As covered there, Triton lowers to a series of intermediate representations, and each level has a set of transformational passes that implement optimizations. TTIR is the generic Triton IR and leverages a number of standard MLIR passes like common subexpression elimination, as well as some Triton specific passes like managing broadcast ops. That’s then lowered to TTGIR, a GPU-specific IR1

triton/third_party/nvidia/backend/compiler.py at rc/3.3.x · triton-lang/triton

The different backends configure the passes appropriate for their targets, so the Nvidia TTGIR configuration above details the passes for Nvidia hardware. Some are gated on the specific backend targeted, like warp specialization:

passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
opt.reg_dec_producer, opt.reg_inc_consumer)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)

Another example is using the Tensor Memory Accelerator for async loading on Hopper+

if capability // 10 >= 9:
   nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
   nvidia.passes.ttnvgpuir.add_fence_insertion(pm)

For a quick recap of the why and how of Warp Specialization, check Colfax’s guide to optimizing a GEMM:

The most basic method by which GPU programmers can create overlapping is via excess warps (warps are groupings of 32 contiguous threads). Nvidia GPUs allow a large number of warps per SM (streaming multiprocessors), and can switch between them with minimal overhead. In particular, the warp schedulers can simply switch to another warp if one warp encounters a slow memory fetch. In order to give the warp schedulers more opportunity to hide latency, a technique called warp-specialization was introduced circa 2011 [1, 2]. With warp-specialization, some warps are dedicated to memory fetches (producers), while others are dedicated to compute (consumers), and named barriers are used for synchronization between them. The idea is that the warp schedulers can then more easily hide the latency of copy operations within compute (and vice-versa).

Even more generally than overlapping memory loads you can overlap other kinds of work. SMs have 1 Tensor Core to 32 ALUs (Cuda Core) (x4 on recent hardware). This means that you can overlap other kinds of work, stuff that isn’t dot products. It’s really common to want to load memory, do a matmul then apply a pointwise function like a relu or other activation function. You aim to keep the Tensor Core as busy as possible with a series of matmuls, and warp specialization lets you do that.

The transforms to implement this are implemented in triton/lib/Dialect/TritonGPU/Transforms at rc/3.3.x · triton-lang/triton

The first task partitions ops in the kernel. This task looks for load ops and dot product ops, and partitions them into producer (loads) and consumer (process) groups.

// Step 1. Select loads into the first task, which is the producer task by
// default. Place dots into the second task, which is the consumer.
// Only consider loads that are connected to a dot op in a loop.

The next task is a bookkeeping one to propagates the task IDs, so if there are unlabeled ops they are attached one of the partitions.

The WSDataPartition transform partitions the dot operations into consumer groups. It splits the dot products inside loops along M or N dimensions to be processable within a warp group, ensuring all dot operations are sliced and the slices labelled with task_ids.

Just to look at some of the numbers: A warp consists of 32 threads, and a sync op (for loading) uses a whole warp. A warp group is a set of 4 warps: this is pertinent because the TensorCore MMA prefers 4 warps working on 64x(64/128/256)xK tiles. Triton already has a WarpGroupDotOp that tries to set this up, and that’s one of the operations targeted in this pass. The pass splits a Triton CTA tile, which may be 128×256, so that each consumer warp group has a 64 row (or 256 column) chunk.

 if (sliceSizeM >= 64) {
      LLVM_DEBUG({ LDBG("partition along M\n"); });
      partitionDim = 0;
      partitionSize = sliceSizeM;
      partitionOperand = opndA;
    } else if (sliceSizeN >= 256) {
      LLVM_DEBUG({ LDBG("partition along N\n"); });
      partitionDim = 1;
      partitionSize = sliceSizeN;
      partitionOperand = opndB;
    } else {
      LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN);
      return false;
    }

The next pass is WSCodePartition. This is a big transform. It takes the task-sliced IR from the DataPartition and sets up the producer warp group to copy from global GPU mem to SMEM (or on blackwell TMEM). It also drops in barriers using product.acquire/cpmmit and consumer.wait/release to ensure proper ordering between the groups. The transform identifies “channels”, places where data is loaded (using load or descriptor_load for TMA on Hopper+) and associates the producer task_id with all the consumer task IDs that need to process that data.

Conceptually, the transform is turning the loops in the original kernel into something like this:

for k: # K-dimension tiles)
    ##### PRODUCER #####
    producer.acquire(token, idx, phase) # reserve smem[idx]
    async_copy_global_to_local(smem[idx]) # GMEM → SMEM[idx]
    producer.commit(token, idx) # make slot visible

    #####  CONSUMERS (run in parallel warps) #####
    ## Consumer 0 ##
    consumer.wait(token, idx, phase) # sleep until slot ready
    mma_sync(accum0, smem[idx], …) # read-only, do matmul
    consumer.release(token, idx) 

    ## Consumer 1 ##
    consumer.wait(token, idx, phase)
    mma_sync(accum1, smem[idx], …)
    consumer.release(token, idx)

    # Repeat for extra consumers. 

    # increment circular buffer
    idx   = (idx + 1) % numBuffers 
    # Toggle each time we hit producer, indicate old vs new data.
    phase = phase ^ (idx == 0)
}

The next pass is a generic pipeline pass. Each op is assigned a stage based on latency: e.g. slower ops like loads go into stage 0. This is then transformed with modulo scheduling. Any sync ops are converted to async variants (in lowerLoops), before writing out (in expandLoops) prologue, kernel and epilogue loops that contain all the relevant ops.

Finally the WSLowering pass takes the various operators we have (like producer.acquire) and replaces them with the hardware specific variants (e.g. wait_barrier). It also handles the bookkeeping like generating the warp group and task_ids from the warp ID:

OpBuilder builder(op);
Value _4 = builder.create<arith::ConstantIntOp>(loc, WARPS_PER_TASK, 32);
Value warpId = builder.create<ttng::GetCanonicalWarpIdOp>(loc);
Value asyncTaskId = builder.create<arith::DivUIOp>(loc, warpId, _4);
op.getResult().replaceAllUsesWith(asyncTaskId);

This is a wordy IR way of saying

warpId = gpu.thread_id().x / 32 
task_id = warpId / 4

Now the code is ready for lowering to regular PTX, and all of the warp-specific stuff is captured explicitly!

  1. Note: Because of some project weirdness, warp specialization is quite different in the release branches from main, so I’ll refer to 3.3 from here on. It’s in very active development (by teams at Meta, OpenAI and Nvidia!) so the specifics are quite likely to change over coming releases! ↩︎

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading