Tag: gpu

  • TileIR

    There are a lot of things folks do on GPUs (including, sometimes, graphics) so I have an approximately-correct taxonomy of operations to group them in to:

    1. Dense compute: A matmul or a convolution.
    2. Map: Elementwise/pointwise work on each value of a tensor.
    3. Reduce: Process a tensor into fewer dimension, like a sum.
    4. Transforms: Change data structure. Easy ones like transpose, annoying ones like scatter/gather.
    5. Synchronize / Communicate: Move data, or wait for it (copies, collectives, fences/barriers).

    At the moment people are pouring billions of dollars into hardware that primarily does 1. And, at the same time, many of the greatest minds of our generation are attempting to ensure that the hardware spends as much time doing 1 as possible.

    The biggest barrier to doing a lot of dense compute is 5: moving things in and out of memory. Launching kernels, transferring data between host and device (or between devices), moving data between global memory and registers and so on. It’s like there’s a box defined by Data × Footprint 1 × Time, and everyone is trying to keep it as full as possible.

    This involves tradeoffs. You want to mul as many mats as you can, but you only have so much room to store accumulators. Fetching new data from memory also takes a while. You can keep many in-flight fetches around, but each one expands the kernel Footprint, lowering occupancy.

    There are 3 tricks that we can use to help fill up the box by stitching different operations together:

    • Epilogue fusion: take an elementwise op and fuse it onto the end of a dense op, so that when the MMA produces output, the elementwise op can be run while the output data is still in registers. A classic example: fuse the activation after the dense compute in a feed-forward net.
    • Vertical fusion: take two subsequent operations and chain them together to avoid running a loop for one, writing it back, then running a loop for the other2. A classic example is Fused LayerNorm: normally you’d need to add elements, then collect stats for the normalization. You can fuse the two to collect the stats as you add the residual.
    • Horizontal fusion: doing different things over the same data, in parallel. The Q, K, and V projections in a transformer all need the exact same input, so are good candidates to fuse horizontally.

    You rely on the design of the hardware to enable some of this. For example, an epilogue fusion is beneficial because it’s one kernel launch instead of two, and because the work doesn’t need to be written back to global memory, but also because the epilogue can overlap with other work.

    It’s not always obvious how to put these fusions together. Flash Attention was such a breakthrough because it made dense op fusion possible. The naive attention block has a softmax in the middle: Softmax(QK^T / √d) · V. That softmax is a reduction op, which means it needs all of QK^T to be computed first, a pretty large matrix. Tri Dao and colleagues realized that if you used online softmax you could calculate the softmax for subsets of the QK matrix, and avoid materializing the whole thing. They enabled fusing the QK into the softmax and the V in one kernel, at the tile level.

    Tiles are the subsection of a matrix you’re working on at any given time. In a matmul, tiles from both input matrices are loaded and multiplied, to produce an output tile. There’s a useful image of this in the Nvidia blog post on cuTile, Nvidia’s most recent entrant into the the kernel-development landscape. To side-step concerns of plagiarism, I had nanobanana plagiarize it for me:

    Illustration depicting three matrices labeled A, B, and C, showing a looping process with arrows indicating active rows and columns in a computational context.

    cuTile is built on a well-specified intermediate representation called TileIR. There’s an experimental backend for Triton that lowers to TileIR too. While Triton is block-oriented rather than tile-oriented, in practice what you mostly work on in a thread-block is… a tile. TileIR elevates the tile to a first-class concept.

    You can see this by generating the same kernel against the regular backend and the TileIR backend. Triton’s intermediate representation (TTIR) uses pointer arithmetic: generating offsets, computing masks, loading from explicit addresses. Here’s a bit of an inner loop of a matmul. It groups up which data it wants, loads the tiles a and b by pointer, and computes the dot product:

    %offs_m = tt.make_range {end = 128, start = 0} : tensor<128xi32>
    %a_ptrs = tt.expand_dims %offs_m {axis = 1} : tensor<128xi32> -> tensor<128x1xi32>
    %a_ptrs_1 = tt.splat %stride_am : i32 -> tensor<128x1xi32>
    %a_ptrs_2 = arith.muli %a_ptrs, %a_ptrs_1 : tensor<128x1xi32>
    ...
    scf.for %k = %c0 to %num_k step %c1 iter_args(%acc = %zero) -> tensor<128x128xf32>:
    %a = tt.load %a_ptrs, %mask : tensor<128x64x!tt.ptr<f16>>
    %b = tt.load %b_ptrs, %mask : tensor<64x128x!tt.ptr<f16>>
    %acc_new = tt.dot %a, %b, %acc : tensor<128x64xf16> * tensor<64x128xf16> -> tensor<128x128xf32>

    TileIR on the other hand preserves the tile as a semantic object. This snippetis doing exactly the same thing, but this representation elides the pointer math and masking:

    $33: Tile[float32,(128,128)] = typed_const(value=0)
    accumulator.2: Tile[float32,(128,128)] = for k.1 in $36 (with accumulator.0 = $33)
    do
    $50: Tile[float16,(128,64)] = tile_load_token_ordered(array=A, index=($7, k.1), shape=(128,64))
    $64: Tile[float16,(64,128)] = tile_load_token_ordered(array=B, index=(k.1, $10), shape=(64,128))
    $72: Tile[float32,(128,128)] = tile_mma(x=$50, y=$64, acc=accumulator.0)
    continue $72

    This is a nice IR (compact!), but from my perspective the most interesting part is that load function: tile_load_token_ordered.

    The “time” dimension of the Data × Footprint × Time box is the hardest one to manage. Time questions separate performant kernels from slow ones: When to prefetch, how to overlap loads, and so on. Since the advent of warp specialization, the Triton compiler has been exploring pipelining options through heuristics and autotuning, and kernel engineers have been going straight to the hardware with explicit barrier API extensions like TLX3 and Gluon4 .

    TileIR goes a somewhat different route. It assumes an unordered memory model: the order your code is written in does not determine when data is actually available. Instead, each memory operation returns a token and you attach semantics to it: read-only, write-only, read-write and so on.

    By being explicit about memory dependencies you give the compiler freedom to manage the Time dimension. Where accesses don’t overlap the compiler can freely reorder them. Where they do, the token chain tells the compiler exactly what depends on what. The kernel expresses intent; the compiler maps that to the hardware.

    TileIR is (mostly) targeting Blackwell right now, and the experimental backend is still early. The open question is whether we can express this smoothly enough in the syntax of kernels to actually enable taking the same kernel across hardware, or whether we are just adding some syntactic-sugar to avoid when doing hardware-specific tuning.

    That said, the idea feels pretty right? The tile is the unit with which we can express what we actually mean about memory, ordering, and fusion. The CUDA programming model was always about bounded-linearity within a massively parallel framework, and this loosens the bounds that little bit more.

    1. Use of available registers and shared memory, for example ↩︎
    2. This is loop fusion, in compiler terms. There are other things you can do, but this is the big one. ↩︎
    3. Triton Language Extensions, from Meta. As a disclaimer, these are the folks I work with. ↩︎
    4. From OpenAI ↩︎
  • Megacores

    Megacore - Systole as a 80s metal album cover.

    What we do in machine learnings owes a lot to the history of computer graphics. Folks like Kurt Akeley, one of the founders of SGI, identified that 3D graphics have a naturally pipelined structure. You have a high volume of similar operations, such as applying pixel-y soldier textures to a mesh of triangles, and by pipelining them you can find an opportunity for a high degree of parallelism.

    Akeley was one of the drivers of OpenGL, which provided a standard interface to that pipeline, and later worked with Nvidia on CG, a realtime shader language and compiler. Shader languages, as used in Pixar’s RenderMan and other non-realtime 3D use cases, introduced an approach where you could manage lighting programmatically by describing the transforms to each individual element. The shader would be run in parallel across all the geometry or pixels it was addressing.

    With CUDA, Ian Buck and others at Nvidia helped formalize what had been true in the hardware for a while: GPUs were massively parallel processing machines, not just polygon factories. CUDA was part of a move from the supercomputer approach of Single Instruction Multiple Data (SIMD) to Single Instruction Multiple Thread (SIMT). On a Cray or other vector oriented processor you had to pack the work into a vector. CUDA let programmers familiar with CPU threads think in those terms instead. Under the hood, the threads in a warp were executed in lockstep, but they could be masked off to allow for divergence. It was flexible, fast, and attracted the attention of the machine learning community. Because so much of ML is large matmuls, Nvidia bolted on Tensor Cores as specialized co-processors that handled blocks of matrix math efficiently. This combination of performant hardware and flexible software helped make Nvidia the most valuable company in the world, and drive up house prices across the Bay Area.

    But, it transpires, not everyone loved shoveling their margin to Jensen, and they looked for more cost-efficient ways to run ML workloads. The flexibility for threads to branch, pause or switch requires infrastructure and silicon. You need big register files per core, multiple levels of memory to cache, and logic to manage swapping in and out warps.

    If you look at the “do the math” parts of a chip, a CPU probably only spends about 10% of silicon on that, with the rest managing the chaos of running an operating system: branch prediction, caching, data movement. A GPU, in contrast, is a wildly efficient machine, with maybe 30-40% of the silicon dedicated to mathing effectively.

    When Google looked at the problem of running inference at their scale back in the dark ages of 2016 they wanted to spend as much of their budget as possible doing the math, to keep the costs as low as they could. The chip they created, the Tensor Processing Unit (TPU) recently hit its 7th iteration and SemiAnalysis published an extensive breakdown on it: TPU v7 Ironwood, quickly followed up with a deep dive Amazon’s Trainium v3.

    Trainium3 takes a similar approach to Trainium2 and Google’s TPU and builds the chip out of a small number of large NeuronCores. This contrasts with GPU architectures like Nvidia and AMD’s, which instead uses a large number of smaller tensor cores. Large cores are typically better for GenAI workloads since they have less control overhead.

    Dylan and his team are touting these as the first chips to genuinely threaten Nvidia’s moat. The big frontier labs seem interested, with deals and investigation from Anthropic, OpenAI, Meta and others. As the piece repeatedly points out, if you want to understand the dominance of Nvidia you have to focus on the system, and not the microarchitecture. So, of course, I want to talk exclusively about the microarchitecture here.

    TPU, Trainium, as well as other custom approaches like Meta’s MTIA1 lean on an approach called Systolic Arrays. As a recap, Nvidia’s Streaming Multiprocessor (SMs), AMDs compute units ,and so on are cooperative multiprocessors. They access registers, talk to caches and handle the flow of data. Threads can request data if it’s not ready and the hardware warp schedulers will swap in another piece of work to keep the chip humming.

    Systolic arrays are different. The name comes from systole, the phase where your heart pumps blood. In a systolic array, you load your data once and fire it through a grid of Processing Elements (PEs). Each element maths its math then passes the result to its neighbor on the next clock tick.

    This was very much in line with the needs of the original TPU: load a set of model weights up, then pump user requests through as efficiently as possible. TPUv1 only supported int8: it was a low-bit, high-efficiency matmul machine. The data flow needed to be pre-determined: you set it up and make it go, which made it incredibly silicon efficient. You don’t need lots of caches or schedulers, and in fact the original TPU didn’t have any at all!

    The con of course was that you have to get it right! If the data isn’t there to pump in, the whole thing just waits. There is no backup plan to another warp, no other threads. Not only that, but because the systolic arrays are generally a lot bigger (say 256×256 vs the Tensorcores 16×16), you have fewer of them. While an Nvidia GPU might have more than 100 SMs, a Trainium v3 has 8 cores, and a TPU has just 2. Each core is a lot larger, and wasting it gets a lot more expensive.

    Presumably Jeff Dean just programmed these right the first time, but for the rest of Google (and later the world) they spent years building XLA (Accelerated Linear Algebra), a full-graph compiler. In GPU kernel programming the challenge is hiding memory latency and managing register pressure. On a TPU-type approach, there is one massive VMEM that fulfills a similar role as the registers and no memory hierarchy, but you can’t rely on the hardware to swap between jobs. XLA needs to know exactly how the graph works so that it can schedule the right data at the right time.

    TPUs used a VLIW architecture: Very Long Instruction Words. Rather than a traditional instruction set with diverse instructions, VLIW lets you bundle Very Long packages of instructions into single units (kind of a silicon equivalent of German) which execute operations on each of the different units of the core at the same time. This was introduced in TPU v2, and its where the pressure on the compiler really multiplied.

    To draw a GPU analogy, if you think about something like a Relu(AxB+C) you have a graph of operations: AxB -> Result, Result + C -> Result2, Relu(Result2). To optimize that you could use an CUDA graph to compile it into single kernel dispatch and CPU/GPU communication. One step further would be kernel fusion: keep all the intermediate results in registers and write one kernel that avoids the back and forth to higher tier memory. That lets you bundle up even more , but you have to have even higher confidence in the sizes involved to avoid running out of registers,

    VLIW is like parallel kernel fusions: a TPU v2 had 2 matrix units, 2 vector units, 2 scalar units and 2 memory load/store units2.To keep them busy every step the compiler needs to plan ahead enough to give each of them something useful to do. VLIW instructions bundle those ops along with any constants needed into a single instruction. Fusion goes from being an optimization to being a necessity. Once you get it though, you can spend more like 50-60% of your silicon on the part you care most about, and that translates into an excellent total cost of ownership.

    Does this mean we should all be cancelling our Rubin orders and buying TPUs? I mean, no. But there is some nuance. Choosing between flexible streaming processors or efficient systolic megacores feels drastic, but I think it might not matter quite as much as it seems.

    Research still overwhelmingly benefits from flexibility. You are running experiments, solving bottlenecks and debugging. Nvidia tends to be the big lab tool of choice thanks to the flexibility, the depth of tooling and the general CUDA ecosystem3.

    If you are mainly serving a massive model, it’s worth the investment to lock down all the weirdness and optimize it. That’s where the megacore chips have proved their mettle first, with TPU, Inferentia4, MTIA and others all starting on that side of the house.

    Folks like Akeley and Buck realized that when you’re building a chip you’re really building a programming model. Get that right, and the model can long outlast the hardware. Balancing expressivity with performance is the thing that lets a platform win: who best lets researchers and engineers define the future without fighting the silicon.

    What seems to be emerging isn’t quite the SIMT/CUDA architecture: its something around expressing the dataflow of tiles on the critical kernels5, while relying on a compiler to optimize the larger graph and compute.

    Making sure that you have access to the right software might be more important than trying to perfectly identify which hardware platform is the once and future king. But also, look, the world moves fast and if you get a Prime Day deal on Trainium instances, you should probably just take it. The hardware can and will change and it can always be adopted, as the frontier labs are showing. If we keep hunting for the expressivity we need, as OpenGL, CUDA, Triton and others have over the years, we will keep unlocking the possibilities in whatever hardware is available.

    1. Disclosure: I work at Meta and like these chips a lot, though no one would let me anywhere near any chip design, luckily enough ↩︎
    2. Newer versions have others too, like the sparse cores in TPU v6 and v7 which are basically dedicated embedding management processors ↩︎
    3. With the notable exception of Google themselves, though the Jax-XLA-TPU ecosystem is very rich internally ↩︎
    4. Amazon remain undefeated at naming things ↩︎
    5. From system to VMEM on megacore approaches, from SMEM to registers on GPUs ↩︎

  • Let’s all switch to FP16?

    Serious scientists use FP64 – 64 bit floating point numbers – for high precision simulations, but in the world of machine learning we got by for the longest time with FP32. The perennial quest for increased FLOPS, particularly when memory bound, made even that seem too expensive though.

    FP16 offered a reduced numeric range, but at half the size. Training with it in practice meant embracing autoscaling1 which ensured the values stayed within the range FP16 could represent. Then, Google developed BF16: it moved some of the bits to the exponent from the mantissa, so offered the same numeric range as FP32, but with reduced precision.

    Since TPUv3 back in 2018 and Ampere in 2020 it’s been finding its way into hardware and has become the go-to format for training for many models. Life was good, and training in FP16 was mainly discussed as a memory of hard winters past.

    Last week [2510.26788] Defeating the Training-Inference Mismatch via FP16 dropped and threw ML Twitter into a tither by making the argument everyone was doing Reinforcement Learning wrong and the solution… was FP16.

    “In this work, we take a step back from the complex algorithmic fixes and investigate the root cause of the numerical mismatch: floating-point precision. We identify that the modern standard for mixed-precision training, BFloat16 (BF16), is the primary culprit. While BF16 has a wide dynamic range which is excellent for stable pre-training, its low precision makes it highly susceptible to rounding errors that accumulate and eventually cause the training and inference policies to diverge.”

    The process for RL generally looks like:

    • Get a problem in a prompt
    • Do inference on the model to generate complete responses (a rollout)
    • Get a reward score for the response(s)
    • Run a training loop on the model to update weights based on the reward

    If you want to be on-policy (which generally trains better) you need the “model” in steps 2 and 4 to be identical, but the actual code running around the model in the two steps is different: for example, you don’t use a KV cache in training and you don’t store gradients in inference. But you do want to keep the weights and numerics of the model the same, else your on-policy training becomes a little bit off-policy.

    The last year of LLM research has been scaling this up, which requires managing a training and inference flow efficiently. This ongoing pressure to optimize the two paths independently leads to a risk of divergence. The paper finds that absolutely happens, and the divergence collapses the effectiveness of the learning. Unless, that is, you use FP16:

    This is precisely why switching to FP16 provides a fundamental solution. With its 10 mantissa bits, FP16 offers 8 times more precision (210 values vs. 27 values) than BF16. This higher fidelity means that the outputs of the training and inference engines are much more likely to be numerically identical. The increased precision creates a buffer that absorbs the minor implementation differences between the two engines, preventing rounding errors from accumulating and causing a policy divergence”

    The paper does an excellent job of breaking down the many reasons why this happens, but it pretty clear that FP16 is a patch: if you can’t get your numerics perfectly matched, then having more precision gives you more wiggle room.

    About a month before this the ByteDance folks posted a fantastic deep dive into RL collapse from discrepancies between training and inference: When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

    They identify a range of concerns, including straight up bugs:

    “According to this GitHub issue, we set disable_cascade_attn=True when initializing the vLLM engine and found that it significantly helps reduce the training-inference mismatch in experiments conducted on A100 GPUs.

    Many of the experiments in the FP16 vs BF16 paper were run on A100s2 , so some backlash emerged suggesting that perhaps this whole thing is just a kernel error. But as ByteDance showed, there really is a lot going on that can make things worse.

    Another example is Horace He’s recent work at Thinking Macines around a related problem: Defeating Nondeterminism in LLM Inference – Thinking Machines Lab

    “As mentioned above, one common explanation for why kernels add numbers in different orders is the “concurrency + floating point” hypothesis. The hypothesis states that if the order in which concurrent threads finish is nondeterministic and the accumulation order depends on the order in which concurrent threads finish (such as with an atomic add), our accumulation order will be nondeterministic as well.”

    Horace calls out variance in batching as the primary cause of non-determinism, and hence another quite plausible cause of inference/training mismatch

    “In other words, the primary reason nearly all LLM inference endpoints are nondeterministic is that the load (and thus batch-size) nondeterministically varies! This nondeterminism is not unique to GPUs — LLM inference endpoints served from CPUs or TPUs will also have this source of nondeterminism.”

    The meta-point is that despite being a field fundamentally based in mathematical precision we have been sloppy with numerics, pretty much everywhere.

    Ed Yang’s session in the PyTorch Conference keynote3 a couple of weeks back called this problem out from the perspective of scaling up ML infrastructure. He presented a number of solutions to try and address it, which often comes down to giving folks control over precisely how the numerics work in different parts of their model.

    While the focus here was on RL and FP16, the reality is we deal with this for training->inference in much simpler cases, as well as when moving models between different hardware. Even within generations this can be hard: one of the fun infra problems when the H100 came out was everyone discovering that the FP8 tensor cores in the Hopper used a 22-bit accumulator for intermediate calculations, which wasn’t really documented!

    The balance between speed and accuracy is often effectively made empirically: if something is faster, and works, then at some level it’s right! Reinforcement Learning mixes together different evolutionary chains of optimizations, so maybe those serious scientists with their FP64 were onto something. Not because they absolutely needed the precision, but because they needed to know they had the precision.

    We’re probably not going to switch industry wide back to FP164, but getting a better numerical grounding into the tools we use is going to make everyone’s lives easier, eventually!

    1. torch.cuda.amp and friends ↩︎
    2. Though they did verify on Hopper some as well, which some people seemed to miss ↩︎
    3. Check out the recording: Keynote: PyTorch Technical Deep Dive – Alban Desmaison, Peng Wu, Mark Saroufim & Edward Yang, Meta ↩︎
    4. Especially since most labs are doing so much with FP8 or less these days, and it would probably annoy a bunch of chip designers ↩︎
  • Helion and the evolving GPU programming model

    Helion: A High-Level DSL for Performant and Portable ML Kernels – PyTorch

    Lots of announcements around the Triton and PyTorch Conferences this week, including the 1.0 of Helion, a high-level kernel authoring DSL:

     It establishes a new layer of abstraction that bridges the user-friendly simplicity of PyTorch with the performance of a lower level language. By automating tedious and error-prone tasks like tensor indexing, memory management, and hardware-specific tuning, Helion empowers developers to focus on algorithmic logic rather than hardware-specific implementation details. Helion achieves this balance by pairing a familiar, PyTorch-centric syntax with a powerful autotuning engine that automates the complex search for optimal kernel configurations. This results in a system that delivers performance portability across hardware architectures while drastically reducing development effort. 

    There has been a bit of an explosion in kernel-authoring options recently with CuTe-DSL and CuTile from Nvidia, TileLang (as featured in recent DeepSeek releases), Gluon and TLX1 as well as evolutions to core Triton, Thunderkittens, Pallas, and others.

    There are a couple of different axes of progress occurring in GPU authoring. The first is between iterable, researcher-friendly declarative code and tightly written hardware-friendly imperative code.

    Its a classic developer-experience trade off: you let people tell you what they want to do (matmul these things then apply a softmax) or you let people tell you precisely how to do it (run this dot product on these SMs then aggregate the result).

    In general you want to stay as high-level as possible, particularly if you are experimenting with lots of different variants in a research type setting, but you may have a bound on the performance hit you can accept. A common example is you want to iterate on some attention variant, but don’t want to completely give up on the performance wins of Flash Attention.2

    Triton and others provided an interesting middle ground: it was easy enough to iterate with thanks to being embedded in Python, and was performant enough as it leveraged a compiler to automatically apply some optimizations. You are still much more imperative in a PyTorch program, but you work at a higher level of abstraction: rather than writing programs which own a thread of data, as in CUDA, you think about a tile of data. The ThunderKittens docs put this well:

    A GPU is not really a 1000×1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16×16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16×16 values.

    The next abstraction that frameworks developed was how to represent data across the memory hierarchy. To take advantage of the tensor cores you have to have data laid out in a specific way in registers. But you are better off loading data in a different order in global or shared memory. CuTe offered a big benefit by giving you types to represent layouts that could be composed, making it easier to keep track of the data movement required. Triton and others leaned on the compiler to infer the right layouts and offered higher-level APIs to copy data between stages.

    This started to get challenging on Hopper, thanks to TMA3 and the limitations of memory bandwidth, which gets to the second evolution happening in GPU kernels. How do you orchestrate the movement of data between memory layers while ensuring that data was you keep the tensor cores saturated. This involved techniques like warp specialization, where individual warps do different operations towards a shared goal. That means carefully allocating ownership over registers to avoid warps stepping on each other. Blackwell4 made this even trickier with the addition of TMEM, 2-CTA mode and other features that offered more performance but required even more careful orchestration.

    In compiler terms this is a scheduling problem and in general the industry is quite good at it! CPUs give compilers a lot of leeway to schedule operations efficiently because they have a great deal of support for out-of-order execution, well documented ops, and substantial caches. GPUs process groups of threads5 in lockstep and demand strict timing about when to insert barriers, issues async operations and so on. 

    A GPU scheduler has to tag operations to specific warp-slots in advance, assign numbers of registers to them to avoid conflicts, and sync them with barriers. It’s a lot more brittle: if we guess wrong, we can idle the Tensor cores and tank efficiency. The actual execution model is a bit of a black box too: the target for compilers (PTX) is actually further compiled to SASS by nvcc.

    Across the industry we’ve been exploring ways to be more explicit without giving way all of the operational and developer efficiency gains of higher-level languages. CuTe-DSL offers a very close-to-hardware model but in a Pythonic package6, Gluon (OpenAI) and TLX (Meta) add extensions to allow modelling pipelines in code without getting rid of the Triton compiler, TileLang builds on TVM with explicit pipeline declarations.

    One of the reasons for this variety is we don’t quite know how to express a warp-group pipelined execution model. For example, TileLang has a pipelined construct:

    for k in T.Pipelined(loop_range, num_stages=num_stages):
        MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)  # Q @ K^T
        Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
        Rescale(acc_o, scores_scale)  # Apply correction
        MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)  # P @ V

    Gluon has a descriptor that allocated resources like registers explicitly to warps:

    gl.warp_specialize(
            (config, chnls, descs, M, STAGE),     # Args to correction stage
            _attn_fwd_correction,                  # Trunk task (1 warp, 192 regs)
            (config, chnls, descs, M, STAGE),     # Args to specialized tasks
            [
                _attn_fwd_softmax0,    # 4 warps, 192 registers - Softmax tile 0
                _attn_fwd_softmax1,    # 4 warps, 192 registers - Softmax tile 1
                _attn_fwd_mma,         # 1 warp, 24 registers  - Matrix multiplies
                _attn_fwd_load,        # 1 warp, 24 registers  - TMA loads
                _attn_fwd_epilogue,    # 1 warp, 24 registers  - Store results
            ],
            [4, 4, 1, 1, 1],          # Warps per stage
            [192, 192, 24, 24, 24]    # Registers per stage
        )

    And TLX tags sections of code with contexts to indicate groupings, and also allocates resources:

    with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS,
                        registers=232,
                        replicate=NUM_MMA_GROUPS):

    They can all work and finding the best trade off is a good goal, but in all cases they do force a lot of decisions. As an example, that allocation of how many registers to use is not only operation dependent, its hardware dependent, and that makes portability between hardware (even different generations from the same vendor) expensive. Manual controls are necessary: it takes time to develop the compiler passes and heuristics to optimally divide work, so handing explicit control over7 is beneficial, particularly when serving at scale. The cost is complexity and portability. This is where Helion takes a different tack

    Anyway, so what about Helion?

    Helion instead take a point on the line above Triton, but below the ML frameworks. It focuses on just expressing what you want to happen from the tile perspective.

    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
        out[tile_m, tile_n] = acc

    Under the hood, this compiles down to Triton. You might think would be a bit of a no-op on performance, but in practical terms its often better. The reason is search: Helion can autotune across a wide number of parameters, then let you bake them into your kernel once you’ve identified good ones for your specific setup. The example in the blog posts shows how many dimensions of search need to occur:

    @helion.kernel(config=helion.Config(
        block_sizes=[64, 64, 64],
        loop_orders=[[0, 1]],
        l2_groupings=[4],
        range_unroll_factors=[0, 1],
        range_warp_specializes=[None, False],
        range_num_stages=[0, 3],
        range_multi_buffers=[None, False],
        range_flattens=[None, None],
        num_warps=8,
        num_stages=6,
        indexing='block_ptr',
        pid_type='flat'
    ))
    def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    This makes moving to different hardware as simple as redoing the search process, and offers a much more comprehensive exploration than most folks would do when hand-rolling a lower level kernel. Its a very interesting idea, and I’m glad to see more people kicking the tires!

    Low-level optimizations aren’t going away any time soon, but I’m glad to have more exploration in the kernel development space. Finding the right abstractions and right compiler approaches to keep scaling kernel development will help make it accessible to more and more people and ensure that we can evolve our kernels with the hardware.

    1. Also a Meta thing, disclaimer. ↩︎
    2. This is the logic behind FlexAttention, whch was one of the lights that guided the way towards Helion. ↩︎
    3. Fully async copies – a separate execution engine to move data ↩︎
    4. Well, datacenter blackwell. Consumer blackwell lacks TMEM and 2-CTA, so is a bit more Hopper-like programming model. I’m not sure yet what the DGX Sparks have! ↩︎
    5. Warps – 32 threads on Nvidia, or waves, 64 threads on AMD. The important distinction is that all these threads are doing the same thing: you can mask some of them out, but they have a fairly simple march through the instruction. ↩︎
    6. With a JIT! ↩︎
    7. Without making people write templated C++, sorry Ben ↩︎
  • Layouts

    You could have invented CuTe hierarchical layout (but maybe not the rest of it?) : ezyang’s blog

    Ed posted the best intro to CuTe layouts I have seen, by showing how to extrapolate them from PyTorch striding1.

    Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.)

    Relatedly, Simon Veitner put together a quite visual understanding of layouts. https://veitner.bearblog.dev/intuition-behind-hierarchical-layouts/ – the graphics are helpful once you have the baseline intuition from Ed’s post!

    1. If you’re not familiar with striding, Ed’s PyTorch Internals talk/post remains the best intro! ↩︎
  • The TPU book, on GPUs

    How to Think About GPUs | How To Scale Your Model

    The Jax “How To Scale Your Model” book is one of my favorite references for folks trying to get their head round pretraining1. It breaks down the performance characteristics of model training (often using Llama 3 as an example) in an incredibly clear way. The only slight limitation is that it is primarily focused on scaling LLMs on TPUs: interesting, but probably not your main platform target (unless you work at Deepmind). They just released a new chapter covering GPUs, and it’s also a great summary2.

    There are also plenty of mildly snarky comments about design choices to leaven the reading too:

    Takeaway: in theory, NVIDIA SHARP (available on most NVIDIA switches) should reduce the cost of an AllReduce on B bytes from about 2 * B / W to B / W. However, in practice we only see a roughly 30% improvement in bandwidth. Since pure AllReduces are fairly rare in LLMs, this is not especially useful.

    1. Though they include a chapter on inference too! ↩︎
    2. Though if you haven’t read the rest of the book it moves pretty fast – definitely best to read through the whole thing and treat this as the appendix it is intended to be! ↩︎
  • Quack CuteDSL Kernels

    Dao-AILab/quack: A Quirky Assortment of CuTe Kernels

    Tri Dao & co have a fun repo up called Quack: A Quirky Assortment of CuTe Kernels, all leveraging the CuTe-DSL. These are hopper and blackwell oriented kernels for a variety of common needs like softmax, layernorm and RMSNorm.

    On top of that, they wrote a post on how to get speed of light (memory bound) kernels in CuTe-DSL. It goes through how to implement a reduction op across multiple tiers of memory using TensorSSA for thread level reductions, warp reduction with shuffle_sync_bfly and block reduction with shared memory. Even if you’re not writing CuTe, this is about as good an introduction to architecting memory bound ops as I have seen!

    They also cover clustered reduction, leveraging multiple SMs:

    In cluster reduction, we first send the current warp’s reduced value to all the peer thread block’s reduction buffer in peer’s SMEM. Such sending is conducted via a dedicated SM-to-SM fabric (as DSMEM). Then each warp fetches all warp’s values from their local reduction buffer, and reduces these values.

    This does seem to help the kernels scale well to larger sizes:

    We believe our outstanding performance at >= 65k input is due to our successful utilization of cluster reduction in H100. When the size of inputs are ultra long and depleting the SM’s registers and shared memory, without cluster reduction, we would have to switch to an online algorithm (like online softmax) otherwise we may get a massive register spilling that leads to significant throughput degradation.

    I also really appreciate this note of reality in their conclusion:

    Hitting “speed-of-light” model memory throughput confirms that a carefully hand-crafted CuTe kernel can squeeze every byte across all memory hierarchies in the hardware. But that efficiency comes at the price of per-operator and even per input-shape tuning, which imposes a natural tradeoff between efficiency and development efforts

  • Cute-DSL

    In May Nvidia shipped CuTe‑DSL, the Python library they teased at GTC earlier in the year that mirrors CUTLASS’s C++ tensor‑layout . Then, at the start of June, the ‑dev label disappeared (so presumably its production ready now). The pitch is simple: Write speed‑of‑light kernels from the comfort of Python.

    Of course, nothing about CUDA is ever really simple. CuTe‑DSL gives the full Cutlass experience1, wrapped in an ever so slightly more approachable interface.

    Getting Cute: Transpose

    Matrix transpose felt like a reasonable ‘hello world’ : (B[j,i] = A[i,j]). The PyTorch version is simple: torch.transpose(input, 0, 1).

    To get a baseline, here is a simple transpose kernel in Triton. We tl.load, flip the coordinates and tl.store it back.

    @triton.jit
    def triton_transpose_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE: tl.constexpr):
        # 2D block coordinates
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        
        # Calculate offsets
        offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        
        # Load with masking
        mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
        a = tl.load(input_ptr + offs_m[:, None] * N + offs_n[None, :], mask=mask)
        
        # Store transposed (swap coordinates)
        tl.store(output_ptr + offs_n[:, None] * M + offs_m[None, :], a, mask=mask)

    Here’s the same idea in CuTe‑DSL. CuTe leverages a decorator and Pythons ability to integrate with JITs. Anything decorated with @jit runs host-side (on the CPU), while @kernel is used for device side (on the GPU). There are both AST based and tracing based options dependening on the presence of dynamic shapes or control flow.

        @cute.kernel
        def transpose_kernel(self, mA: cute.Tensor, mB: cute.Tensor):
            tidx = cute.arch.thread_idx()[0]
            tidy = cute.arch.thread_idx()[1]
            bidx = cute.arch.block_idx()[0]
            bidy = cute.arch.block_idx()[1]
            # This might all be unnecessary
            # but I was fearful of the compiler
            tile_start_m = cutlass.Int32(0)
            tile_start_n = cutlass.Int32(0)
            global_m = cutlass.Int32(0)
            global_n = cutlass.Int32(0)
            M = cutlass.Int32(0)
            N = cutlass.Int32(0)
            val = cutlass.Float32(0.0)
            # Calculate tile starting positions
            tile_start_m = bidy * self._tile_size
            tile_start_n = bidx * self._tile_size
            # Calculate global coordinates for this thread
            global_m = tile_start_m + tidy
            global_n = tile_start_n + tidx
            # Get matrix dimensions at runtime
            M = mA.shape[0]
            N = mA.shape[1]
            # Bounds checking and transpose operation
            if global_m < M and global_n < N:
                val = mA[global_m, global_n]
                # Transpose: B[n, m] = A[m, n]
                mB[global_n, global_m] = val

    What just happened?

    • Thread and block indices come straight from CUDA (thread_idx, block_idx), vs the Triton block abstraction
    • No explicit loads or stores: CuTe uses overloaded [] to generate them.

    Launching isn’t a million miles away from Triton:

    @cute.jit   # host side
    def launch(self, A: cute.Tensor, B: cute.Tensor):
        M, N = A.shape
        grid = ((N + self.T - 1)//self.T,
                (M + self.T - 1)//self.T, 1)
        self.transpose_kernel(A, B).launch(
            grid=grid,
            block=[self.T, self.T, 1],
        )

    Because CuTe‑DSL speaks DLPack, you can hand it a PyTorch tensor directly. If you wanted to cache the conversion, it looks like this:

    A_cute = from_dlpack(A).mark_layout_dynamic()

    The mark_layout_dynamic is used to trigger the dynamic shape support, and avoid shape specialization. The one place where this went a bit funky in my testing was dealing with singular leading dimensions: there you need to be more explicit about the shape to satisfy the compiler.

    Layouts and Memory

    This kernel isn’t really leveraging the fundamental value of CuTe though, which is composable tensor layouts and memory management. CuTe‑DSL exposes the full memory hierarchy: global, shared, register (and tmem for those with blackwells), and lets you tile, copy, and pipeline data between them. Common primitives:

    • make_layout / make_layout_tv: describe how a tensor is laid out.
    • cute.zipped_divide(tensor, tiler): tile a tensor.
    • cute.copy(src_layout, dst_layout, pred=mask): async copy.
    • cute.arch.sync_threads(): explicit barrier.

    HGEMMony2

    CuTe ships with some example kernels, so I grabbed one — an HGEMM (half-precision, FP16, batched GEMM) — and compared to an example implementation in Triton.

    To express the same thing in  PyTorch, we can unleash our inner Jeremy Howard and use einsum notation: torch.einsum("mkl,nkl->mnl", a, b).  Take L batches of a MxK matrix, L batches of a NxK matrix, and return L batches of a MxN matrix.

    Here is the Triton:

    @triton.jit
    def triton_batched_hgemm_kernel(
    	a_ptr, b_ptr, c_ptr,
      M, N, K, L, 
      stride_am, stride_ak, stride_al, 
      stride_bn, stride_bk, stride_bl, 
      stride_cm, stride_cn, stride_cl, 
      BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
      GROUP_SIZE_M: tl.constexpr,
    ):
        """Triton batched half-precision GEMM kernel: C[m,n,l] = sum_k A[m,k,l] * B[n,k,l]"""
        pid = tl.program_id(axis=0)
        pid_batch = tl.program_id(axis=1)  # Batch dimension
        
        # Calculate batch offsets
        batch_offset_a = pid_batch * stride_al
        batch_offset_b = pid_batch * stride_bl  
        batch_offset_c = pid_batch * stride_cl
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m
        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        
        # Include batch offsets in pointer calculations
        a_ptrs = a_ptr + batch_offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        # transpose for GEMM (load B[K, N] pattern)
        b_ptrs = b_ptr + batch_offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        
        # We accumulate into fp32 for higher accuracy.
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            # Load the next block of A and B, mask in K
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            accumulator = tl.dot(a, b, accumulator)
            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk
        # Convert back to FP16 for output
        c = accumulator.to(tl.float16)
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        c_ptrs = c_ptr + batch_offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
        tl.store(c_ptrs, c, mask=c_mask)

    The core loop here is:

    • Divide into block_size_k chunks of K
    • For each do a masked load (so when we hit the edges we don’t bring in garbage)
    • Do the dot product for the tile into an accumulator for the result matrix
    • Advance the A and B pointers

    The CuTe kernel is, uhhh… a bit more involved. The full kernel is several hundred lines long. You can see the source in the Cutlass repo, which demonstrates some cool features like the ability to pass in an epilogue function for fusion.

    For now, lets focus on the main loop. As before, we are looping over tiles of K3 .

    The first thing we do is advance the pointers. The kernel is doing explicit pipelining of transfers from global memory, to shared memory, to registers, so we need to set wait groups do ensure all the loading has completed before we advance. We’re not actually doing loading in this section, just prepping the ground:

    for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
    	for k_block in range(num_k_block):  
          if k_block == num_k_block - 1:
        	tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
          tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
          cute.arch.cp_async_wait_group(num_smem_stages - 2)
          cute.arch.sync_threads()

    Next we kick off a copy from shared memory (smem) to registers (rmem) using cute.copy for the (future) A and B tiles.

    k_block_next = (k_block + 1) % num_k_block  # static
    cute.copy(
    	tiled_copy_s2r_A,
      	tCsA_p[None, None, k_block_next],
        tCrA_copy_view[None, None, k_block_next],
       )
    cute.copy(
    	tiled_copy_s2r_B,
        tCsB_p[None, None, k_block_next],
         tCrB_copy_view[None, None, k_block_next],
      )

    Finally, we interleave the transfers of the next A and B tiles from global to shared memory with the actual gemm operation (the equivalent of tl.dot). I will trust the folks at Nvidia that this is an optimal pattern. The pred= in there is the equivalent of masking in Triton.

       if k_block == 0:
     		if k_tile + num_smem_stages - 1 < k_tile_count:
        	    cute.copy(
    				tiled_copy_B,
    				tBgB[None, None, None, k_tile_index],
    				tBsB[None, None, None, smem_pipe_write],
    				pred=tBpB,
    			)
            k_tile_index = k_tile_index + 1
            cute.arch.cp_async_commit_group()
    		smem_pipe_write = smem_pipe_read
    		smem_pipe_read = smem_pipe_read + 1
    			
    		if smem_pipe_read == num_smem_stages:
          	    smem_pipe_read = 0

    The pipelining is explicit, which is nice for debuggability and optimization, but very manual.

    Debugging Tips

    export CUTE_DSL_LOG_TO_CONSOLE=1
    export CUTE_DSL_LOG_LEVEL=10   # up to 100
    export CUTE_DSL_PRINT_IR=1     # dump MLIR
    • cute.printf() gives you a GPU‑side printf.
    • Kernels are aggressively cached; rm ~/.cache/cutedsl if things look stale.
    • Multiple @cute.jit host functions in the same Python scope can confuse MLIR (mainly for launching kernels).
    • The control‑flow rules are strict: no return inside a kernel; initialize everything.

    If you’re exploring GPU kernels for the first time, I strongly recommend starting with Triton. When you need to really get into the weeds, or want to reuse CUTLASS building blocks, its great to have CuTe‑DSL as an option in Python (provided you’re comfortable spelunking in GPU internals).

    1. I spent a lot of time holding it wrong. Arguably, still holding it wrong. ↩︎
    2. No one knows what it means, but it’s provocative ↩︎
    3. Note the explicit unroll tag. When you really want #pragma but can’t. ↩︎
  • How to build unmaintainable kernels

    What do you need to do to get better performance and GPU efficiency out of your model? The GPU-oriented folks at Stanford recently published an early preview of the work they have been doing on the LLM generation of kernels: Surprisingly Fast AI-Generated Kernels We Didn’t Mean to Publish (Yet) – and they have a list:

    • Memory Access Optimization: improving the efficiency of data movement between different memory hierarchies (global memory, shared memory, registers) and ensuring data is accessed in a way that maximizes bandwidth and minimizes conflicts.
    • Asynchronous Operations & Latency Hiding: hide the latency of slow operations (like global memory access) by overlapping them with computation or other memory transfers
    • Data Type & Precision Optimization: using lower-precision data types (like FP16 or BF16) where possible to reduce memory bandwidth requirements, increase cache effectiveness, and potentially leverage specialized hardware units.
    • Compute & Instruction Optimization: making the arithmetic computations themselves more efficient, reducing instruction count, or leveraging specialized hardware instructions
    • Parallelism & Occupancy Enhancement: maximize the number of active warps on the Streaming Multiprocessors (SMs) to better hide latencies and improve overall throughput
    • Control Flow & Loop Optimization: reducing the overhead associated with loops, branches, and indexing calculations

    That’s a good list! In this case though, it emerged not from (just) talking with kernel experts, but also from developing a model to generate kernels:

    We have some very fast AI-generated kernels in pure CUDA-C without using libraries and DSLs such as CUTLASS and Triton. They are performing close to or in some cases even beating the standard expert-optimized production kernels shipped in PyTorch.

    They developed a very straightforward but smart pattern on structuring test-time-compute. They reason about optimizations in natural language before generating code. Then, they branch out into a tree structure of refinements for each optimization idea, to avoid loops in investigation.

    The kernels they generated were somewhere between fast, and very fast:

    Conv2D: 179.9% performance of FP32 torch.nn.Conv2D; problem size: (100, 3, 224, 224) input tensor, conv(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2)

    The team aren’t claiming this is a general solution, but just an interesting proof of possibility, which is certainly is! The walk through of how they got to the final conv2D kernel is fascinating, both in terms of human intervention and the chain of optimizations.

    The final code sample for the Conv2D kernel is included in the appendix. It uses advanced CUDA techniques that we find challenging to write ourselves! We also have more example kernels in this Github repo

    The kernel is very fast for specific shapes, on the L40s, in FP32. Its also a kernel that, by the sounds of it, the team themselves struggled a bit with. It’s very, very specialized. It’s not that a human couldn’t have built it, its that (in most cases) they wouldn’t: it’s not a priority kernel, and all that clever CUDA comes with operational overhead, ties on specific hardware, shapes and so on.

    That in itself isn’t new. If you PyTorch or XLA compile you’ll get a lot of kernels which you probably wouldn’t write, but this adds a new (and weird!) layer of non-determinism to everything. Elsewhere at Stanford, they have been looking at one of the other killers of GPU efficiency: kernel launch overhead. Most models are represented by hundreds of kernels, each of which have to be scheduled from the CPU. LLMs are generally memory-bound, small ones particularly so, and the gaps between kernel executions can end up dominating performance:

    Look Ma, No Bubbles! Designing a Low-Latency Megakernel for Llama-1B:

    In this post, we show how we can bypass this problem by merging the entire Llama-1B forward pass into a single “megakernel” that eliminates kernel boundaries altogether. Doing this achieves brr – on an H100, we use 78% of memory bandwidth and outperform existing systems by over 1.5x. (To our knowledge, this is the lowest-latency forward pass for Llama-1B in bfloat16!) In the rest of this post, we’ll walk through how and why one would do this.

    The idea of megakernels that handle all of the operations is not new, but the complexity of fusing everything together is high. Persistent kernel were popularized at the tail end of the CUDA 11 series, due to the right residency and async copy support in Ampere. They allow leaving kernels resident on an SM and having them pull a series of tiles to do their work rather than scheduling repeated launches of the same kernel. The megakernel takes this idea ever further, with multiple operations within the kernel that pulls a stream of different problems. One issue with this approach (traditionally) is register spilling: you only have so many registers available, up to 255 per thread, though with a fairly high overall limit of 64k 32-bit registers (on Hopper). That means you need to keep some data in shared memory, and efficient use of shared memory ends up being the bottleneck. The team at Stanford developed paging shared memory, with a separate reserved block for managing allocations of shared memory to individual tasks.

    This gets the CPU completely out the picture for the forward pass, but is incredibly specific to the model (in this case Llama 3.2 1B).

    Another collaboration was clearly thinking in the same direction, as they recently posted about their Mirage Persistent Kernel: a compiler for megakernels.

    our team from CMU, UW, Berkeley, NVIDIA, and Tsinghua developed Mirage Persistent Kernel (MPK) — a compiler and runtime system that automatically transforms multi-GPU LLM inference into a high-performance megakernel. MPK unlocks the benefits of end-to-end GPU fusion while requiring minimal manual effort from developers.

    The system works by building a task-graph of what they call LAX1[1] fragments, which in practice is a very short list of 14 operators. This is actually too small to represent everything they need, meaning they have to manually decompose some common ops like ReLu, but this level of decomposition gives them an ability to do some pretty complex fusions.

    The actual ops are generated thanks to Mirage’s Kernel Superoptimizer (a great name), which I think is a very intense autotuner:

    Mirage automatically searches for possible GPU kernels for attention. The search space includes existing manually designed attention kernels (e.g., FlashAttention and FlashDecoding) as special cases. It also includes other implementations that outperform today’s handwritten ones by up to 3.5x for certain use cases. The GPU kernels generated by Mirage can directly operate on PyTorch tensors and be called in your PyTorch program.

    The search is not cheap though:

    In our evaluation, Mirage takes up to 4 hours to optimize a Lax program. This optimization is a one-time cost before deployment on the target hardware.

    The aggressive decomposition allows them to have a clever verification scheme where they validate kernels on random inputs to get confidence in (approximate) numerical correctness.

    They then build a worker kernel with all the relevant operations, and schedule the optimized graph via dedicated scheduler warps. Workers are scheduled on the SMs, and report back status. The scheduler warps then decide when tasks can be enqued for execution.

    They’ve got a code example that walks through setting it up for Qwen. They recreate the model structure explicitly, generate a task graph from it, and kick off the search for optimal kernels and fusions. This avoids the need to solve the Dynamo-style problem of tracing the model!

    The resulting kernel is again heavily tied to the specific hardware and model. One thing we have found useful for investigating production problems is that the ability to ablate different parts of the compile process, running models in (basically) PyTorch eager mode. This approach leaves the darkest of black boxes to work with, and I would imagine even more terrifying PTX than the complex CUDA that the LLM kernel generation team came up with.

    Between these projects though, it feels like we are exploring the edges of what running a program on GPUs should actually look like: a combination of kernel generation and multi-level search seems almost guaranteed to yield optimizations that would be far outside the cost-benefit curve for manual implementation. What we don’t have yet is known ways to operationalize this kind of approach, but its an exciting area to watch!

    Thanks to Matt for nudging me to actually read through these papers, they’d been on my todolist for a bit!


    1. I am not sure what this stands for, but the basic ops in jax are in jax.lax, so I presume its the same source! ↩︎

  • GPU Driven

    Traditionally, GPU collective network operations were issued from the framework on a separate CUDA stream than the local computation kernel launches. This allowed overlapping comms and hiding most or all of the network latency. NCCL exposes collectives as fully implemented kernels, and there have been various derivitives such as AMD’s RCCL or Berkeley’s new UCCL project, which is aiming to be a drop-in replacement better suited for large-scale GPU workloads1. Earlier versions of this sent the networking via the host, and later developed towards GPU-to-GPU peer to peer over connections like NVlink, and direct communication between GPU and the NIC for scale out. But the actual coordination was driven by launches from the CPU.

    The increasing capacity of GPUs, particularly in the Hopper/MI300x era, the use of micro-batching, and the rise of Mixture-of-Expert models put a lot more pressure on this arrangement: now rather than a large chunk of comms at the end of each distributed-data-parallel pass you are doing thousands of small exchanges per step. Each step could require each CPU rank to launch thousands of tiny All-to-All: millions of collective calls across the cluster, while still running the rest of the training loop. Each one forces a host interrupt, collective calls and GPU triggers for the data transfer.

    A paper from last year, The Landscape of GPU Driven Communication, gives a broad overview of this shift:

    In the last decade, several advancements, broadly referred to as GPU-centric communication, have sought to challenge the CPU’s hegemony on multi-GPU execution. At a high level, these advancements reduce the CPU’s involvement in the critical path of execution, give the GPU more autonomy in initiating and synchronizing communication and attempt to address the semantic mismatch between multi-GPU communication and computation.

    Getting the right kind of abstractions in here to make this more accessible and flexible is an active area of development. The main reference is the shmem abstraction. This allows writing and reading bytes from remote memory (originally for supercomputers) and adding barriers around usage. Nvidia’s nvshmem (and AMDs ROCshmem) library directly implemented this for GPU.

    The next big step up in functionality was GPU Direct Async, a transport that allows access to NIC doorbells directly from the GPU, extending NVSHMEM to RoCE/InfiniBand (Remote Direct Memory Access – RDMA). In addition to this, NVLS (NVLink Sharp) was added to NVLink switches which allowed much more bandwidth efficient switch-managed multicast for broadcast and reduce cases, the fundamental operations used in all-gather and all-reduce collectives. This allows GPU-initiation of collectives which gives us the ingredients to get the CPU out of the networking path completely, and to fuse networking alongside compute operations.

    This was one of the things DeepSeek did really well, covered in their DeepEP work: fusing MoE GEMM with GPU-initiated RDMA cut single-node latency by 2–3x. The tradeoff is kernels handle a lot of complexity: choosing between NVLink and GPUDirect for nodes, polling, flow management and so on is tuned to their specific needs and hardware.

    ByteDance has also spent a lot of time looking at this problem. Their Flux paper looks at more general approach for doing tile-based fusion of the comms and compute:

    Flux overdecomposes computation and communication into tiles. Here, since the computation operation is GEMM, and most high-performance GEMM kernels on GPUs are written with tiling, such as thread block tiling or warp tiling, our decomposition can naturally map into existing tiling in the kernels. Flux fuses dependent communication and/or wait logic into a GEMM kernel, and launches only one fused kernel, compared to the prior methods launching multiple split GEMM kernels.

    PyTorch has an experimental feature and RFC for SymmetricMemory:

    Then, innovative block-wise compute/computation overlapping techniques started to use copy-engine to drive the P2P traffic in order to minimize contention. Now, we are seeing techniques where NVLink communication is directly issued from “compute” kernels.

    […]

     Just as Triton allows average users to modify matmuls for their needs (fusion, quantization, etc.), we hope that SymmetricMemory will enable average users to modify NVLink communication algorithms for their requirements, whether it’s implementing alternate collective algorithms (one-shot allreduce), using different quantization approaches (stochastic rounding), or fusing collectives with other kernels (all-gather interleaved with matmuls).

    This project is still fairly manual, though it abstracts much of the plumbing required and makes it accessible at a high level.

    In a similar vein, ByteDance recently released their implementation of Triton-Distributed: ByteDance-Seed/Triton-distributed: Distributed Compiler Based on Triton for Parallel Systems. This adds a small Triton DSL wrapping the shmem operations, supporting both Nvidia and AMD hardware. It exposes some of the comms knobs to autotuning, allowing tuning across compute and collectives, focusing on the tile-based approach they documented in their TileLink paper.

    This is unlikely to make its way into upstream Triton, in part because the general approach of Triton is to hide much of the hardware information in the compiler passes, while this approach makes the topology fairly explicit.

    The PyTorch team has been working on traceable collectives for the PyTorch compiler. The idea of a compiler being able to look at the overall task graph and decide where to fuse comms and which transport options to use is appealing, as it allows kernels to be more transparent to the specifics of the cluster.

    The move to GPU-initiated, fine-grain comms that fuse into compute kernels is real and continuing; the tooling is still early, but the gap will continue to close.

    1. This currently host-side controlled, but they have plans for GPU-driven comms as well ↩︎
  • Linear Layouts in Triton

    [2505.23819] Linear Layouts: Robust Code Generation of Efficient Tensor Computation

    Paper from the Triton folks at OpenAI on their solution to the layouts/data movement problem. Data often needs to be laid out in a specific way to maximize performance on a GPU. This includes certain instructions, and also avoidance of bank conflicts in shared memory. You might have data stored nicely in global memory, need to permute it to load, then permute it again for execution.

    Part of the appeal of CuTe is expressing these layouts and allowing a relatively simple algebra to transform it between these domains. This works, but the Triton approach is to try and hide this type of complexity, particularly hardware specific complexity, in the compiler.

    While both CUTE and linear layouts aim to address the challenge of flexible task mapping on emerging architectures, they differ in several key aspects. First and foremost, CUTE is primarily designed for users to manually describe layouts, whereas linear layouts are integrated into a compiler. Second, the linear algebra framework of linear layouts enables compilers to generate efficient code for layout conversion and code lowering for many common operators, which is absent in CUTE. Third, swizzling is inherently defined within linear layouts, whereas in CUTE, it is treated as a separate step

    The clever insight is that you can represent any of the layouts as a binary matrix over F₂, which means you can use XOR/AND for arithmetic. You can compose those binary matrices freely, and it’s also easy to replace the transform matrix with a new one for hardware that requires a different permutation.

    To give a step-by-step example (as I’m not totally sure how well I grok this myself!) let’s say we are working on am MMA for a 16×8 tile:

    • We start with our data, say in row major order (0,0), (0,1), …, (0,7), (1,0). Each value is stored in its own register
    • We have 32 threads, each managing their own section of the block: in this case 4 registers
    • So we have a hardware location for each value: the thread (0..31) and the register (0..3). You can imagine this as 7 bits of data, thread ID (5 bits), and register ID (2 bits)
    • Equivalently we have imagine tracking the tensor location for each value: 4 bits for 0..15 rows, 3 bits for 0..7 columns
    • We can have a map which translates between tensor location and hardware location: block location row 1 col 0 is in thread 2 register 0. This would be a 7 by 7 binary matrix
    • We can define a matrix that transforms the hardware map to the one needed for our ldmatrix tensorcore call.
    • For example, we might need thread 0 to manage tensor values (0,0), (4,0), (8,0), (12,0)
    • If the mapping requires moving a value to a different register in the same thread we can use a prmt (permute) instruction
    • If the mapping requires moving values between thread’s registers, we can use a warp shuffle like shfl.sync that allows swapping registers between threads without using shared memory1

    Triton has layouts for standard block level storage, and for MMAs and other operations. By multiplying through the required mappings it can automatically work out how best to optimize movement, versus the manual transforms you do in CuTe!

    It also has versions of these mappings for different hardware, so for many operations only the layouts need to be swapped out when moving from Ampere to Hopper or Blackwell!

    1. mostly. if there will be bank conflicts, it will spill to shared memory. ↩︎
  • Analyzing Modern GPU Cores

    [2503.20481] Analyzing Modern NVIDIA GPU cores

    Filing this under interesting work I will probably never use. The authors try to construct a more accurate simulation of the Ampere (4090/A100 type GPUs) microarchitecture, backed by extensive testing on real hardware. It’s a good reminder that, in part because of how good some of the abstractions are, there is quite a lot about Nvidia GPUs that isn’t really known outside Nvidia. My main takeaway was that the compiler is very deeply coupled to the hardware performance: a Nvidia chip is not really a complete unit without taking in to account the software driving the performance, and recognizing that accounts for why Nvidia have done such a good job of building a solid stack with CUDA.

    One of the things I found interesting was the use of a Stall counter: the compiler notes fixed latency instructions (which seem to be a preferred design choice) and adds a counter to the instructions control bits that specifies how many cycles the warp should wait before issuing the next instruction, and so other warps will be selected for execution. This means the hardware doesn’t have to dynamically check for data dependencies.

    For example, an addition whose latency is four cycles and its
    first consumer is the following instruction encodes a four in the Stall counter. Using the methodology explained in section 3, we have verified that if the Stall counter is not properly set, the result of the program is incorrect since the hardware does not check for RAW hazards, and simply relies on these compiler-set counters. In addition, this mechanism has benefits in terms of area and energy wiring. Keep in mind that wires from fixed-latency units to the dependence handling components are not needed, in contrast to a traditional scoreboard approach where they are required.

    There are variable execution length instructions, like memory loads, and in that case they have a Dependence counter, which is decremented when data arrives.

    In the vein of handing off to the compiler, the scheduler uses a Compiler Guided Greedy Then Youngest policy: it will keep issuing instructions from the same warp (greedy) with guidance from the Stall (and an explicit Yield bit) and otherwise will swithch to the youngest ready warp. Older GPUs (apparently!) used Greedy Then Oldest instead, which resulted in more often selecting a warp that was still stalled waiting for memory or similar, while the youngest more likely has useful work to do.

    The scheduler starts issuing instructions from the youngest warp, which is W3, until it misses in the Icache.As a result of the miss, W3 does not have any valid instruction, so the scheduler switches to issue instructions from W2. W2 hits in the I-cache since it reuses the instructions brought by W3, and when it reaches the point where W3 missed, the miss has already been served, and all remaining instructions are found in the I-cache, so the scheduler greedily issues that warp until the end. Later, the scheduler proceeds to issue instruction from W3 (the youngest warp) until the end, since now all instructions are present in the I-cache.

    Similarly, the paper points out that the instruction prefetch cache is a stream buffer (probably 16 instructions deep) rather than any kind of complex branch prediction logic, because we generally don’t do that kind of thing on GPUs!

    a straightforward prefetcher, such as a stream buffer, behaves close to a perfect instruction cache in GPUs. This is because the different warps in each sub-core usually execute the same code region and the code of typical GPGPUs applications do not have a complex control flow, so prefetching 𝑁 subsequent ines usually performs well. Note that since GPUs do not predict branches, it is not worth implementing a Fetch Directed Instruction prefetcher [76] because it would require the addition of a branch predictor.

  • Keeping a GPU busy is a lot about tiling

    File this under the “gross oversimplifications” category. The basic approach to keeping GPUs busy is dividing the work into tiles, smaller sub-problems that make up the larger result. For a GEMM you might break the matrix into 128×128 or 128×64 tiles and let each CUDA thread block (CTA) own one tile. The GPU has many streaming multiprocessors (an A100 has 108) and every SM picks up one CTA at a time. If you want to know how many SMs your own card has you can call:

    props = torch.cuda.get_device_properties(0)
    print(f"SMs: {props.multi_processor_count}")

    Tiles are launched in waves. A full wave is the moment when every SM is busy with exactly one CTA. If the total number of tiles isn’t a multiple of the SM count, the final wave is only partly full and some SMs sit idle; Nvidia calls that wave quantization. There is a similar problem at the edge of the matrix: if the dimensions aren’t multiples of the tile size the right-most or bottom-most tiles are partly empty, wasting threads (tile quantization). Sometimes a smaller tile size (for example 64 × 64) gives higher overall throughput because it leaves less unused space at the edges.

    The usual cure for poor wave utilization is a persistent kernel. Instead of launching one CTA per tile, you launch (roughly) one CTA per SM and have each CTA pull tiles from a global queue until the queue is empty. Because each CTA is pulls whenever ready, the SMs rarely go idle and the tail effect is reduced.

    Inside an SM the main performance lever for GEMMs arethe Tensor Core, which execute matrix-multiply add (MMA) instructions efficiently. On Ampere you use WMMA instructions: one Warp (32 threads) computes a 16 × 16 fragment at a time. Hopper introduces WGMMA instructions where four warps acting in ia warp-group (128 threads) execute a larger matrix multiply (up to 64 × 64 for FP16/FP8). To issue WGMMA you must place the right-hand operand B in shared memory; A can sit in either registers or shared memory. The operation is asynchronous, so while a warp-group is processing one tile the same CTA can be pre-loading the next tile.

    Blackwell pushes the idea further. A pair of CTAs on neighbouring SMs can cooperate in a pair unified MMA, letting two SMs’ tensor cores process an even larger tile.

    To make that possible Hopper introduced thread-block clusters and Blackwell extends them. When you launch a kernel you can group CTAs into clusters such that the scheduler guarantees to place them on SMs inside the same GPC (GPU Processing Core), so they share a fast interconnect and can access shared memory across SMs. If the grid doesn’t divide cleanly into whole clusters you also lower efficiency on the tail (is this cluster quantization? stick with the trend Nvidia!) so Blackwell has a Cluster Launch Control that can shrink the last cluster to better fit the work.

    Loading Data

    All of this only works if data is present in shared memory. The first optimization is making sure (global) memory access is coalesced. A 32-thread warp can request 32-byte chunks , but the memory bandwidth for a single fetch from DRAM is wider. e.g. If four consecutive threads request address 1, 4, 8 and 12, the memory controller can coalesce these into a single 128-byte read. If the addresses are strided (e.g. hopping across rows) then only 32 bytes out of the 128 byte fetch capacity is loaded at a time, so the load takes longer. Getting this right is about ensuring the memory layout is set up for the kernel, and doing any transforms needed in shared memory before executing.

    In older GPUs the warp had to wait on the copy operation. Ampere enabled cp.async plus non-blocking wait/arrive barriers so a warp can initiate a copy from global to shared memory and immediately continue with arithmetic. Hopper adds the Tensor Memory Accelerator: with TMA, a single thread in the CTA can describe a multidimensional block to copy and the TMA hardware streams it to shared memory while the threads do something else. Blackwell goes one step further and can multicast a single TMA load into every SM of a cluster, which is helpful when multiple CTAs are about to reuse the same B tile.

    In practice you hide latency by organizing the main loop using so that it double buffers: while the tensor cores work on tile k the TMA or cp.async engine is fetching tile k + 1 into the other half of shared memory; then you swap buffers and repeat. As long as copy time and compute time overlap well, the tensor cores and the copy engines stay saturated.

    Choosing the right tile size

    Choosing the right tile size (often expressed in Triton as BLOCK_M × BLOCK_N) is a balance between each of these: enough threads to issue a warp-group MMA, small enough tiles that the matrix edges aren’t mostly padding, enough shared-memory space to double-buffer, and a grid size that fills whole waves or is run via a persistent kernel. Autotuning in Triton or CUTLASS can empirically test different options on the hardware, but it helps to have the right mental model about what sets of sizes they should consider. One good clue that you’re missing an option is when you see a sudden drop in achieved TFLOP/s for particular shape.

    AMD

    AMD’s MI300X hardware takes a somewhat different route. The GPU is divided into chiplets, where each chiplet has its own compute units and multiple schedulers that schedule wavefronts (AMD for warps, 64 threads rather than 32) independently, so the hardware load-balances multiple kernels by itself. Matrix instructions run at the wavefront level; there is no cross-CU equivalent to WGMMA. Latency hiding relies on launching a large grid of workgroups and letting the hardware interleave them, rather than on explicitly scheduling async copies. On AMD the guidance is to mostly focus on high occupancy and coalesced memory access, whereas on NVIDIA there is value in crafting (by hand or compiler) the copy–compute pipeline.

  • How does Triton do Warp Spec?

    Kapil Sharma from the PyTorch team has a great series of posts diving into the Triton compiler process: 1, 2, 3. As covered there, Triton lowers to a series of intermediate representations, and each level has a set of transformational passes that implement optimizations. TTIR is the generic Triton IR and leverages a number of standard MLIR passes like common subexpression elimination, as well as some Triton specific passes like managing broadcast ops. That’s then lowered to TTGIR, a GPU-specific IR1

    triton/third_party/nvidia/backend/compiler.py at rc/3.3.x · triton-lang/triton

    The different backends configure the passes appropriate for their targets, so the Nvidia TTGIR configuration above details the passes for Nvidia hardware. Some are gated on the specific backend targeted, like warp specialization:

    passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
    opt.reg_dec_producer, opt.reg_inc_consumer)
    passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
    passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)

    Another example is using the Tensor Memory Accelerator for async loading on Hopper+

    if capability // 10 >= 9:
       nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
       nvidia.passes.ttnvgpuir.add_fence_insertion(pm)

    For a quick recap of the why and how of Warp Specialization, check Colfax’s guide to optimizing a GEMM:

    The most basic method by which GPU programmers can create overlapping is via excess warps (warps are groupings of 32 contiguous threads). Nvidia GPUs allow a large number of warps per SM (streaming multiprocessors), and can switch between them with minimal overhead. In particular, the warp schedulers can simply switch to another warp if one warp encounters a slow memory fetch. In order to give the warp schedulers more opportunity to hide latency, a technique called warp-specialization was introduced circa 2011 [1, 2]. With warp-specialization, some warps are dedicated to memory fetches (producers), while others are dedicated to compute (consumers), and named barriers are used for synchronization between them. The idea is that the warp schedulers can then more easily hide the latency of copy operations within compute (and vice-versa).

    Even more generally than overlapping memory loads you can overlap other kinds of work. SMs have 1 Tensor Core to 32 ALUs (Cuda Core) (x4 on recent hardware). This means that you can overlap other kinds of work, stuff that isn’t dot products. It’s really common to want to load memory, do a matmul then apply a pointwise function like a relu or other activation function. You aim to keep the Tensor Core as busy as possible with a series of matmuls, and warp specialization lets you do that.

    The transforms to implement this are implemented in triton/lib/Dialect/TritonGPU/Transforms at rc/3.3.x · triton-lang/triton

    The first task partitions ops in the kernel. This task looks for load ops and dot product ops, and partitions them into producer (loads) and consumer (process) groups.

    // Step 1. Select loads into the first task, which is the producer task by
    // default. Place dots into the second task, which is the consumer.
    // Only consider loads that are connected to a dot op in a loop.

    The next task is a bookkeeping one to propagates the task IDs, so if there are unlabeled ops they are attached one of the partitions.

    The WSDataPartition transform partitions the dot operations into consumer groups. It splits the dot products inside loops along M or N dimensions to be processable within a warp group, ensuring all dot operations are sliced and the slices labelled with task_ids.

    Just to look at some of the numbers: A warp consists of 32 threads, and a sync op (for loading) uses a whole warp. A warp group is a set of 4 warps: this is pertinent because the TensorCore MMA prefers 4 warps working on 64x(64/128/256)xK tiles. Triton already has a WarpGroupDotOp that tries to set this up, and that’s one of the operations targeted in this pass. The pass splits a Triton CTA tile, which may be 128×256, so that each consumer warp group has a 64 row (or 256 column) chunk.

     if (sliceSizeM >= 64) {
          LLVM_DEBUG({ LDBG("partition along M\n"); });
          partitionDim = 0;
          partitionSize = sliceSizeM;
          partitionOperand = opndA;
        } else if (sliceSizeN >= 256) {
          LLVM_DEBUG({ LDBG("partition along N\n"); });
          partitionDim = 1;
          partitionSize = sliceSizeN;
          partitionOperand = opndB;
        } else {
          LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN);
          return false;
        }

    The next pass is WSCodePartition. This is a big transform. It takes the task-sliced IR from the DataPartition and sets up the producer warp group to copy from global GPU mem to SMEM (or on blackwell TMEM). It also drops in barriers using product.acquire/cpmmit and consumer.wait/release to ensure proper ordering between the groups. The transform identifies “channels”, places where data is loaded (using load or descriptor_load for TMA on Hopper+) and associates the producer task_id with all the consumer task IDs that need to process that data.

    Conceptually, the transform is turning the loops in the original kernel into something like this:

    for k: # K-dimension tiles)
        ##### PRODUCER #####
        producer.acquire(token, idx, phase) # reserve smem[idx]
        async_copy_global_to_local(smem[idx]) # GMEM → SMEM[idx]
        producer.commit(token, idx) # make slot visible
    
        #####  CONSUMERS (run in parallel warps) #####
        ## Consumer 0 ##
        consumer.wait(token, idx, phase) # sleep until slot ready
        mma_sync(accum0, smem[idx], …) # read-only, do matmul
        consumer.release(token, idx) 
    
        ## Consumer 1 ##
        consumer.wait(token, idx, phase)
        mma_sync(accum1, smem[idx], …)
        consumer.release(token, idx)
    
        # Repeat for extra consumers. 
    
        # increment circular buffer
        idx   = (idx + 1) % numBuffers 
        # Toggle each time we hit producer, indicate old vs new data.
        phase = phase ^ (idx == 0)
    }

    The next pass is a generic pipeline pass. Each op is assigned a stage based on latency: e.g. slower ops like loads go into stage 0. This is then transformed with modulo scheduling. Any sync ops are converted to async variants (in lowerLoops), before writing out (in expandLoops) prologue, kernel and epilogue loops that contain all the relevant ops.

    Finally the WSLowering pass takes the various operators we have (like producer.acquire) and replaces them with the hardware specific variants (e.g. wait_barrier). It also handles the bookkeeping like generating the warp group and task_ids from the warp ID:

    OpBuilder builder(op);
    Value _4 = builder.create<arith::ConstantIntOp>(loc, WARPS_PER_TASK, 32);
    Value warpId = builder.create<ttng::GetCanonicalWarpIdOp>(loc);
    Value asyncTaskId = builder.create<arith::DivUIOp>(loc, warpId, _4);
    op.getResult().replaceAllUsesWith(asyncTaskId);

    This is a wordy IR way of saying

    warpId = gpu.thread_id().x / 32 
    task_id = warpId / 4

    Now the code is ready for lowering to regular PTX, and all of the warp-specific stuff is captured explicitly!

    1. Note: Because of some project weirdness, warp specialization is quite different in the release branches from main, so I’ll refer to 3.3 from here on. It’s in very active development (by teams at Meta, OpenAI and Nvidia!) so the specifics are quite likely to change over coming releases! ↩︎
  • Autotuning in PyTorch & Triton

    torch.compile offers some knobs for controlling the trade-off of execution performance with longer compile times. This is particularly useful for inference, where the same model will be running for a long time.

    model_autotune = torch.compile(model, mode="max-autotune")

    Passing the max-autotune option to instructs the compiler to test more options for the operations. The compiler has the option to use pre-built aten kernels, leverage kernels from libraries like CuDNN or Cutlass, or use templated Triton kernels. When autotuning, specific variants are tested on device with the shape information identified during tracing, and the fastest options are selected. Thanks to Triton templates, it can also use options like fusions where pointwise ops can be fused into a single kernel via a Triton template, saving kernel launch overhead.

    The downside of this is that testing the options takes more time, so using max-autotune can lead to some very extended compile times. You also need a hefty enough GPU to get the benefit: is_big_gpu gates it on the number of SMs, so it works best on a 3090, V100 or above.

    You can see a lot of the autotuning options in _inductor/config.py. Backends that are considered are set separately for GEMMs and convolution ops:

    max_autotune_gemm_backends = os.environ.get(
        "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
    ).upper()

    Each kernel has implementations using for the different backends which are added to possible choices. e.g. in _inductor/kernels/mm.py you can see calls to use_[backend]_template that verify whether the backend in question is a choice:

    if is_nonzero and use_cutlass_template(layout, m, n, k):
            CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

    _inductor/select_algorithm.py does the actual benchmarking through the choices.

    If you run autotuning, you’ll get some log output, and caches will be written to /tmp/torchinductor_yourusername.

    We can try this out on a simple MLP:

    import torch, time
    
    class SimpleMLP(torch.nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.linear1 = torch.nn.Linear(in_features, hidden_features)
            self.relu = torch.nn.ReLU()
            self.linear2 = torch.nn.Linear(hidden_features, out_features)
        def forward(self, x):
            return self.linear2(self.relu(self.linear1(x)))
    
    # Set up device and model
    device = 'cuda'
    model = SimpleMLP(in_features=1024, hidden_features=1024, out_features=1024).to(device)
    x = torch.randn(256, 1024, device=device)  # batch of 256, 1024 features each
    
    # Compile the model in default mode and max-autotune mode
    model_default = torch.compile(model, mode="default")
    
    # Warm-up runs (to trigger compilation)
    torch.compiler.reset()
    with torch.no_grad():
        model_default(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of default compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_default(x)
        end.record()
    torch.cuda.synchronize()
    time_default_ms = start.elapsed_time(end) / 50.0
    torch.compiler.reset()
    
    model_autotune = torch.compile(model, mode="max-autotune")
    
    with torch.no_grad():
        model_autotune(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of max-autotune compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_autotune(x)
        end.record()
    torch.cuda.synchronize()
    time_autotune_ms = start.elapsed_time(end) / 50.0
    
    print(f"Average inference time - torch.compile default: {time_default_ms:.3f} ms")
    print(f"Average inference time - torch.compile max-autotune: {time_autotune_ms:.3f} ms")
    

    Disappointedly, this is the result:

    Average inference time - torch.compile default: 0.113 ms
    Average inference time - torch.compile max-autotune: 3.251 ms

    We can turn on logging with the TORCH_LOG env variable: some useful options are inductor, autotuning, and perf_hints.

    TORCH_LOGS="perf_hints" python tune.py

    You can control many more autotune options via the options flags, though its incompatible with passing a mode value. We can recreate the max-autotune mode, and turn on some useful tracing options like this (note that the options version uses an underscore, the mode a hypen!)

    model_autotune = torch.compile(
    model,
    options={
    "max_autotune": True,
    "triton.cudagraphs": True,
    "coordinate_descent_tuning": True,
    "trace.enabled": True,
    "trace.graph_diagram": True,
    },
    )

    Options "trace.enabled": True, "trace.graph_diagram": True generate trace outputs, and output a nice diagram of the captured graph. Cudagraphs turned out to be the culprit here, which is common enough there is a non-cudagraph mode available to stop you having to remember all the options:

    model_autotune = torch.compile(model, mode="max-autotune-no-cudagraphs")

    As you can see here in the graphs of with and without, the slower version actually has an extra fusion performed!

    Captured graphs for the two runs

    Triton Autotuning

    Triton also conducts autotuning, but it’s a little more explicit. When authoring a Triton kernel you can specify configurations. At compile time each config variant will be tested, the most performant one picked and the choice stored for future calls. A key value can be provided to indicate when to re-autotune based on changing inputs:

    import os
    import torch
    import triton
    import triton.language as tl
    
    # Just to save passing this on the command line
    os.environ["TRITON_PRINT_AUTOTUNING"] = "1"  
    
    @triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE': 128}, num_warps=4,  num_stages=2),
            triton.Config({'BLOCK_SIZE': 256}, num_warps=8,  num_stages=2),
        ],
        key=['N']            # re‑tune only if the length N changes
    )
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        pid   = tl.program_id(0)
        offs  = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        x     = tl.load(x_ptr  + offs, mask=mask, other=0.0)
        y     = tl.load(y_ptr  + offs, mask=mask, other=0.0)
        tl.store(out_ptr + offs, x + y, mask=mask)
    
    def vec_add(x: torch.Tensor, y: torch.Tensor):
        assert x.is_cuda and y.is_cuda
        N   = x.numel()
        out = torch.empty_like(x)
        grid = (triton.cdiv(N, 128),)         # 128 = smallest BLOCK_SIZE we declared
        vecadd_kernel[grid](x, y, out, N)    
        return out
    
    x = torch.randn(1 << 20, device="cuda")   # 1 048 576 elements
    y = torch.randn_like(x)
    
    _ = vec_add(x, y)  # first call → autotuning prints to stdout
    _ = vec_add(x, y)  # second call → no autotuning, uses the best config found
    

    Setting the env variable TRITON_PRINT_AUTOTUNING documents the process as it goes:

    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 256, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Triton autotuning for function vecadd_kernel finished after 0.44s; best config selected: BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;

    You can use the same do_bench tester that the autotuner does, and see how the performance varies yourself:

    import torch, triton, triton.testing as tt
    import triton.language as tl
    
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        offs  = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        tl.store(out_ptr + offs,
                 tl.load(x_ptr + offs, mask=mask) +
                 tl.load(y_ptr + offs, mask=mask),
                 mask=mask)
    
    # tensors
    N   = 1 << 20
    x   = torch.randn(N, device='cuda')
    y   = torch.randn_like(x)
    out = torch.empty_like(x)
    
    def bench(block_size, num_warps):
        grid = (triton.cdiv(N, block_size),)
        # tt.do_bench returns [median, p20, p80] in micro‑seconds
        return tt.do_bench(
            lambda: vecadd_kernel[grid](x, y, out, N, BLOCK_SIZE=block_size, num_warps=num_warps),
            warmup=5, rep=16, return_mode="all", quantiles=(0.5, 0.2, 0.8)
        )
    timings = {
        "128/4": bench(128, 4),
        "256/8": bench(256, 8),
    }
    
    print("timings:", timings)
    

    Running that gives shows that both kernels are basically equivalent, but the first one is slightly faster over the 16 runs.

    timings: {'128/4': [0.01945599913597107, 0.01945599913597107, 0.02028159946203232], '256/8': [0.01945599913597107, 0.01945599913597107, 0.020479999482631683]}