Tag: gpu

  • Profiling Triton

    There are a couple of different options to profile a Triton kernel.

    Proton

    Proton is the profiler that ships with Triton (profiler for triton). You can enable it and (optionally) activate/deactive around specific regions you want to profile. You have the ability to annotate functions with specific metrics as well.

     session = proton.start()  # Start profiling session
    
    bias = torch.rand((256,), device='cuda', dtype=torch.float16)  # Bias vector
    flops = 2 * M * N * K
    bytes_accessed = A.element_size() * M*K + B.element_size() * K*N + C.element_size() * M*N  # rough bytes
        with proton.scope(f"fused_gemm_bias_relu [M={M}, N={N}, K={K}]", {"flops": flops, "bytes": bytes_accessed}):
            fused_gemm_bias_relu[grid](  
                A, B, C, bias, 
                M, N, K, 
                A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1),
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
            )
    
     proton.finalize() 

    The output can be visualized with the built-in viewer:

    proton-viewer -m time/ms,tflop/s ./proton.hatchet

    0.040 6.645 ROOT
    ├─ 0.004 nan _ZN2at6native55_GLOBAL__N__11f7a751_22_DistributionUniform_[...]_15PhiloxCudaStateESH_SI_
    └─ 0.037 7.284 fused_gemm_bias_relu [M=1024, N=256, K=512]
    └─ 0.037 nan fused_gemm_bias_relu

    In this case you can see both the (trimmed!) generated name for the bias tensor set up as well as the name of my custom kernel.

    nsight-compute

    Nvidia also have a good range of tools for looking at performance. Note will need to enable access to the counters on device for this:

    NVIDIA Development Tools Solutions – ERR_NVGPUCTRPERM: Permission issue with Performance Counters | NVIDIA Developer

    On the offchance you’re doing this on WSL, https://peterchng.com/blog/2024/03/02/profiling-cuda-programs-on-wsl-2/ walks through the set up!

    Nvidia ships nsight system which tracks larger system wide metrics, and nsight compute which is more focused on profiling execution. You can run it against a script like so:

    ncu -o profile_results python test.py

    The tool comes with a nice GUI for inspecting the results. It can show you the PTX or SASS source for the kernels, offers metrics like actively used registers (good for checking on register spilling), and calls out warnings on poor utilization or memory clashes.

    Upcoming intra-kernel profiler

    [tool][proton] Intra kernel profiling support by fywkevin · Pull Request #4861 · triton-lang/triton

    There is an extension coming for Proton that enables profiling within kernels. This reserves a pre-allocated buffer on device and logs metrics locally, for reading out at the end of the execution. It outputs as a chrome trace for use within a wide range of dev tools. While this isn’t merged into mainline yet, you can see an example of the usage in the dev repo.

  • Colfax on Blackwell GEMMs

    CUTLASS Tutorial: Writing GEMM Kernels Using Tensor Memory For NVIDIA® Blackwell GPUs – Colfax Research

    Dives deep into TMEM into particular, and the trend over the last few Nvidia generations of special-casing GEMMS in hardware:

    Tensor Memory and UMMA do for MMA just what TMA did for copy, making it a single-threaded, asynchronous operation that does not consume registers. As a result, registers can primarily be used for other tasks like scheduling and fused epilogue operations.

    Edit: link no longer seems to be working! It was a great post though, so hopefully comes back! Edit edit: it did!

  • Bank Conflicts in Shared Memory

    When data is in the global memory on a GPU it’s usually in row-major or column-major order. Loading from global memory is quite slow though, so for performance we want to move the data to shared memory for the threads in a warp to work on.

    To make that load from global memory performance we want memory reads to be coalesced, meaning we are reading contiguous chunk of memory at a time. Shared memory on the other hand is divided into banks, typically 32 banks which are 4 bytes wide. If multiple threads in the same warp try to write to different addresses in the same bank then the requests are processed sequentially, slowing things down while the threads wait on each other. Nsight and other profiling tools will helpfully point this out to you!

    For example, let’s say we’re loading a row major and column major tensor, and will be doing a multiplication between them (this is naive, to demonstrate the issue):

    __shared__ float Asub[TILE_DIM][TILE_DIM];
    __shared__ float Bsub[TILE_DIM][TILE_DIM];  // (No padding in this naive version)
    int lane = threadIdx.x;  // 0...31 (warp lane index)
     int tileRow = blockIdx.y * TILE_DIM;
     int tileCol = blockIdx.x * TILE_DIM;
    int globalRow = tileRow + lane;
    int globalCol = tileCol + lane;
    Asub[lane][0] = A[globalRow * N + tileCol + 0];
    Bsub[lane][0] = B[(tileRow + lane) + (tileCol + 0) * N];

    Now when we fill Bsub we will be writing everything to the same shared memory bank, significantly slowing things down. One easy fix is just to add padding:

    __shared__ float Asub[TILE_DIM][TILE_DIM];           // A tile (row-major, no conflict in our case)
    __shared__ float Bsub[TILE_DIM][TILE_DIM + PAD];     // B tile (extra column to prevent conflicts)
       

    With PAD as 1 (and TILE_DIM as 32) we have 32×33, or 132 bytes, offsetting the writes and ensuring that each thread gets its own bank.

    The downside is that this wastes shared memory, a scarce resource, so an alternative approach is swizzling: changing the layout such that consecutive thread accesses aren’t causing bank conflicts. That’s what Bert implemented to get performance in his recent GEMM walkthrough, but it’s easy to get it wrong.

    To make life easier than writing it in raw CUDA, Cutlass has a system called CuTE. Cute is a set of templates to express layout of data:

    auto tileLayout    = make_layout(make_shape(Int<32>{}, Int<32>{}), GenRowMajor{});
    auto swizzledLayout = composition(Swizzle<5, 0, 5>{}, tileLayout);

    Here you specify how the data is laid out in global memory with the shape and stride, then make_layout and the copy operation take care of translating from the row-major layout in global memory to the swizzled layout in shared memory.

    From a Triton perspective, Lei Zhang has a great post on memory access, and how it works in Triton, specifically the LinearLayout class that allows the language to similarly handle swizzling and layouts for you:

    Indeed the whole point of LLs is that they allow us to specify transposed and swizzled layouts as a “general case”. Instead of a layout class for registers in a thread, and another layout for registers in a thread but in MMAv2 order, and so on, all of these can be represented by different LLs. This gets rid of special cases and lets us write more general code.

    There’s a great colfax report on building GEMMS that covers shared memory bank conflicts, and Lei Mao has a post with a nice illustration. Axel Feldman also has a post about benchmarking different approaches and identifying bank conflicts, and some more efficient loading techniques.

  • Ping Pong GEMM from Scratch

    bertmaher/simplegemm

    Following in the tradition of worked kernel examples, Bert, of the PyTorch and Triton teams at Meta, writes up his experience developing a fast Ping-Pong kernel with TMA (fast loading on Hopper/H100) from scratch. As you might expect there are some good insights from debugging and working through the problems.

    You know what actually made it super obvious? Programming. I filled a shared memory buffer with consecutive integers — basically the smem equivalent of torch.arange(64*128).bfloat16().reshape(64, 128), and then TMA-transferred that to GMEM with 128B swizzling, cudaMemcpyed it back to the host, and printed it out. This actually made it crystal clear! I wrote the swizzle function correctly on my first try 😄.

    All the code, and the walk through, are in the repo!

  • Write more kernels

    The GPU Mode discord has emerged as the preeminent hub for current and aspiring GPU kernel hackers, and several of the folks there have kicked off a project to help make it easier for folks to write and benchmark them. https://gpu-mode.github.io/discord-cluster-manager/docs/intro/ goes over the idea, but it’s a series of leaderboards and runners for different kernel types so you can easily find (and beat!) the state of the art:

    We designed this leaderboard as a central and open-source resource for people to find the fastest kernels for the devices they are using. Furthermore, these open-community kernels will be useful in the future for designing automated methods for optimized kernel generation.

    The latter part there is one of the interesting points. Fundamentally custom kernels are an optimization on a model architecture, and like any optimization its natural to look for a system to automatically create that for you. ML compilers do a good job of certain graph optimizations, autotuning (searching for good kernel choices) and building specific versions from templates, but those templates are generally based on hand-written, high performance kernels for specific needs and shapes. It’s natural to see how LLMs do with this problem, and up to now the answer has been “pretty mid”.

    To that end, Sakana recently wrote about their efforts to build a system to generate high performance kernels from PyTorch model code with an agentic system: https://sakana.ai/ai-cuda-engineer/ – it has generated a lot of kernels (17k!)

    They chose to output CUDA , rather than CUTLASS, Triton, or another higher level framework, and they use an LLM to functionalize the PyTorch code, rather than use torch.compile and work on the exported graph:

    Functional Conversion: We first evaluate the LLMs’ ability to convert torch modules into parameterized function calls (stage 1). Our analysis of 250 KernelBench tasks (fig. 6) reveals distinct performance patterns across complexity levels. All tested LLMs successfully generate equivalent functional implementations for basic operations and simple fused operations (level 1, 2). However, for complex composed architectures (level 3), reasoning models (o1-high, o1-preview, o3-mini-high) demonstrate superior robustness, converting more than 45 tasks compared to sonnet3.5’s 42 tasks.

    One nice trick was they self-improved generation through adding in examples from similar, previously generated kernels, which improved the success rate:

    Retrieval-Augmented CUDA Kernel Translation & Optimization: Building upon these results, we enhanced our system’s capabilities through RAG. By leveraging our growing ’innovation archive’ of translated and optimized kernels, RAG significantly improved both translation and optimization capabilities.

  • DeepSeek’s Flash MLA Kernel

    https://github.com/deepseek-ai/FlashMLA/

    DeepSeek’s infra team are having a great week open-sourcing some their components, to the benefit of everyone! The first day was their Multihead Latent Attention kernel, which take the FlashAttention approach of leveraging shared memory heavily as an additional level of tiling during the attention computation, split-k to divide up along the K dimension and so on.

    On top of that they add a latent K/V vector: the input to the K and V is first projected to a lower latent dimension, then the K and V matrices project back into full dimension per-head. Only the latent vector is cached for past tokens. While this does use a little more compute than a traditional KV cache, choosing a sufficiently small latent dimension means significant memory and memory bandwidth savings, which is typically the constraint.

    MLA was introduced back in DeepSeek v2, if you want to read the full breakdown!

  • Warp Specialization

    In general, branching in GPU code is considered bad. When you write a kernel, it’s very easy to write the same kind of logic as you would on a CPU. However, GPU kernels execute on blocks of threads scheduled on streaming multiprocessors (SMs), and they are optimized for vectorized (or parallel) computation. This optimization relies on the idea that large groups of threads can execute the same instructions on different data in a lockstep fashion. Practically, these are scheduled as “warps” of 32 threads at a time (on Nvidia, the equivalent in AMD is 64 threads).

    To work an example , take this naive approach to processing an array that contains both positive and negative values:

    __global__ void naive_kernel(const float* input, float* output, int N) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
        if (idx < N) {
            if (input[idx] > 0) {
                // Some operation for positive values
                output[idx] = input[idx] * 2.0f;
            } else {
                // Some operation for negative or zero values
                output[idx] = input[idx] + 1.0f;
            }
        }
    }
    

    On a GPU, this branching between positive and negative values can lead to warp divergence – you end up using a small number of the threads in the warp, getting worse utilization. Instead, you can rewrite this logic to effectively remove the branching:

    __global__ void improved_kernel(const float* input, float* output, int N) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
        if (idx < N) {
            // Compute the same math in a unified way
            float val = input[idx];
            // Evaluate the transforms without branching
            float val_pos = val * 2.0f;
            float val_neg = val + 1.0f;
            // Use a conditional assignment
            output[idx] = (val > 0) ? val_pos : val_neg;
        }
    }
    

    This rewritten version still makes a choice, but it does so in a way that can be handled more concurrently. The basic idea in GPU programming is to use thread and block IDs to develop kernels that operate cooperatively (for example, splitting data among threads).

    This additional idea is branching on the warp itself — which is referred to as warp specialization. It’s very common to have kernels that deal with irregular data access, leading to branching, but by grouping specialized tasks into warps, you can still maintain high utilization of threads.

    For example, by branching on the thread ID and using barriers, you can specialize roles in warps, and have one set of threads dedicated to data loading and another to processing the data:

    __shared__ int data[128];
    
    __global__ void warp_specialization_kernel(int* global_mem) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
    
        if (threadIdx.x < 32) {
            // Producer warp
            int value = global_mem[idx]; // ... load data from global memory ...
            data[threadIdx.x] = value;
            __namedBarrierArrival("data_ready", 1);
        } else {
            // Consumer warp
            __namedBarrierWait("data_ready", 1);
            int value = data[threadIdx.x - 32];
            // ... process data ...
            global_mem[idx] = value + 42;
        }
    }
    

    Here, the first warp (threads 0–31) acts as the producer, loading data into shared memory. The remaining threads (in warps 1, 2, etc.) act as consumers, waiting for the producer to finish before processing the data. The namedbarrierX function calls ensures the producer and consumer warps are synchronized. This sample kernel is simplified to illustrate the concept, but the pattern is useful for specialized tasks.

    Triton, with the new changes that landed recently, allows you to specify warp groups in the autotuning parameters:

    @triton.autotune(
        configs=[
            triton.Config(
                {
                    "BLOCK_SIZE_M": 128,
                    "BLOCK_SIZE_N": 256,
                    "BLOCK_SIZE_K": 64,
                    "GROUP_SIZE_M": 8,
                },
                num_stages=2,
                num_warps=4,
                num_consumer_groups=2,
                num_buffers_warp_spec=3,
            ),
        ],
        key=["M", "N", "K"],
    )
    

    num_consumer_groups greater than zero enables warp specialization, and sets how many consumers will be available. num_buffers_warp_spec specifies the how many shared memory buffers are use for transfer between the warp groups. The Triton compiler can then optimize kernels based on available warps, grouping threads intelligently and applying warp-level optimizations, which you can read about in the PyTorch blog post on warp specialization.

    One of the reasons for the visibility of this technique now is in the Hopper architecture there are 8 independent schedulers per SM (up from 4 on Ampere) which enables more concurrent execution of warp groups, and the added support for warp-group level instructions, which make synchronization between warp groups pretty cheap.

  • Ping-Pong Kernels on Hopper

    Deep Dive on CUTLASS Ping-Pong GEMM Kernel | PyTorch

    A useful deep dive on this performance technique. The TL;DR is, using warp specialization, set up a producer groupthat loads data (using TMA), and two consumer groups executing MatMuls on the Tensor core. When a consumer group finishes it executes the epilogue (e.g. copying the results elsewhere, but you could imagine doing something else on a Cuda core) while the other consumer group takes over the Tensor core. Hence, I presume, the name as Tensor core usage ping-pongs between the two consumers!

    The producer warp group focuses on producing data movement to fill the shared memory buffers (via TMA). Two other warp groups are dedicated consumers that process the math (MMA) portion with tensor cores, and then do any follow up work and write their results back to global memory (epilogue).

  • Thunderkittens /GPUs go brr

    Just got round to reading the intro post to the (now improved) thunderkittens kernel DSL.

    https://hazyresearch.stanford.edu/blog/2024-05-12-tk

    Many good nuggets on kernel writing in general and the hopper in particular.

    But to us a “register” is a 16×16 tile of data. We think AI wants this — after all this time, it’s still just matrix multiplies, reductions, and reshapes. And we think the hardware wants this, too — small matrix multiplies are just begging for hardware support beyond just the systolic mma.

    In fact, more broadly we believe we should really reorient our ideas of AI around what maps well onto the hardware. How big should a recurrent state be? As big can fit onto an SM. How dense should the compute be? No less so than what the hardware demands. 

  • Better performance on GPUs

    https://www.nvidia.com/content/gtc-2010/pdfs/2238_gtc2010.pdf

    This is the 2010 NVidia presentation that really helped set the path for GPU performance. Focusing on memory bandwidth to get high FLOPS, do lots of work per-core, and manage the latency.

  • How AMD may get across the CUDA moat

    https://www.hpcwire.com/2023/10/05/how-amd-may-get-across-the-cuda-moat/

    CUDA is a huge advantage for NVidia, and is really baked in to a lot of workflows (PyTorch being a part of that!)

    Having a good quality ROCm backend makes porting significantly easier – AMD have made significant efforts on testing and support too. Also, they have generally structured their software to mirror CUDA, which makes switching fairly seamless. It’ll be interesting to see how folks reading to the MI300, and the opportunities given that much HBM per card!