Category: note to self

References and aide memoires

  • Rubrics

    Rubrics

    Pre-training is about making AI correct, post-training is about making AI helpful1. That helpfulness is (primarily) shaped by reinforcement learning. RL for LLMs really took off with RLHF (RL from Human Feedback), which trained based on the score from a reward model.

    The reward model was designed to score responses based on how well they met certain preferences, and the preferences were inferred from a set of human ratings: the graders were told what to look for in pairs of responses, and the reward model was trained to predict what they would pick. This worked, but was gated on how much signal you could get into the reward model and hence how many humans you had to generate preference data.

    RLAIF (RL from AI Feedback) naturally extended this to using an LLM to make the preference picks rather than humans2. Folks also started to use LLMs in an LLM-as-Judge pattern for evaluation after training: give the model a list of criteria, and ask it to rate how well the responses meet them. 

    The next notable step was RLVR (RL with Verifiable Rewards), which uses ground-truth data to provide rewards scores instead of a model. For example, a math problem might have a defined numeric answer, or a generated proof could be verified by a dedicated theorem prover program. This turned out to work very well for code and math and lead to the O-series of OpenAI models3 and many open reasoners, particularly Deepseek R1. 

    It’s a pretty natural idea to take a verifiable reward pipeline plug in AI scoring directly: rather than a model generate preference pairs and train a separate reward model, give the model criteria and ask it how well the response satisfies them. This means instead of letting a model work out what “good code” looks like from pairs of different (but similar!) solutions to a problem, you have a model working through a checklist, asking things like “Does it have types? Does it have comments? Would your coworkers hate you if you landed this?”

    These checklists are referred to as rubrics and Snorkel have started an interesting looking blog series introducing rubrics, which offers a definition: 

    A rubric is a structured guide that spells out what “good” looks like for each response from an AI system. 

    A rubric consists of:

    • A list of criteria: Does the code compile? Does it have comments?
    • How the model performed on each criterion: “Compiles” could be yes/no. It could also be more nuanced: yes/yes with warnings/no.
    • Scoring rules that turn performance into numbers: Clean = 0. Warnings = 1. No = 2.

    In Nathan Lambert’s recent interview with Ross Taylor, Taylor calls rubrics out as an underappreciated research opportunity, particularly for agentic training:

    Rubrics are underhyped on social media – they were driving force behind projects like DeepResearch – and GenRMs are interesting but perhaps slightly overhyped.

    This caught my eye, as Moonshot leveraged rubric based rewards heavily in Kimi K2, notably using the model they were training as the judge of itself: 

    The framework operates using a Self-Critique Rubric Reward mechanism, where the model evaluates its own outputs to generate preference signals. To bootstrap K2 as a competent judge, we curated a mixture of open-source and in-house preference datasets and initialize its critic capability in the SFT stage.

    One of the core values of rubrics is that they work for both LLMs and humans. You can iterate on rubrics with people, scale them with LLMs, and spot-check LLM results with human raters to ensure reliability. 

    The paper [2507.17746] Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains formalizes them as a full peer to Verifiable Rewards. The paper sets up rubrics so each criteria is a simple pass/fail and each has a predefined importance weight. They normalize everything so the system can’t get gamed by just adding more criteria4, and then plug in the resulting score in to the RL loop5.

    Of course, you actually have to write the rubrics, which leads to a specificity versus generality tradeoff: take more time to write more rubrics or rely on fewer, more general ones. The RaR paper makes it clear that more is better:

    predefined generic rubrics substantially underperform compared to prompt-specific ones, underscoring the importance of contextualization. Rubrics that include a broader range of criteria—both positive and negative—consistently outperform those limited to essential checks, suggesting that richer evaluation signals lead to better learning.

    As you might have guessed, the solution was more LLM: use a model to generate prompt-specific rubrics:  

    For each domain, the prompt (included in Appendix H) instructs the LLM to generate 7–20 rubric items based on the complexity of the input question. Each item is assigned a categorical weight (e.g., Essential Criteria, Important Criteria) to determine its importance to a correct answer. The rubrics are designed to be fully self-contained which means that non-expert readers should be able to evaluate response quality using only the rubric. 

    This particularly benefited from having a reference answer attached to the prompt. The models do a much better job of coming up with a good rubric if provided with a (human generated) “good” answer to judge against rather than just the question/prompt. This really opens the door to 1:1 rubrics: given questions and reference answers, you can generate a scoring checklist for each one and mix it with verifiable rewards during post-training. 

    The field continues to be turtles all the way down: using LLMs to write rubrics to have LLM judges evaluate LLM training outputs. At some point, someone’s going to suggest we use rubrics to evaluate how good our rubrics are, and honestly, I’m surprised that paper doesn’t already exist6.

    1. Correct in predicting the next token, and helpful, honest and harmless, specifically. ↩︎
    2. With humans still looped in to validate that the ratings were reasonable. The human graders went from generating ratings to rating the raters. ↩︎
    3. This is the part where everyone pretends they know exactly how O1 works, but actually we’re all just pattern-matching from breadcrumbs ↩︎
    4. Else we’d risk giving more focus to problems with more rubrics, and end up with something unthinkable like a coding model that liberally sprinkles emojis everywhere ↩︎
    5. In practice, they also tried a single LLM judge that took in all criteria and weights and generated a scalar reward, which seemed to work fine. ↩︎
    6. It probably does, I’m just scared to look ↩︎
  • Reinforcement Learning Continues To Be The Frontier

    Back in 2021, OpenAI nixed its robotics team, leading to comments on Hacker News like “Reinforcement learning itself is a dead-end on a road to AI”. Now, in 2025 we are surrounded by RL post-trained reasoning models and Mary Meeker is using the word “unprecedented” a lot. This kind of skepticism/hype overlap is very common right now, as Helen Toner breaks down in her excellent recent post/talk on unresolved questions in AI:

    Last year, we had coverage from the Wall Street Journal—really good reporting—about real challenges inside OpenAI with scaling up their pre-trained models and how difficult that was and how they weren’t happy with the results, and then on the literal same day we had the release of o3, the next generation of their reasoning model, and François Chollet—who’s famously skeptical—saying that it was a significant breakthrough on his ARC-AGI benchmark. So these very contradictory takes, both of which had some truth to them.

    The framing used in that post is really useful: it’s less about “are we making progress?” and more “are we on the right branch of the tech tree?”

    A lot of people thought RL was the wrong branch: after notable successes from DeepMind and OpenAI, RL had become a bit of a backwater, with some resurgence (in a limited form) from Reinforcement Learning with Human Feedback (RLHF) for preference tuning LLMs.

    The reason people keep coming back to reinforcement learning is the ability to discover new things. Supervised learning is somewhat inherently bound by the dataset. A reinforcement process can continue to explore and find new strategies, like the famous examples of AlphaGo choosing moves humans wouldn’t have. Tim Lee has an excellent non-technical introduction to the evolution of RL that mentions this: Reinforcement Learning Explained

    In short, imitation learning can rapidly teach a model to mimic the behaviors in its training data, but the model will easily get confused in unfamiliar environments. A model trained with reinforcement learning has a better chance of learning general principles that will be relevant in new and unfamiliar situations

    In this direction, a recent paper, [2507.00432] Does Math Reasoning Improve General LLM Capabilities? Understanding Transferability of LLM Reasoning, suggests1 that reasoning generalizes better from RL-driven learning than supervised fine-tuning.

    RL-tuned models achieve significant gains on math reasoning while preserving positive transfer to other reasoning tasks and non-reasoning tasks, whereas SFT often incurs negative transfer on non-reasoning benchmarks. Second, PCA analysis of latent space confirms that RL induces minimal drift from backbone representations thus maintaining feature stability, while SFT produces larger latent shifts, especially in non-reasoning domains. Third, token-distribution analysis shows that RL selectively adjusts only a handful of task-relevant tokens, whereas SFT perturbs many irrelevant tokens, indicating RL’s more targeted optimization.

    RLHF is implemented by first training a reward model based on human preference feedback: you give people two versions of an answer, they tell you which one they prefer, you then train a model to predict those ratings. That reward model becomes the scoring function during post-training.

    Designing good reward functions has been somewhat of a dark art in RL. The agent optimizes what you ask for, which is not always what you really want2. This “reward hacking” phenomenon makes RL agents somewhat brittle, prone to exploiting loopholes in environments in ways no one anticipated.

    The recent reasoning models did so well because their rewards were verifiable: reward scores that are based on some ground truth validation and are often just yes/no: does code compile, does it pass a unit test, can a math proof be verified by a formal logic reasoner, or simply is the answer correct or not. Nathan Lambert did a breakdown on where RL goes next:

    The optimistic case for scaling current reinforcement learning with verifiable rewards (RLVR) techniques to next-generation language models, and maybe AGI or ASI depending on your religion, rests entirely on RL being able to learn on ever harder tasks. Where current methods are generating 10K-100K tokens per answer for math or code problems during training, the sort of problems people discuss applying next generation RL training to would be 1M-100M tokens per answer.

    Lambert makes the point that even the very long-range tasks we have now (coding agents, deep research) are based around learning to be better at tasks individually, then stringing those together:

    How to read this training method, which is likely similar for agents like Claude Code or Codex, is that current RL methods are helping the models get more robust at individual tasks that make up a longer trajectory rather than being trained on the end result of the trajectory itself. The final long-horizon behavior is put together with prompting and letting the model run longer, not sparse credit assignment. In the case of Deep Research the final measure of performance would actually look far closer to human preferences than verifiable rewards, and a large portion of that applies for Claude Code as well, where multiple solutions could solve a problem and it falls to human taste to say which is the best.

    This problem of having to learn to act over a long time-horizon is a recurring one in RL. The best algorithms we have for reinforcement learning are online: the model learns “live” while interacting with the environment. But sometimes it’s a lot easier to collect data than it is to run an experiment: for example, it’s much safer to get a large amount of sensor input from driving cars around than it is to have a model driving a real car around and making mistakes. This is off-policy or offline RL, and it offers the promise of learning from much larger data sets.

    Seohong Park recently wrote a great post breaking down how offline RL fails to scale up: Q-Learning Is Not Yet Scalable3. In the experiment there the team at Berkeley generate 1000x more data to try and scale offline RL, and still see the process breaking down:

    Q-learning struggles to scale because the prediction targets are biased, and these biases accumulate over the horizon. The presence of bias accumulation is a fundamental limitation that is unique to Q-learning (TD learning). For example, there are no biases in prediction targets in other scalable objectives (e.g., next-token prediction, denoising diffusion, contrastive learning, etc.) or at least these biases do not accumulate over the horizon (e.g., BYOL, DINO, etc.).

    Noted LLM-branch skeptic (and technically a very distant colleague) Yann LeCun has spoken a lot about a version of this kind of planning and world modelling problem, which he sees as inherent to the autoregressive nature of LLMs: the accumulation of errors over long time horizons.

    One of his architectural bets is JEPA, and the recently released V-JEPA 2 paper is beginning to show how this could work. V-JEPA 2 is a self-supervised video world model trained on a million hours of YouTube video. The model learns in a semi-supervised fashion by masking out parts of video frames and predicting them, in latent (embedding) space rather than pixel space. After the pre-training, they freeze the encoder, generate tokens with it for a video and prepend those to a query for a pretrained LLM4 .They fine-tune that LLM on video question answering data, and were able to get state of the art question answering with that set up, despite the JEPA part of it being totally task agnostic.

    Going a step further, they took the encoder and hooked it up to a small robot control model5. They trained it on some robot operator data for pick-and-place tasks. It learned to do a remarkably good job, without any reinforcement learning at all!

    This is interesting because robotics has traditionally been an area where we have seen a lot of exploration (with success and disappointment!) with long-range RL. Andrew Stokols’ excellent post on ChinaTalk makes a good case that while the west has focused on AI in a brain-in-a-jar type way, there has been a concerted push in Beijing for Embodied AI (with Chinese Characteristics). China has a very strong base in manufacturing. Robotics, drones, autonomous vehicles are all being developed and deployed in the country.

    One of the fundamental challenges robotics systems have to address is much more constrained latency bounds: the world operates in real time, and running a big model may result in a smart robot that simply cannot respond quickly enough to be useful. The space has trended towards hierarchical models, which chunk actions into higher level concepts that a controller model puts out (like “pick up at x”) and lower-level models that decode those chunked outputs into a series of motor commands. While sometimes transformers are used autoregressively (take sequences of state, action and predict next action), many now use diffusion-based techniques where they will generate a whole trajectory at once. Physical Intelligence recently put out a paper on Real Time Chunking where they show you can start with generating a chunk, then continue the denoising process a-la inpainting or fill-in-the-middle to generate the steps between the start and goal, allowing more real time responses.

    China putting a lot of eggs in the embodied AI basket is indirectly also betting that methods to make those systems learn and adapt will mature. Some of those same techniques will invariably apply to the (disembodied) agents that are currently the focus on big labs in the west.

    1. One of the ways they corroborate this finding is by seeing there is less KL divergence in the RL trained model than the SFT model, but that’s usually a training objective on RL, and I’d imagine you could apply KL regularization to SFT as well if you wanted. ↩︎
    2. A classic example from OpenAI: A reinforcement learning agent in a boat race game was given points for hitting targets, so it happily learned to drive in circles hitting the same targets forever, instead of actually finishing the race. Faulty reward functions in the wild | OpenAI ↩︎
    3. Q-Learning is the most common class of algorithms for offline RL. ↩︎
    4. They unsquash it into the hidden dimension size, and depending on how the numbers work out add some pooling. ↩︎
    5. Much like with the LLM, they combine the video embeddings with model-specific tokens, in this case a state tracking input and the current state of the robot arm. ↩︎
  • GPU Driven

    Traditionally, GPU collective network operations were issued from the framework on a separate CUDA stream than the local computation kernel launches. This allowed overlapping comms and hiding most or all of the network latency. NCCL exposes collectives as fully implemented kernels, and there have been various derivitives such as AMD’s RCCL or Berkeley’s new UCCL project, which is aiming to be a drop-in replacement better suited for large-scale GPU workloads1. Earlier versions of this sent the networking via the host, and later developed towards GPU-to-GPU peer to peer over connections like NVlink, and direct communication between GPU and the NIC for scale out. But the actual coordination was driven by launches from the CPU.

    The increasing capacity of GPUs, particularly in the Hopper/MI300x era, the use of micro-batching, and the rise of Mixture-of-Expert models put a lot more pressure on this arrangement: now rather than a large chunk of comms at the end of each distributed-data-parallel pass you are doing thousands of small exchanges per step. Each step could require each CPU rank to launch thousands of tiny All-to-All: millions of collective calls across the cluster, while still running the rest of the training loop. Each one forces a host interrupt, collective calls and GPU triggers for the data transfer.

    A paper from last year, The Landscape of GPU Driven Communication, gives a broad overview of this shift:

    In the last decade, several advancements, broadly referred to as GPU-centric communication, have sought to challenge the CPU’s hegemony on multi-GPU execution. At a high level, these advancements reduce the CPU’s involvement in the critical path of execution, give the GPU more autonomy in initiating and synchronizing communication and attempt to address the semantic mismatch between multi-GPU communication and computation.

    Getting the right kind of abstractions in here to make this more accessible and flexible is an active area of development. The main reference is the shmem abstraction. This allows writing and reading bytes from remote memory (originally for supercomputers) and adding barriers around usage. Nvidia’s nvshmem (and AMDs ROCshmem) library directly implemented this for GPU.

    The next big step up in functionality was GPU Direct Async, a transport that allows access to NIC doorbells directly from the GPU, extending NVSHMEM to RoCE/InfiniBand (Remote Direct Memory Access – RDMA). In addition to this, NVLS (NVLink Sharp) was added to NVLink switches which allowed much more bandwidth efficient switch-managed multicast for broadcast and reduce cases, the fundamental operations used in all-gather and all-reduce collectives. This allows GPU-initiation of collectives which gives us the ingredients to get the CPU out of the networking path completely, and to fuse networking alongside compute operations.

    This was one of the things DeepSeek did really well, covered in their DeepEP work: fusing MoE GEMM with GPU-initiated RDMA cut single-node latency by 2–3x. The tradeoff is kernels handle a lot of complexity: choosing between NVLink and GPUDirect for nodes, polling, flow management and so on is tuned to their specific needs and hardware.

    ByteDance has also spent a lot of time looking at this problem. Their Flux paper looks at more general approach for doing tile-based fusion of the comms and compute:

    Flux overdecomposes computation and communication into tiles. Here, since the computation operation is GEMM, and most high-performance GEMM kernels on GPUs are written with tiling, such as thread block tiling or warp tiling, our decomposition can naturally map into existing tiling in the kernels. Flux fuses dependent communication and/or wait logic into a GEMM kernel, and launches only one fused kernel, compared to the prior methods launching multiple split GEMM kernels.

    PyTorch has an experimental feature and RFC for SymmetricMemory:

    Then, innovative block-wise compute/computation overlapping techniques started to use copy-engine to drive the P2P traffic in order to minimize contention. Now, we are seeing techniques where NVLink communication is directly issued from “compute” kernels.

    […]

     Just as Triton allows average users to modify matmuls for their needs (fusion, quantization, etc.), we hope that SymmetricMemory will enable average users to modify NVLink communication algorithms for their requirements, whether it’s implementing alternate collective algorithms (one-shot allreduce), using different quantization approaches (stochastic rounding), or fusing collectives with other kernels (all-gather interleaved with matmuls).

    This project is still fairly manual, though it abstracts much of the plumbing required and makes it accessible at a high level.

    In a similar vein, ByteDance recently released their implementation of Triton-Distributed: ByteDance-Seed/Triton-distributed: Distributed Compiler Based on Triton for Parallel Systems. This adds a small Triton DSL wrapping the shmem operations, supporting both Nvidia and AMD hardware. It exposes some of the comms knobs to autotuning, allowing tuning across compute and collectives, focusing on the tile-based approach they documented in their TileLink paper.

    This is unlikely to make its way into upstream Triton, in part because the general approach of Triton is to hide much of the hardware information in the compiler passes, while this approach makes the topology fairly explicit.

    The PyTorch team has been working on traceable collectives for the PyTorch compiler. The idea of a compiler being able to look at the overall task graph and decide where to fuse comms and which transport options to use is appealing, as it allows kernels to be more transparent to the specifics of the cluster.

    The move to GPU-initiated, fine-grain comms that fuse into compute kernels is real and continuing; the tooling is still early, but the gap will continue to close.

    1. This currently host-side controlled, but they have plans for GPU-driven comms as well ↩︎
  • Keeping a GPU busy is a lot about tiling

    File this under the “gross oversimplifications” category. The basic approach to keeping GPUs busy is dividing the work into tiles, smaller sub-problems that make up the larger result. For a GEMM you might break the matrix into 128×128 or 128×64 tiles and let each CUDA thread block (CTA) own one tile. The GPU has many streaming multiprocessors (an A100 has 108) and every SM picks up one CTA at a time. If you want to know how many SMs your own card has you can call:

    props = torch.cuda.get_device_properties(0)
    print(f"SMs: {props.multi_processor_count}")

    Tiles are launched in waves. A full wave is the moment when every SM is busy with exactly one CTA. If the total number of tiles isn’t a multiple of the SM count, the final wave is only partly full and some SMs sit idle; Nvidia calls that wave quantization. There is a similar problem at the edge of the matrix: if the dimensions aren’t multiples of the tile size the right-most or bottom-most tiles are partly empty, wasting threads (tile quantization). Sometimes a smaller tile size (for example 64 × 64) gives higher overall throughput because it leaves less unused space at the edges.

    The usual cure for poor wave utilization is a persistent kernel. Instead of launching one CTA per tile, you launch (roughly) one CTA per SM and have each CTA pull tiles from a global queue until the queue is empty. Because each CTA is pulls whenever ready, the SMs rarely go idle and the tail effect is reduced.

    Inside an SM the main performance lever for GEMMs arethe Tensor Core, which execute matrix-multiply add (MMA) instructions efficiently. On Ampere you use WMMA instructions: one Warp (32 threads) computes a 16 × 16 fragment at a time. Hopper introduces WGMMA instructions where four warps acting in ia warp-group (128 threads) execute a larger matrix multiply (up to 64 × 64 for FP16/FP8). To issue WGMMA you must place the right-hand operand B in shared memory; A can sit in either registers or shared memory. The operation is asynchronous, so while a warp-group is processing one tile the same CTA can be pre-loading the next tile.

    Blackwell pushes the idea further. A pair of CTAs on neighbouring SMs can cooperate in a pair unified MMA, letting two SMs’ tensor cores process an even larger tile.

    To make that possible Hopper introduced thread-block clusters and Blackwell extends them. When you launch a kernel you can group CTAs into clusters such that the scheduler guarantees to place them on SMs inside the same GPC (GPU Processing Core), so they share a fast interconnect and can access shared memory across SMs. If the grid doesn’t divide cleanly into whole clusters you also lower efficiency on the tail (is this cluster quantization? stick with the trend Nvidia!) so Blackwell has a Cluster Launch Control that can shrink the last cluster to better fit the work.

    Loading Data

    All of this only works if data is present in shared memory. The first optimization is making sure (global) memory access is coalesced. A 32-thread warp can request 32-byte chunks , but the memory bandwidth for a single fetch from DRAM is wider. e.g. If four consecutive threads request address 1, 4, 8 and 12, the memory controller can coalesce these into a single 128-byte read. If the addresses are strided (e.g. hopping across rows) then only 32 bytes out of the 128 byte fetch capacity is loaded at a time, so the load takes longer. Getting this right is about ensuring the memory layout is set up for the kernel, and doing any transforms needed in shared memory before executing.

    In older GPUs the warp had to wait on the copy operation. Ampere enabled cp.async plus non-blocking wait/arrive barriers so a warp can initiate a copy from global to shared memory and immediately continue with arithmetic. Hopper adds the Tensor Memory Accelerator: with TMA, a single thread in the CTA can describe a multidimensional block to copy and the TMA hardware streams it to shared memory while the threads do something else. Blackwell goes one step further and can multicast a single TMA load into every SM of a cluster, which is helpful when multiple CTAs are about to reuse the same B tile.

    In practice you hide latency by organizing the main loop using so that it double buffers: while the tensor cores work on tile k the TMA or cp.async engine is fetching tile k + 1 into the other half of shared memory; then you swap buffers and repeat. As long as copy time and compute time overlap well, the tensor cores and the copy engines stay saturated.

    Choosing the right tile size

    Choosing the right tile size (often expressed in Triton as BLOCK_M × BLOCK_N) is a balance between each of these: enough threads to issue a warp-group MMA, small enough tiles that the matrix edges aren’t mostly padding, enough shared-memory space to double-buffer, and a grid size that fills whole waves or is run via a persistent kernel. Autotuning in Triton or CUTLASS can empirically test different options on the hardware, but it helps to have the right mental model about what sets of sizes they should consider. One good clue that you’re missing an option is when you see a sudden drop in achieved TFLOP/s for particular shape.

    AMD

    AMD’s MI300X hardware takes a somewhat different route. The GPU is divided into chiplets, where each chiplet has its own compute units and multiple schedulers that schedule wavefronts (AMD for warps, 64 threads rather than 32) independently, so the hardware load-balances multiple kernels by itself. Matrix instructions run at the wavefront level; there is no cross-CU equivalent to WGMMA. Latency hiding relies on launching a large grid of workgroups and letting the hardware interleave them, rather than on explicitly scheduling async copies. On AMD the guidance is to mostly focus on high occupancy and coalesced memory access, whereas on NVIDIA there is value in crafting (by hand or compiler) the copy–compute pipeline.

  • How does Triton do Warp Spec?

    Kapil Sharma from the PyTorch team has a great series of posts diving into the Triton compiler process: 1, 2, 3. As covered there, Triton lowers to a series of intermediate representations, and each level has a set of transformational passes that implement optimizations. TTIR is the generic Triton IR and leverages a number of standard MLIR passes like common subexpression elimination, as well as some Triton specific passes like managing broadcast ops. That’s then lowered to TTGIR, a GPU-specific IR1

    triton/third_party/nvidia/backend/compiler.py at rc/3.3.x · triton-lang/triton

    The different backends configure the passes appropriate for their targets, so the Nvidia TTGIR configuration above details the passes for Nvidia hardware. Some are gated on the specific backend targeted, like warp specialization:

    passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
    opt.reg_dec_producer, opt.reg_inc_consumer)
    passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
    passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)

    Another example is using the Tensor Memory Accelerator for async loading on Hopper+

    if capability // 10 >= 9:
       nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
       nvidia.passes.ttnvgpuir.add_fence_insertion(pm)

    For a quick recap of the why and how of Warp Specialization, check Colfax’s guide to optimizing a GEMM:

    The most basic method by which GPU programmers can create overlapping is via excess warps (warps are groupings of 32 contiguous threads). Nvidia GPUs allow a large number of warps per SM (streaming multiprocessors), and can switch between them with minimal overhead. In particular, the warp schedulers can simply switch to another warp if one warp encounters a slow memory fetch. In order to give the warp schedulers more opportunity to hide latency, a technique called warp-specialization was introduced circa 2011 [1, 2]. With warp-specialization, some warps are dedicated to memory fetches (producers), while others are dedicated to compute (consumers), and named barriers are used for synchronization between them. The idea is that the warp schedulers can then more easily hide the latency of copy operations within compute (and vice-versa).

    Even more generally than overlapping memory loads you can overlap other kinds of work. SMs have 1 Tensor Core to 32 ALUs (Cuda Core) (x4 on recent hardware). This means that you can overlap other kinds of work, stuff that isn’t dot products. It’s really common to want to load memory, do a matmul then apply a pointwise function like a relu or other activation function. You aim to keep the Tensor Core as busy as possible with a series of matmuls, and warp specialization lets you do that.

    The transforms to implement this are implemented in triton/lib/Dialect/TritonGPU/Transforms at rc/3.3.x · triton-lang/triton

    The first task partitions ops in the kernel. This task looks for load ops and dot product ops, and partitions them into producer (loads) and consumer (process) groups.

    // Step 1. Select loads into the first task, which is the producer task by
    // default. Place dots into the second task, which is the consumer.
    // Only consider loads that are connected to a dot op in a loop.

    The next task is a bookkeeping one to propagates the task IDs, so if there are unlabeled ops they are attached one of the partitions.

    The WSDataPartition transform partitions the dot operations into consumer groups. It splits the dot products inside loops along M or N dimensions to be processable within a warp group, ensuring all dot operations are sliced and the slices labelled with task_ids.

    Just to look at some of the numbers: A warp consists of 32 threads, and a sync op (for loading) uses a whole warp. A warp group is a set of 4 warps: this is pertinent because the TensorCore MMA prefers 4 warps working on 64x(64/128/256)xK tiles. Triton already has a WarpGroupDotOp that tries to set this up, and that’s one of the operations targeted in this pass. The pass splits a Triton CTA tile, which may be 128×256, so that each consumer warp group has a 64 row (or 256 column) chunk.

     if (sliceSizeM >= 64) {
          LLVM_DEBUG({ LDBG("partition along M\n"); });
          partitionDim = 0;
          partitionSize = sliceSizeM;
          partitionOperand = opndA;
        } else if (sliceSizeN >= 256) {
          LLVM_DEBUG({ LDBG("partition along N\n"); });
          partitionDim = 1;
          partitionSize = sliceSizeN;
          partitionOperand = opndB;
        } else {
          LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN);
          return false;
        }

    The next pass is WSCodePartition. This is a big transform. It takes the task-sliced IR from the DataPartition and sets up the producer warp group to copy from global GPU mem to SMEM (or on blackwell TMEM). It also drops in barriers using product.acquire/cpmmit and consumer.wait/release to ensure proper ordering between the groups. The transform identifies “channels”, places where data is loaded (using load or descriptor_load for TMA on Hopper+) and associates the producer task_id with all the consumer task IDs that need to process that data.

    Conceptually, the transform is turning the loops in the original kernel into something like this:

    for k: # K-dimension tiles)
        ##### PRODUCER #####
        producer.acquire(token, idx, phase) # reserve smem[idx]
        async_copy_global_to_local(smem[idx]) # GMEM → SMEM[idx]
        producer.commit(token, idx) # make slot visible
    
        #####  CONSUMERS (run in parallel warps) #####
        ## Consumer 0 ##
        consumer.wait(token, idx, phase) # sleep until slot ready
        mma_sync(accum0, smem[idx], …) # read-only, do matmul
        consumer.release(token, idx) 
    
        ## Consumer 1 ##
        consumer.wait(token, idx, phase)
        mma_sync(accum1, smem[idx], …)
        consumer.release(token, idx)
    
        # Repeat for extra consumers. 
    
        # increment circular buffer
        idx   = (idx + 1) % numBuffers 
        # Toggle each time we hit producer, indicate old vs new data.
        phase = phase ^ (idx == 0)
    }

    The next pass is a generic pipeline pass. Each op is assigned a stage based on latency: e.g. slower ops like loads go into stage 0. This is then transformed with modulo scheduling. Any sync ops are converted to async variants (in lowerLoops), before writing out (in expandLoops) prologue, kernel and epilogue loops that contain all the relevant ops.

    Finally the WSLowering pass takes the various operators we have (like producer.acquire) and replaces them with the hardware specific variants (e.g. wait_barrier). It also handles the bookkeeping like generating the warp group and task_ids from the warp ID:

    OpBuilder builder(op);
    Value _4 = builder.create<arith::ConstantIntOp>(loc, WARPS_PER_TASK, 32);
    Value warpId = builder.create<ttng::GetCanonicalWarpIdOp>(loc);
    Value asyncTaskId = builder.create<arith::DivUIOp>(loc, warpId, _4);
    op.getResult().replaceAllUsesWith(asyncTaskId);

    This is a wordy IR way of saying

    warpId = gpu.thread_id().x / 32 
    task_id = warpId / 4

    Now the code is ready for lowering to regular PTX, and all of the warp-specific stuff is captured explicitly!

    1. Note: Because of some project weirdness, warp specialization is quite different in the release branches from main, so I’ll refer to 3.3 from here on. It’s in very active development (by teams at Meta, OpenAI and Nvidia!) so the specifics are quite likely to change over coming releases! ↩︎
  • Autotuning in PyTorch & Triton

    torch.compile offers some knobs for controlling the trade-off of execution performance with longer compile times. This is particularly useful for inference, where the same model will be running for a long time.

    model_autotune = torch.compile(model, mode="max-autotune")

    Passing the max-autotune option to instructs the compiler to test more options for the operations. The compiler has the option to use pre-built aten kernels, leverage kernels from libraries like CuDNN or Cutlass, or use templated Triton kernels. When autotuning, specific variants are tested on device with the shape information identified during tracing, and the fastest options are selected. Thanks to Triton templates, it can also use options like fusions where pointwise ops can be fused into a single kernel via a Triton template, saving kernel launch overhead.

    The downside of this is that testing the options takes more time, so using max-autotune can lead to some very extended compile times. You also need a hefty enough GPU to get the benefit: is_big_gpu gates it on the number of SMs, so it works best on a 3090, V100 or above.

    You can see a lot of the autotuning options in _inductor/config.py. Backends that are considered are set separately for GEMMs and convolution ops:

    max_autotune_gemm_backends = os.environ.get(
        "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
    ).upper()

    Each kernel has implementations using for the different backends which are added to possible choices. e.g. in _inductor/kernels/mm.py you can see calls to use_[backend]_template that verify whether the backend in question is a choice:

    if is_nonzero and use_cutlass_template(layout, m, n, k):
            CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

    _inductor/select_algorithm.py does the actual benchmarking through the choices.

    If you run autotuning, you’ll get some log output, and caches will be written to /tmp/torchinductor_yourusername.

    We can try this out on a simple MLP:

    import torch, time
    
    class SimpleMLP(torch.nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.linear1 = torch.nn.Linear(in_features, hidden_features)
            self.relu = torch.nn.ReLU()
            self.linear2 = torch.nn.Linear(hidden_features, out_features)
        def forward(self, x):
            return self.linear2(self.relu(self.linear1(x)))
    
    # Set up device and model
    device = 'cuda'
    model = SimpleMLP(in_features=1024, hidden_features=1024, out_features=1024).to(device)
    x = torch.randn(256, 1024, device=device)  # batch of 256, 1024 features each
    
    # Compile the model in default mode and max-autotune mode
    model_default = torch.compile(model, mode="default")
    
    # Warm-up runs (to trigger compilation)
    torch.compiler.reset()
    with torch.no_grad():
        model_default(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of default compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_default(x)
        end.record()
    torch.cuda.synchronize()
    time_default_ms = start.elapsed_time(end) / 50.0
    torch.compiler.reset()
    
    model_autotune = torch.compile(model, mode="max-autotune")
    
    with torch.no_grad():
        model_autotune(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of max-autotune compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_autotune(x)
        end.record()
    torch.cuda.synchronize()
    time_autotune_ms = start.elapsed_time(end) / 50.0
    
    print(f"Average inference time - torch.compile default: {time_default_ms:.3f} ms")
    print(f"Average inference time - torch.compile max-autotune: {time_autotune_ms:.3f} ms")
    

    Disappointedly, this is the result:

    Average inference time - torch.compile default: 0.113 ms
    Average inference time - torch.compile max-autotune: 3.251 ms

    We can turn on logging with the TORCH_LOG env variable: some useful options are inductor, autotuning, and perf_hints.

    TORCH_LOGS="perf_hints" python tune.py

    You can control many more autotune options via the options flags, though its incompatible with passing a mode value. We can recreate the max-autotune mode, and turn on some useful tracing options like this (note that the options version uses an underscore, the mode a hypen!)

    model_autotune = torch.compile(
    model,
    options={
    "max_autotune": True,
    "triton.cudagraphs": True,
    "coordinate_descent_tuning": True,
    "trace.enabled": True,
    "trace.graph_diagram": True,
    },
    )

    Options "trace.enabled": True, "trace.graph_diagram": True generate trace outputs, and output a nice diagram of the captured graph. Cudagraphs turned out to be the culprit here, which is common enough there is a non-cudagraph mode available to stop you having to remember all the options:

    model_autotune = torch.compile(model, mode="max-autotune-no-cudagraphs")

    As you can see here in the graphs of with and without, the slower version actually has an extra fusion performed!

    Captured graphs for the two runs

    Triton Autotuning

    Triton also conducts autotuning, but it’s a little more explicit. When authoring a Triton kernel you can specify configurations. At compile time each config variant will be tested, the most performant one picked and the choice stored for future calls. A key value can be provided to indicate when to re-autotune based on changing inputs:

    import os
    import torch
    import triton
    import triton.language as tl
    
    # Just to save passing this on the command line
    os.environ["TRITON_PRINT_AUTOTUNING"] = "1"  
    
    @triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE': 128}, num_warps=4,  num_stages=2),
            triton.Config({'BLOCK_SIZE': 256}, num_warps=8,  num_stages=2),
        ],
        key=['N']            # re‑tune only if the length N changes
    )
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        pid   = tl.program_id(0)
        offs  = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        x     = tl.load(x_ptr  + offs, mask=mask, other=0.0)
        y     = tl.load(y_ptr  + offs, mask=mask, other=0.0)
        tl.store(out_ptr + offs, x + y, mask=mask)
    
    def vec_add(x: torch.Tensor, y: torch.Tensor):
        assert x.is_cuda and y.is_cuda
        N   = x.numel()
        out = torch.empty_like(x)
        grid = (triton.cdiv(N, 128),)         # 128 = smallest BLOCK_SIZE we declared
        vecadd_kernel[grid](x, y, out, N)    
        return out
    
    x = torch.randn(1 << 20, device="cuda")   # 1 048 576 elements
    y = torch.randn_like(x)
    
    _ = vec_add(x, y)  # first call → autotuning prints to stdout
    _ = vec_add(x, y)  # second call → no autotuning, uses the best config found
    

    Setting the env variable TRITON_PRINT_AUTOTUNING documents the process as it goes:

    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 256, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Triton autotuning for function vecadd_kernel finished after 0.44s; best config selected: BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;

    You can use the same do_bench tester that the autotuner does, and see how the performance varies yourself:

    import torch, triton, triton.testing as tt
    import triton.language as tl
    
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        offs  = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        tl.store(out_ptr + offs,
                 tl.load(x_ptr + offs, mask=mask) +
                 tl.load(y_ptr + offs, mask=mask),
                 mask=mask)
    
    # tensors
    N   = 1 << 20
    x   = torch.randn(N, device='cuda')
    y   = torch.randn_like(x)
    out = torch.empty_like(x)
    
    def bench(block_size, num_warps):
        grid = (triton.cdiv(N, block_size),)
        # tt.do_bench returns [median, p20, p80] in micro‑seconds
        return tt.do_bench(
            lambda: vecadd_kernel[grid](x, y, out, N, BLOCK_SIZE=block_size, num_warps=num_warps),
            warmup=5, rep=16, return_mode="all", quantiles=(0.5, 0.2, 0.8)
        )
    timings = {
        "128/4": bench(128, 4),
        "256/8": bench(256, 8),
    }
    
    print("timings:", timings)
    

    Running that gives shows that both kernels are basically equivalent, but the first one is slightly faster over the 16 runs.

    timings: {'128/4': [0.01945599913597107, 0.01945599913597107, 0.02028159946203232], '256/8': [0.01945599913597107, 0.01945599913597107, 0.020479999482631683]}

  • Profiling Triton

    There are a couple of different options to profile a Triton kernel.

    Proton

    Proton is the profiler that ships with Triton (profiler for triton). You can enable it and (optionally) activate/deactive around specific regions you want to profile. You have the ability to annotate functions with specific metrics as well.

     session = proton.start()  # Start profiling session
    
    bias = torch.rand((256,), device='cuda', dtype=torch.float16)  # Bias vector
    flops = 2 * M * N * K
    bytes_accessed = A.element_size() * M*K + B.element_size() * K*N + C.element_size() * M*N  # rough bytes
        with proton.scope(f"fused_gemm_bias_relu [M={M}, N={N}, K={K}]", {"flops": flops, "bytes": bytes_accessed}):
            fused_gemm_bias_relu[grid](  
                A, B, C, bias, 
                M, N, K, 
                A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1),
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
            )
    
     proton.finalize() 

    The output can be visualized with the built-in viewer:

    proton-viewer -m time/ms,tflop/s ./proton.hatchet

    0.040 6.645 ROOT
    ├─ 0.004 nan _ZN2at6native55_GLOBAL__N__11f7a751_22_DistributionUniform_[...]_15PhiloxCudaStateESH_SI_
    └─ 0.037 7.284 fused_gemm_bias_relu [M=1024, N=256, K=512]
    └─ 0.037 nan fused_gemm_bias_relu

    In this case you can see both the (trimmed!) generated name for the bias tensor set up as well as the name of my custom kernel.

    nsight-compute

    Nvidia also have a good range of tools for looking at performance. Note will need to enable access to the counters on device for this:

    NVIDIA Development Tools Solutions – ERR_NVGPUCTRPERM: Permission issue with Performance Counters | NVIDIA Developer

    On the offchance you’re doing this on WSL, https://peterchng.com/blog/2024/03/02/profiling-cuda-programs-on-wsl-2/ walks through the set up!

    Nvidia ships nsight system which tracks larger system wide metrics, and nsight compute which is more focused on profiling execution. You can run it against a script like so:

    ncu -o profile_results python test.py

    The tool comes with a nice GUI for inspecting the results. It can show you the PTX or SASS source for the kernels, offers metrics like actively used registers (good for checking on register spilling), and calls out warnings on poor utilization or memory clashes.

    Upcoming intra-kernel profiler

    [tool][proton] Intra kernel profiling support by fywkevin · Pull Request #4861 · triton-lang/triton

    There is an extension coming for Proton that enables profiling within kernels. This reserves a pre-allocated buffer on device and logs metrics locally, for reading out at the end of the execution. It outputs as a chrome trace for use within a wide range of dev tools. While this isn’t merged into mainline yet, you can see an example of the usage in the dev repo.

  • JEPA

    JEPA is an example of a predictive coding architecture. Predictive coding is the idea that the brain works by predicting the next sensory input and learns from the error between the prediction and the actual input. Hierarchies of this kind of prediction allow higher level elements to predict the outputs of lower-level elements, building up deeper and more complex semantics.

    The core idea in JEPA is to take two related things (say consecutive frames of a video) x and y, encode each of them into a latent space (an embedding), and then predict the embedding for y s(y) based on s(x). The encoders can be Transformer-based models — in practice models like I-JEPA have trained the x encoder and updated the weights via a moving average for the y (target) encoder.

    The learning is not based on how well the end-result predicts the target e.g. how close the pixels of the next frame are predicted. Instead, it’s based on how well the latent representation of the next frame is predicted.

    The advantage of working in the latent space for the prediction is the model can choose what level of detail it wants to capture, discarding some aspects and focusing on more foundational concepts. This helps build a more robust world-model, with the hope being that training in this way will then allow easier generalization to more tasks, with less data required

    Similarities

    This is somewhat similar to autoencoders. Autoencoders take an input, compress it in a latent space, then reconstruct the original from the latent space and propagate back the error. JEPA does a similar process across two different items with separate encoders, and only cares about error within the latent space.

    Contrastive models embed two different items into the same space and try to increase similarity between the embeddings for things known to be similar and make them dissimilar to other items. This is used in CLIP and other multimodal text-image encoders, where the text and the image embed to the same space so that a text caption and a matching image are close in embedding space. This requires a lot of pairwise comparisons, while JEPA is a more straightforward s(x)->s(y) prediction in training.

    Challenges

    Because JEPA models leave you with a latent they need to be paired with a generator for getting an observable/human viewable output, which is a per-domain challenge. This makes it harder to evaluate how well the model is learning, beyond measuring loss.

    Training stability can also be tricky — it is possible for the model to collapse and learn trivial representations to minimize prediction error. Even without complete collapse it can require some experimentation to ensure the model is learning a deep enough conceptual level. For example, I-JEPA, which worked in image space, found that using large enough masked patches was important to ensure the model captured sufficient detail.

  • Pydantic Evals

    https://ai.pydantic.dev/evals/

    Ed Yang was recently recommending keeping your own benchmark of LLM evals, so you can test newer models on problems that they have struggled with in the past. I have recommended similar things to people, but there is some barrier to entry into knowing how to start. Ed references (and forks) Nicolas Carlini’s personal benchmark repo, but its nice to have some light(ish) weight options too.

    Pydantic Evals is a powerful evaluation framework designed to help you systematically test and evaluate the performance and accuracy of the systems you build, especially when working with LLMs.

    You can install the library with uv or pip:

    uv add pydantic-evals
    

    I tried it out with a strawberry test, calling openrouter with different models. I needed a custom eval as the default Contains is a bit rigid, but the approach seems nice!

    import os
    import asyncio
    from dataclasses import dataclass
    from pydantic_evals import Case, Dataset
    from pydantic_evals.evaluators import Evaluator, EvaluationReason, EvaluatorContext
    from pydantic_evals.evaluators.common import _truncated_repr
    from openai import OpenAI
    from typing import Any, Optional, cast
    @dataclass
    class FlexibleContains(Evaluator[object, object, object]):
        """
        Check if the output contains any one of the expected options.
        """
        value: Any
        case_sensitive: bool = False
        def evaluate(
            self, ctx: EvaluatorContext[object, object, object]
        ) -> EvaluationReason:
            failure_reason: Optional[str] = None
            # Normalize value into a list of options if it isn't already a list or tuple.
            options = self.value if isinstance(self.value, (list, tuple)) else [self.value]
            output_str = str(ctx.output)
            if not self.case_sensitive:
                output_str = output_str.lower()
            match_found = False
            for opt in options:
                opt_str = str(opt)
                if not self.case_sensitive:
                    opt_str = opt_str.lower()
                if opt_str in output_str:
                    match_found = True
                    break
            if not match_found:
                failure_reason = (
                    f"Output string {_truncated_repr(output_str, max_length=100)} does not contain "
                    f"any of expected strings: {[str(opt) for opt in options]}"
                )
            return EvaluationReason(value=match_found, reason=failure_reason)
    strawberry = Case(
        name="strawberry",
        inputs="How many rs are in strawberry?",
        evaluators=[FlexibleContains(value=["3", "three"])],
        metadata={"difficulty": "easy"},
    )
    dataset = Dataset(cases=[strawberry])
    MODELS = [
        "anthropic/claude-3.5-sonnet",
        "openai/gpt-4o",
        "meta-llama/llama-4-maverick:free",
        "meta-llama/llama-4-scout:free",
        "openrouter/optimus-alpha",  # secret model!
    ]
    def generate_completion(inputs: str, model: str) -> str:
        """Generate a completion using OpenRouter with specified model"""
        client = OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant."},
                {"role": "user", "content": inputs},
            ],
            max_tokens=50,
            temperature=0.7,
        )
        return response.choices[0].message.content.strip()
    def evaluate_models():
        """Run evaluations across multiple models"""
        for model in MODELS:
            print(f"\nResults for model: {model}")
            print("=" * 50)
            # Wrap the synchronous generate_completion in an async function:
            async def model_specific_generate(inputs: str) -> str:
                loop = asyncio.get_running_loop()
                return await loop.run_in_executor(None, generate_completion, inputs, model)
            # Run evaluation for this model
            report = dataset.evaluate_sync(model_specific_generate)
            # Print results for this model
            report.print(include_input=True, include_output=True, include_durations=False)
    def main():
        evaluate_models()
    if __name__ == "__main__":
        main()
    

    To give a trimmed output:

    Results for model: openrouter/optimus-alpha
    ==================================================
                                          Evaluation Summary: model_specific_generate
    ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    ┃ Case ID    ┃ Inputs                         ┃ Outputs                                                   ┃ Assertions ┃
    ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    │ strawberry │ How many rs are in strawberry? │ The word **"strawberry"** contains **three** letter "r"s. │ ✔          │
    ├────────────┼────────────────────────────────┼───────────────────────────────────────────────────────────┼────────────┤
    │ Averages   │                                │                                                           │ 100.0% ✔   │
    └────────────┴────────────────────────────────┴───────────────────────────────────────────────────────────┴────────────┘
  • Dynamic Shapes in PyTorch

    Dynamic shapes are one of the more distinctive parts of torch.compile. Rather than specializing a graph to static shapes (which works in many cases!), PyTorch’s approach allows a single graph to work for a variety of sizes, so things like sequence length or batch size can vary. It does this by reasoning about shapes symbolically: instead of using fixed shape values, it uses placeholders and infers rules that constrain those shapes.

    Tracing & Symbolic Shapes

    PyTorch uses tracing in Python (via Dynamo) to capture the graph of operations. By default, it marks shapes as static during tracing. If a shape marked as static changes at runtime, it is marked as dynamic and treated symbolically in a recompilation. You can also proactively mark a dimension as dynamic to encourage symbolic treatment from the start:

    torch._dynamo.mark_dynamic(x, 0)  # Mark dim 0 as dynamic

    Under the hood, PyTorch uses SymPy to represent and manipulate symbolic shapes. Each dynamic shape is replaced with a SymInt and tracked by a central ShapeEnv.

    Every operation in PyTorch has a meta function — a lightweight implementation that computes metadata like shape changes without actually performing the computation. This lets PyTorch propagate symbolic shapes through the graph. For example, concatenating two tensors along dimension zero is represented symbolically as:

    s0 = s_x0 + s_y0

    To support branching logic, symbolic shapes carry a “hint” of their current concrete value. This allows specific branches of conditionals like if tensor.size(0) > 2 to be taken during tracing based on the hint. PyTorch adds a guard at this point to ensure that the resulting graph is only used if that branch is the correct one.

    Guards

    Guards are runtime checks inserted into the compiled graph to ensure the assumptions made during tracing still hold. For example, in the case of tensor.size(0) > 2, if the tensor is the result of concatenation, the guard will check that a.size(0) + b.size(0) > 2. If this fails, the code is retraced, and a graph for the new branch generated. Multiple graphs can be cached and selected at runtime based on guard validation.

    Guards don’t need to assert exact sizes; they can use symbolic constraints like x.size(0) > 2. This allows dimensions to vary within bounded ranges. The backend compiler (usually Inductor) can then compile code that operates over symbolic dimensions, as long as the variability is within the guarded constraints.

    For example, operations like broadcasting typically generalize well to symbolic shapes. In contrast, if an op specializes on a fixed shape (e.g., optimized path for 1D input), it may require conditional tracing and guards.

    What this means in practice is that most of the time compilation will follow this process:

    • Take a batch of data, assume all shapes are static, insert guards, and pass static sizes to the compiler
    • On the next batch see which guards have been violated, mark those dimensions as dynamic, add appropriate guards and pass symbolic dimensions to the compiler
    • Assuming no control flow, continue to reuse this dynamic graph for without recompilation

    Backed vs. Unbacked SymInts

    Most symbolic shapes are backed, meaning they have an associated concrete value at trace time. These are usually derived from inputs and show up in traces as s0, s1, etc.

    Unbacked SymInts lack a concrete value. These arise from data-dependent operations, e.g.:

    n = (x > 1).nonzero().size(0)

    Here, n depends on the data in x, so its size cannot be known at trace time. It will be represented as an unbacked SymInt like u0.

    If a control flow decision depends on an unbacked SymInt, tracing cannot proceed, resulting in a graph break or a GuardOnDataDependentSymNode error (when full_graph=True).

    However, you can guide the compiler with additional constraints, e.g.:

    torch._check(x.size(0) == y, lambda: f"size mismatch: {x.size(0)} != {y}")

    This lets PyTorch treat x.size(0) as equivalent to y throughout the graph. The check will be validated at runtime.

    There are other APIs to help mark unbacked SymInts as size-like to enable meta function compatibility (see Ed’s docs for more).

    Controlling Dynamic Shape Usage

    You can control dynamic behavior in torch.compile with the dynamic flag:

    • Not passed: default shape inference behavior
    • dynamic=false: force all shapes to be static
    • dynamic=True: treat all shapes as dynamic

    The default is usually best, but dynamic=True can help in testing.

    Use full_graph=True to attempt to generate a single, complete graph without breaks. This is often critical for performance, as graph breaks can drastically affect runtime and it’s easy to make innocuous looking code changes that can trigger additional breaks!

  • Bank Conflicts in Shared Memory

    When data is in the global memory on a GPU it’s usually in row-major or column-major order. Loading from global memory is quite slow though, so for performance we want to move the data to shared memory for the threads in a warp to work on.

    To make that load from global memory performance we want memory reads to be coalesced, meaning we are reading contiguous chunk of memory at a time. Shared memory on the other hand is divided into banks, typically 32 banks which are 4 bytes wide. If multiple threads in the same warp try to write to different addresses in the same bank then the requests are processed sequentially, slowing things down while the threads wait on each other. Nsight and other profiling tools will helpfully point this out to you!

    For example, let’s say we’re loading a row major and column major tensor, and will be doing a multiplication between them (this is naive, to demonstrate the issue):

    __shared__ float Asub[TILE_DIM][TILE_DIM];
    __shared__ float Bsub[TILE_DIM][TILE_DIM];  // (No padding in this naive version)
    int lane = threadIdx.x;  // 0...31 (warp lane index)
     int tileRow = blockIdx.y * TILE_DIM;
     int tileCol = blockIdx.x * TILE_DIM;
    int globalRow = tileRow + lane;
    int globalCol = tileCol + lane;
    Asub[lane][0] = A[globalRow * N + tileCol + 0];
    Bsub[lane][0] = B[(tileRow + lane) + (tileCol + 0) * N];

    Now when we fill Bsub we will be writing everything to the same shared memory bank, significantly slowing things down. One easy fix is just to add padding:

    __shared__ float Asub[TILE_DIM][TILE_DIM];           // A tile (row-major, no conflict in our case)
    __shared__ float Bsub[TILE_DIM][TILE_DIM + PAD];     // B tile (extra column to prevent conflicts)
       

    With PAD as 1 (and TILE_DIM as 32) we have 32×33, or 132 bytes, offsetting the writes and ensuring that each thread gets its own bank.

    The downside is that this wastes shared memory, a scarce resource, so an alternative approach is swizzling: changing the layout such that consecutive thread accesses aren’t causing bank conflicts. That’s what Bert implemented to get performance in his recent GEMM walkthrough, but it’s easy to get it wrong.

    To make life easier than writing it in raw CUDA, Cutlass has a system called CuTE. Cute is a set of templates to express layout of data:

    auto tileLayout    = make_layout(make_shape(Int<32>{}, Int<32>{}), GenRowMajor{});
    auto swizzledLayout = composition(Swizzle<5, 0, 5>{}, tileLayout);

    Here you specify how the data is laid out in global memory with the shape and stride, then make_layout and the copy operation take care of translating from the row-major layout in global memory to the swizzled layout in shared memory.

    From a Triton perspective, Lei Zhang has a great post on memory access, and how it works in Triton, specifically the LinearLayout class that allows the language to similarly handle swizzling and layouts for you:

    Indeed the whole point of LLs is that they allow us to specify transposed and swizzled layouts as a “general case”. Instead of a layout class for registers in a thread, and another layout for registers in a thread but in MMAv2 order, and so on, all of these can be represented by different LLs. This gets rid of special cases and lets us write more general code.

    There’s a great colfax report on building GEMMS that covers shared memory bank conflicts, and Lei Mao has a post with a nice illustration. Axel Feldman also has a post about benchmarking different approaches and identifying bank conflicts, and some more efficient loading techniques.

  • Muon optimizer

    Last year Keller Jordan at OpenAI beat some of the existing NanoGPT speedrun records thanks to some optimizer improvements. Towards the end of the year the work was formalized as the Muon optimizer, and it’s making waves in a bunch of areas now

    Friendship ended with Adam, now Muon is my best friend.
    From Elie Bakouch’s great pretraining presentation

    Jeremy Berenstein has written up a great post on how Muon is derived:

    To handle individual layers, our idea is to normalize the weight updates in a clever way so that, given the structure of the inputs, the weight updates automatically induce a desirable effect on the outputs. As a community, we have invested so much effort into thinking about how to normalize the activations: think batch normlayer normRMS norm, etc. Why not also consider how the weights and weight updates influence the activations?

    Keller also wrote a detailed blog post when introducing the optimizer, calling out some open questions (like does it scale to very large training).

    As the posts cover, the optimizer isn’t totally generally – it was designed for linear layers (and flattened convs), so you need to pair it up with Adam for most usage.

    You can install the library from Github: pip install git+https://github.com/KellerJordan/Muon

    from muon import Muon
    
    muon_params = [p for p in list(model.parameters()) if p.ndim > 2]
    muon_param_ids = {id(p) for p in muon_params}
    adamw_params = [p for p in model.parameters() if id(p) not in muon_param_ids]
    # Create the optimizer
    optimizers = [Muon(muon_params, lr=0.001),
    torch.optim.AdamW(adamw_params, lr=0.001)]

    And step both optimizers in the training loop:

    for opt in optimizers:
        opt.step()

    It’s great to have innovation in this area, particularly with this kind of fundamental reasoning around why it works!

  • Warp Specialization

    In general, branching in GPU code is considered bad. When you write a kernel, it’s very easy to write the same kind of logic as you would on a CPU. However, GPU kernels execute on blocks of threads scheduled on streaming multiprocessors (SMs), and they are optimized for vectorized (or parallel) computation. This optimization relies on the idea that large groups of threads can execute the same instructions on different data in a lockstep fashion. Practically, these are scheduled as “warps” of 32 threads at a time (on Nvidia, the equivalent in AMD is 64 threads).

    To work an example , take this naive approach to processing an array that contains both positive and negative values:

    __global__ void naive_kernel(const float* input, float* output, int N) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
        if (idx < N) {
            if (input[idx] > 0) {
                // Some operation for positive values
                output[idx] = input[idx] * 2.0f;
            } else {
                // Some operation for negative or zero values
                output[idx] = input[idx] + 1.0f;
            }
        }
    }
    

    On a GPU, this branching between positive and negative values can lead to warp divergence – you end up using a small number of the threads in the warp, getting worse utilization. Instead, you can rewrite this logic to effectively remove the branching:

    __global__ void improved_kernel(const float* input, float* output, int N) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
        if (idx < N) {
            // Compute the same math in a unified way
            float val = input[idx];
            // Evaluate the transforms without branching
            float val_pos = val * 2.0f;
            float val_neg = val + 1.0f;
            // Use a conditional assignment
            output[idx] = (val > 0) ? val_pos : val_neg;
        }
    }
    

    This rewritten version still makes a choice, but it does so in a way that can be handled more concurrently. The basic idea in GPU programming is to use thread and block IDs to develop kernels that operate cooperatively (for example, splitting data among threads).

    This additional idea is branching on the warp itself — which is referred to as warp specialization. It’s very common to have kernels that deal with irregular data access, leading to branching, but by grouping specialized tasks into warps, you can still maintain high utilization of threads.

    For example, by branching on the thread ID and using barriers, you can specialize roles in warps, and have one set of threads dedicated to data loading and another to processing the data:

    __shared__ int data[128];
    
    __global__ void warp_specialization_kernel(int* global_mem) {
        int idx = threadIdx.x + blockIdx.x * blockDim.x;
    
        if (threadIdx.x < 32) {
            // Producer warp
            int value = global_mem[idx]; // ... load data from global memory ...
            data[threadIdx.x] = value;
            __namedBarrierArrival("data_ready", 1);
        } else {
            // Consumer warp
            __namedBarrierWait("data_ready", 1);
            int value = data[threadIdx.x - 32];
            // ... process data ...
            global_mem[idx] = value + 42;
        }
    }
    

    Here, the first warp (threads 0–31) acts as the producer, loading data into shared memory. The remaining threads (in warps 1, 2, etc.) act as consumers, waiting for the producer to finish before processing the data. The namedbarrierX function calls ensures the producer and consumer warps are synchronized. This sample kernel is simplified to illustrate the concept, but the pattern is useful for specialized tasks.

    Triton, with the new changes that landed recently, allows you to specify warp groups in the autotuning parameters:

    @triton.autotune(
        configs=[
            triton.Config(
                {
                    "BLOCK_SIZE_M": 128,
                    "BLOCK_SIZE_N": 256,
                    "BLOCK_SIZE_K": 64,
                    "GROUP_SIZE_M": 8,
                },
                num_stages=2,
                num_warps=4,
                num_consumer_groups=2,
                num_buffers_warp_spec=3,
            ),
        ],
        key=["M", "N", "K"],
    )
    

    num_consumer_groups greater than zero enables warp specialization, and sets how many consumers will be available. num_buffers_warp_spec specifies the how many shared memory buffers are use for transfer between the warp groups. The Triton compiler can then optimize kernels based on available warps, grouping threads intelligently and applying warp-level optimizations, which you can read about in the PyTorch blog post on warp specialization.

    One of the reasons for the visibility of this technique now is in the Hopper architecture there are 8 independent schedulers per SM (up from 4 on Ampere) which enables more concurrent execution of warp groups, and the added support for warp-group level instructions, which make synchronization between warp groups pretty cheap.

  • Streaming DiLoCo

    [2501.18512] Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch

    Every paper in this series has been required reading in (very) large language model training. The basic theme is that model training requires gang-semantics, where a large cluster of accelerators need to do coordinated work together in order to make progress, which gets progressively more expensive to enable and harder to do reliably as the number of devices in the cluster increases.

    The prior papers explored ways of splitting up the training into an inner loop where the model trained fairly traditionally, and an outer optimization loop that aggregated the differences and updated based on them – the outer optimizer works on the deltas between parameter values at the sync point. The outer optimizer still runs on the same cluster as all the inner loops, but it means that only at the “outer” sync point do you need to do synchronization between all the devices. This loosens the coupling between devices and allows introducing failure domains.

    This paper addresses the challenge that when you do synchronize you still have to send data for all the parameters, which requires a lot of bandwidth and can block forward progress. Streaming DiLoCo divides the model layers into different shards and syncs those at different times (in practicality, ever 5 inner optimizer steps), lowering the peak bandwidth required. They take shards in a strided fashion rather than sequentially to mildly improve stability and performance.

    To further reduce bandwidth, the communication between devices for the outer loop is done in 4-bit floating point! They still do the accumulations/optimization in 32 bit, but they didn’t see any performance loss when using the lower bit rate for comms. All of these comms are overlapped with the inner loop training, which helps minimize stalls.

  • Grouped GEMMs and MoE

    One of the challenges discussed in the Deepseek v3 paper is the availability of grouped GEMM kernels, which are used to hide the performance impact of many small kernel launches on GPUs. Deepseek uses many small experts (256!) rather than a few larger ones, which exacerbates this problem.

    Mixture of Experts models introduce multiple experts in the feed-forward portion of each transformer layer. Rather than having a single shared set of experts, each layer has its own. Each batch of tokens first passes through the standard attention block, followed by a lightweight linear layer with a softmax function1. This determines, for each token, which experts it should be sent to. Tokens designated for each expert are gathered and sent to the appropriate device via an all-to-all operation, as experts are typically distributed across different devices.

    Once the tokens are on the device with the right expert(s) we need to execute the matrix multiplies for each expert for its set of tokens. The obvious solution is just to loop through and launch each GEMM, but because these are small (small number of tokens, and smaller expert matrices) the kernel launch ends up being a lot of the performance. A grouped GEMM allows you to do this process on-device, taking in a list of tokens and experts and executing all the GEMMS with a single kernel launch.

    This varies from batch GEMMs as the inputs can vary – different experts might receive different numbers of tokens.

    There are example implementations available, including a tutorial on TritonLang that walks through a simple grouped GEMM kernel, as well as an example in Cutlass .

    1. In switch MoEs at least, but there are similar gating networks elsewhere. ↩︎