Category: ML Infrastructure

  • Maybe the agents shouldn’t write the kernels

    A thing you can do is take the most performance and correctness sensitive part of your stack and just ask a chatbot to write it for you. They will sometimes get it right!

    Back towards the end of 2024 Ouyang et al at Stanford attempted to benchmark how often that happened with KernelBench. DeepSeek R1 could one-shot 12% of simple ops, 36% of fused operators, and 2% of whole architectures. Still, things have moved on a bit in the past 18 months, and Han, Zhang et al.1 extended the idea in KernelBench-X. They found:

    1. Writing correct kernels and performant kernels is somewhat decoupled. You can refine kernels and that mostly helps with correctness: the models got more of the tasks compiling, but drops the average speedup on the way.
    2. What you ask for matters more than how you ask. The category of the task explained 3x more of the variance than switching between different agents or other method varieties.

    “Together, these results indicate that the capability boundary of current LLM-based kernel generation is not a single wall but a sequence of distinct barriers – compilability, semantic correctness, hardware efficiency and performance portability – each requiring different mechanisms to clear”

    In one particular area they tried getting the models to write quantization kernels, an area known for needing numerical precision. They got 0 out of 30. The models produced running kernels, just not kernels that were, you know, right.

    One thing that did stand out to me was that a lot of the baselines were eager PyTorch, so I decided to run an experiment myself. How do these models do against a compiler, not just eager?

    I took a popular model (Qwen, naturally), ran it through torch.compile with minimal settings on my DGX Spark and identified three kernels that were eating big chunks of the wall time: SwiGLU, residual+RMSNorm and the SDPA prelude. I then had ChatGPT, Claude and Kimi2 take a run at writing those kernels for the hardware.

    The results were an absolute blowout: SwiGLU 1.06x, RMSNorm 1.21x, SDPA prelude 3.91x. For the latter, Kimi stacked up three different weight matrices into one and fused multiple matmuls together. It was very impressive stuff.

    Then, another suspiciously well-timed paper arrived, FASTKERNELS from Snowflake. Rather than benchmarking against PT Eager or against single-operator references, they wanted to test how the models did on real model-serving problems, with a focus on the end-to-end speedups. Their takeaway:

    ”agents learn to generate kernels that score well in sandboxes but introduce interface incompatibilities, compilation-stack conflicts, and silent correctness degradation when integrated into real systems.”

    All of those issues hurt end-to-end performance when you put them in a real model-serving context. Of the three strong kernel-generation agents3 they tried, none beat the production baselines

    “in contrast to the supra-unity numbers these agents have reported on operator-level benchmarks whose reference is PyTorch eager.”

    Taking another look at my vibed-up experiment results, as the FASTKERNELS folks may have suggested, there was a catch. Several catches.

    The baseline, it turned out, spent an awful lot of time doing… kernel dispatch. Even getting 3.91x speedup on SDPA prelude led to an end-to-end model speedup of… 1.007x. Not quite as exciting.

    You also had to be very, very careful about how the agents were getting speedups.

    For example, the initial correctness check accepted anything within cos_sim >= 0.95 of the reference kernel. Codex “won” the SwiGLU round by replacing sigmoid(x) with clamp(0.21*x + 0.5, 0, 1), a straight line which diverges from sigmoid everywhere except a narrow band near zero.

    It turns out this kind of thing is pretty common. The FASTKERNELS folks found a case where an agent needed to write an all-reduce kernel for cross-GPU synchronization. The test harness they were using was single GPU so the agent just no-op’d it, replacing the all-reduce with a straight tensor copy. This “passed its checker but produces the wrong sum on every scenario of our 4-rank NCCL+Gloo harness.”

    Even when the kernels are right, and fast, it doesn’t mean they are… good? Several of the generated kernels in my experiment were somewhat unshippable due to hardcoded shapes or silent global mutations. FASTKERNELS found similar things:

    “L2 failures are dominated by syntactically valid kernels that respect the per-tensor signature but violate the surrounding production contract.”

    Which I think is the academic way of saying they wouldn’t ship those either.

    If you get your verification of the problem wrong in the harness, you will get a kernel optimized for the harness. Use the wrong contract and your kernel will be wrong in exactly the shape of your wrongness.

    Still, a small win is a win, right! My original run had agents outperforming torch.compile by about 2.6%. At that point I had a friend take a look who immediately pointed out that I had hampered the compiler unrealistically, and suggested running on max-autotune. This was especially unfair since the agents each got several cracks at the problem. Turns out, with that baseline the agents lost by 4.6%.

    And, that’s pretty similar to what FASTKERNELS found. Across 88 tasks and three agents the best of theirs landed at about 0.94x4 the performance of the production stack.

    Fairly late in the day I decided to replicate the experiments I had run on the Spark on a 3090. That’s Ampere, sm_86, an elder statesman of consumer GPUs at this point. It turns out that once again, some of the wins were just worse baselines. For example, Kimi tried the same SDPA-prelude matrix stacking as on the GB10, but on Ampere the 3.91x speedup turned into a 0.74x loss. The difference was cuBLAS: it was simply better tuned for the 3090 than the GB10, and did a much better job of utilizing all 82 SMs. The baseline Kimi had to beat was (relatively) higher.

    The question of “do agents beat compilers” is hard to answer because what we are (roughly!) measuring is compiler maturity. Agents are most useful in exactly the window where a compiler is weakest: new silicon, untuned heuristics, and libraries that are still evolving5.

    As libraries improve, hardware is better understood, and compilers mature, the value of exploratory search diminishes: there are “right ways” and it’s better to just use them than create custom solutions. If an agent is identifying patterns reliably and repeatably, it may as well author a compiler pass and spend more tokens on the areas that can’t be as cleanly captured.

    1. I think these folks are associated with Tsinghua, but to be honest I am not entirely sure! ↩︎
    2. Each model ran in their respective coding harnesses. One fun takeaway was the wall-time for generating the kernels was a factor too. Kimi took at least 3x longer than the other agents, spending a lot of tokens on the way, but also generated the most performant kernels of the three on every task on Blackwell, which was not what I was expecting. ↩︎
    3. Codex, KernelAgent and Dr. Kernel, the latter of which I hadn’t heard of but has by far the best name. ↩︎
    4. Codex landed at 0.94x. KernelAgent at 0.78x. Dr. Kernel got 0.53x, but still billed my insurance. ↩︎
    5. I suspect this is particularly pointed for the GB10, which is an unusual piece of hardware, and in particular has a lowish number of SMs. ↩︎
  • The elusive order of things

    SIMT offered a fantastic bargain. You write a straight-line program, the machine runs a lot of copies of it, and when one waits for memory the hardware swaps in others. You look with disdain on the less enlightened thread programmers dealing with deadlocks and concurrency etc. etc.

    Choosing what to run where and when is a scheduling problem, and there have been three effective approaches to that so far.

    You can schedule statically: decide ahead of time what all the units should do each tick. You can schedule temporally: swapping in different phases of workers via a pipeline. Or you can schedule spatially: divide the resources of the machine into different roles.

    The underlying mechanics of which one you pick tends to be determined by the hardware. A chip like a TPU spends most of its silicon on math, and fairly little on orchestrating work. That means static scheduling, and a compiler that can build you that schedule.

    Ampere and before1, and all the modern AMD chips, encourage temporal pipelining. The hardware will swap in warps (or waves) when one stalls ,and by structuring your kernels into phases you can hide memory latency and keep the chips busy.

    Hopper and beyond are where spatial scheduling started mattering, in the form of warp specialization. Nvidia GPUs let you assign different register footprints to different warp groups. When you introduce warp-group scoped MMA for compute and TMA for executing data moves from a single thread you have the ingredients to divide the pipeline between groups. Instead of the same worker doing load -> compute -> store you have different workers exclusively working on different parts of the pipeline. Blackwell made this… much harder. TMEM and UMMA added new operator and memory types, so you now need to schedule movement between shared memory, tensor memory, registers, global memory, and a variety of compute units.

    The problem is: how do you do that?

    To stick with Nvidia for a moment, at the bottom of the stack are barriers. An mbarrier is a phase switch for a specific number of arrivals: one side waits, the other increases the arrival count. When the counter matches the expected number, the phase flips. It’s elegant and straightforward, and easy to get wrong. A classic example is the phase parity bug: if you screw up the wraparound the kernel can work perfectly at first, but then deadlock waiting on the wrong phase.

    Next up, libraries like CUTLASS, and newer ones like ThunderKittens, package the patterns you tend to write. The CUTLASS Pipeline combines buffers and synchronization into a unit and makes it easy to compose common structures. This is where much of the expert-kernel-writing time goes, but that time encodes a lot of hardware-specific behavior. Hopper wants one set of patterns, Blackwell another, and even within a generation there can be differences between variants of the hardware. The more explicit the schedule is for the developer, the more they own the portability problem.

    The subsequent step is to make the schedule less explicit, while still keeping the roles visible. AsyncGraphene’s ARef is a good example of this. An ARef is a reference to asynchronously produced data. Basically, a channel, with synchronization attached. A producer writes, a consumer reads, and both sides can know when the other is done. A compiler can then plan a schedule. Nvidia’s TAWA work does this explicitly for Triton, tagging producers and consumers and lowering to ARefs. TLX on the other hand, as well as systems like PipeThreader, allow defining subtasks in a kernel that a compiler can schedule.

    TileIR and CuTile also enable building an explicit graph, but through focusing on the data itself. Attaching usage information on how data is read or written gives the compiler room to bundle work into tasks and reschedule.

    Getting the graph is the starting point, but then you need to identify what the right schedule actually is. In practice this involves exploring different shapes and combinations to work out which is best. You can either do that explicitly through heuristics and cost models of the hardware, or do it via searching across many different possible schedules to find the ones that work best. Most systems do both.

    But what do we need in a kernel DSL?

    If you are building a DSL for writing kernels, the starting point is to reflect whatever the hardware does. This is not only direct, but also a necessary option because there are always smart people operating at the frontier who have a strong intuition around how to drive the most performance. They’re often targeting very new hardware which is not yet well understood (sometimes, even by the people that made it).

    Beyond that, deciding what else should be on offer means answering three questions:

    1. How do you think about portability?

    Portability doesn’t mean “write one generic kernel and get peak performance everywhere”. But it can mean: what’s the minimal amount I can express to get correctness and a particular level of performance across hardware. Projects like Helion are explicitly operating at a high level to enable rapid research iteration. Regardless of your view on where the line for “high performance” is, you need something to define what “correct” means.

    Having a good concept of a “task” seems to offer the flexibility to schedule statically, temporally or spatially, but there are a lot of edge-cases to consider.

    2. What do agents change?

    Humans are not going to be writing every, or even most, kernels. We have to figure out how much of that portability or performance is a deterministic search, versus how much is agentic loops exploring the space somewhat probabilistically. Agents make generating code massively cheaper. They can create candidates, run profiling on real hardware, test hypotheses and explore options.

    But we also need a sense of where and how the agents fail, particularly when it differs from the patterns of humans. That includes things like verbosity: more lines (generally!) means more bugs. Performance can be both spiky and somewhat subjective; sometimes small changes can reshape the kernel’s performance, and a faster kernel might only be “correct” within specific numerical accuracy bounds.

    3. How do you think about kernel boundaries?

    A lot of discussion focused on GEMMs, which is understandable. But almost all real-world kernel work is across operator boundaries. FlashAttention wasn’t making the matmuls in attention fast, it was fusing them despite a reduction in the middle.

    When we are writing programs we are expressing intent and providing direction. We mix that “what” and the “how”. This reflects a search vs expressiveness divide; search-oriented approaches want you to focus more on the what, expressiveness leans more into the how. The more the units inside kernels can compose across kernel boundaries, the more we can optimize across models and discover patterns automatically2,

    The way I think about compilers is that they encode knowledge (in the form of rules and heuristics) about hardware. The more we can move that out of our heads, or the model’s parametric knowledge, the more we can focus our time or tokens on the parts we don’t yet understand.

    1. Mostly! cp.async was introduced in Ampere and it was very impactful in making temporal pipelining work, as it let the mechanism largely hide HBM latency ↩︎
    2. Whether via compiler, or agent! These problems tend to recurse, so you could have pretty much this whole discussion at the IR level too. ↩︎

  • Cutie Fly

    Cutie Fly

    The FlashAttention 4 paper is out and is fascinating, you should read it! One of the things that Tri called out on Twitter was that the experience of using a Python-based language (CuteDSL) significantly improved the dev loop, not just for him, but for Claude:

    CuTe’s layout algebra plus the quick iteration cycle of a Python DSL are a nice combination. Hence, it’s not too surprising that late last month,AMD dropped FlyDSL, which is, largely, CuteDSL for AMD. This is not a knock on FlyDSL! The project is very open about acknowledging CuTe and its provenance.

    To help navigate, here is a handy translation guide:

    • CuTeDSL: cute.make_layout.
      FlyDSL: flir.make_layout.
    • CuteDSL: cute.composition.
      FlyDSL: flir.composition.
    • CuteDSL: cute.zipped_divide.
      FlyDSL: flir.zipped_divide.
    • CuteDSL: cute.make_tiled_copy_tv.
      FlyDSL: flir.make_tiled_copy_tv.

    FlyDSL also calls out Colfax’s paper from earlier this year: Categorical Foundations for CuTe Layouts. This paper, along with the Integer Set Relations one from Nvidia last year, really started to establish a mathematical formalization of what had been going on in CuTe layouts. This kind of foundation enables verifying the approaches taken in fresh implementations, like FlyDSL’s.

    We can actually go see that, as the whole compiler is open source. This lets you compare the composition_impl in FlyDSL to the diagrammatic version in (section 4.1.3) in the Colfax paper to understand why it works!1

    Given the blistering pace of layout algebra, we shouldn’t be surprised that just a few days after, Cris Cecka of Nvidia dropped a beastly preprint: CuTe Layout Representation and Algebra:

    Colfax Research [19] analyzes CUTE layouts and some operations on them in the context of category theory. In this paper, we intend to provide a more definitive and formal treatment of CUTE concepts and their applications.

    Sometimes with this kind thing it doesn’t matter if you have the idea, it’s often specific implementations of that that end up defining the standard for it. I read this paper as Cecka planting his flag and saying “y’all, this is the real CuTe”. And he cuts no corners.2

    Cecka reframes layout algebra as a theory of loop transformations, showing that the objects you are transforming (Shapes Strides) and the things you are transforming them with (Shapes Strides) are the same.3

    One of the cleverest results of this is in Section 2.3.1. Cecka demonstrates that strides don’t have to be just… regular strides. If your stride is in fact a coordinate then each “step” in the stride moves in the coordinate dimension.

    This is, for example, what you need for TMA on Hopper or Blackwell: you tell it the logical position in the tensor and it figures out the physical address internally, handling tiling, swizzling and bank conflict avoidance in hardware. If you stride over coordinates, you can use exactly the same layout algebra as for your computations.

    Another example was that if a Stride is a bitmask you get something very like Triton’s LinearLayouts!3 That lets you compose layouts with swizzling, using the same composition operators again.

    The paper is full of these interesting, but also practical, results. Cecka gives guidance on optimizations like ‘avoid ranged slicing’; (e.g. a[2:4, 1:3]) as it mixes up tile size (an optimization knob) and thread ID (a runtime index)4, or use layouts to algebraically work out how to auto-vectorize loads and stores rather than hard coding5.

    There is something satisfying about paper on composition that itself pulls together ideas floating around CUTLASS internals, preprints, and alternative implementations, then shows they are all views of the same object. This will help projects like FlyDSL, Triton, and any number of other authoring libraries ground their management of one of the most painful aspects of kernel dev in a way that should make life easier, for everyone.

    1. I think! My understanding of category theory is similar to my understanding of Skibidi Toilet: I get the idea, but I have so many questions. ↩︎
    2. As an example, Cecka provides a wider generalization than the Colfax paper, demonstrating that CuTe layouts are not strictly closed under group composition: you can’t always compose layouts however you want. But! The failures correspond to real errors, which is the kind of restriction you actually want. ↩︎
    3. Actually, you do a bit better: being strictly in F₂ means Linear Layouts are limited to powers of 2, which it turns out is a bit limiting. ↩︎
    4. This makes it harder for compilers to separate static and dynamic elements. CuTe, and Fly, do this in two stages: zipped_divide to tile. then slice by a dynamic bid, allowing the compiler to optimize (e.g. constant-fold) the static tile parameter. ↩︎
    5. By composing the layout with the right-inverse of the other, apparently! Or calling max_common_vector(src_layout, dst_layout) ↩︎
  • TileIR

    There are a lot of things folks do on GPUs (including, sometimes, graphics) so I have an approximately-correct taxonomy of operations to group them in to:

    1. Dense compute: A matmul or a convolution.
    2. Map: Elementwise/pointwise work on each value of a tensor.
    3. Reduce: Process a tensor into fewer dimension, like a sum.
    4. Transforms: Change data structure. Easy ones like transpose, annoying ones like scatter/gather.
    5. Synchronize / Communicate: Move data, or wait for it (copies, collectives, fences/barriers).

    At the moment people are pouring billions of dollars into hardware that primarily does 1. And, at the same time, many of the greatest minds of our generation are attempting to ensure that the hardware spends as much time doing 1 as possible.

    The biggest barrier to doing a lot of dense compute is 5: moving things in and out of memory. Launching kernels, transferring data between host and device (or between devices), moving data between global memory and registers and so on. It’s like there’s a box defined by Data × Footprint 1 × Time, and everyone is trying to keep it as full as possible.

    This involves tradeoffs. You want to mul as many mats as you can, but you only have so much room to store accumulators. Fetching new data from memory also takes a while. You can keep many in-flight fetches around, but each one expands the kernel Footprint, lowering occupancy.

    There are 3 tricks that we can use to help fill up the box by stitching different operations together:

    • Epilogue fusion: take an elementwise op and fuse it onto the end of a dense op, so that when the MMA produces output, the elementwise op can be run while the output data is still in registers. A classic example: fuse the activation after the dense compute in a feed-forward net.
    • Vertical fusion: take two subsequent operations and chain them together to avoid running a loop for one, writing it back, then running a loop for the other2. A classic example is Fused LayerNorm: normally you’d need to add elements, then collect stats for the normalization. You can fuse the two to collect the stats as you add the residual.
    • Horizontal fusion: doing different things over the same data, in parallel. The Q, K, and V projections in a transformer all need the exact same input, so are good candidates to fuse horizontally.

    You rely on the design of the hardware to enable some of this. For example, an epilogue fusion is beneficial because it’s one kernel launch instead of two, and because the work doesn’t need to be written back to global memory, but also because the epilogue can overlap with other work.

    It’s not always obvious how to put these fusions together. Flash Attention was such a breakthrough because it made dense op fusion possible. The naive attention block has a softmax in the middle: Softmax(QK^T / √d) · V. That softmax is a reduction op, which means it needs all of QK^T to be computed first, a pretty large matrix. Tri Dao and colleagues realized that if you used online softmax you could calculate the softmax for subsets of the QK matrix, and avoid materializing the whole thing. They enabled fusing the QK into the softmax and the V in one kernel, at the tile level.

    Tiles are the subsection of a matrix you’re working on at any given time. In a matmul, tiles from both input matrices are loaded and multiplied, to produce an output tile. There’s a useful image of this in the Nvidia blog post on cuTile, Nvidia’s most recent entrant into the the kernel-development landscape. To side-step concerns of plagiarism, I had nanobanana plagiarize it for me:

    Illustration depicting three matrices labeled A, B, and C, showing a looping process with arrows indicating active rows and columns in a computational context.

    cuTile is built on a well-specified intermediate representation called TileIR. There’s an experimental backend for Triton that lowers to TileIR too. While Triton is block-oriented rather than tile-oriented, in practice what you mostly work on in a thread-block is… a tile. TileIR elevates the tile to a first-class concept.

    You can see this by generating the same kernel against the regular backend and the TileIR backend. Triton’s intermediate representation (TTIR) uses pointer arithmetic: generating offsets, computing masks, loading from explicit addresses. Here’s a bit of an inner loop of a matmul. It groups up which data it wants, loads the tiles a and b by pointer, and computes the dot product:

    %offs_m = tt.make_range {end = 128, start = 0} : tensor<128xi32>
    %a_ptrs = tt.expand_dims %offs_m {axis = 1} : tensor<128xi32> -> tensor<128x1xi32>
    %a_ptrs_1 = tt.splat %stride_am : i32 -> tensor<128x1xi32>
    %a_ptrs_2 = arith.muli %a_ptrs, %a_ptrs_1 : tensor<128x1xi32>
    ...
    scf.for %k = %c0 to %num_k step %c1 iter_args(%acc = %zero) -> tensor<128x128xf32>:
    %a = tt.load %a_ptrs, %mask : tensor<128x64x!tt.ptr<f16>>
    %b = tt.load %b_ptrs, %mask : tensor<64x128x!tt.ptr<f16>>
    %acc_new = tt.dot %a, %b, %acc : tensor<128x64xf16> * tensor<64x128xf16> -> tensor<128x128xf32>

    TileIR on the other hand preserves the tile as a semantic object. This snippetis doing exactly the same thing, but this representation elides the pointer math and masking:

    $33: Tile[float32,(128,128)] = typed_const(value=0)
    accumulator.2: Tile[float32,(128,128)] = for k.1 in $36 (with accumulator.0 = $33)
    do
    $50: Tile[float16,(128,64)] = tile_load_token_ordered(array=A, index=($7, k.1), shape=(128,64))
    $64: Tile[float16,(64,128)] = tile_load_token_ordered(array=B, index=(k.1, $10), shape=(64,128))
    $72: Tile[float32,(128,128)] = tile_mma(x=$50, y=$64, acc=accumulator.0)
    continue $72

    This is a nice IR (compact!), but from my perspective the most interesting part is that load function: tile_load_token_ordered.

    The “time” dimension of the Data × Footprint × Time box is the hardest one to manage. Time questions separate performant kernels from slow ones: When to prefetch, how to overlap loads, and so on. Since the advent of warp specialization, the Triton compiler has been exploring pipelining options through heuristics and autotuning, and kernel engineers have been going straight to the hardware with explicit barrier API extensions like TLX3 and Gluon4 .

    TileIR goes a somewhat different route. It assumes an unordered memory model: the order your code is written in does not determine when data is actually available. Instead, each memory operation returns a token and you attach semantics to it: read-only, write-only, read-write and so on.

    By being explicit about memory dependencies you give the compiler freedom to manage the Time dimension. Where accesses don’t overlap the compiler can freely reorder them. Where they do, the token chain tells the compiler exactly what depends on what. The kernel expresses intent; the compiler maps that to the hardware.

    TileIR is (mostly) targeting Blackwell right now, and the experimental backend is still early. The open question is whether we can express this smoothly enough in the syntax of kernels to actually enable taking the same kernel across hardware, or whether we are just adding some syntactic-sugar to avoid when doing hardware-specific tuning.

    That said, the idea feels pretty right? The tile is the unit with which we can express what we actually mean about memory, ordering, and fusion. The CUDA programming model was always about bounded-linearity within a massively parallel framework, and this loosens the bounds that little bit more.

    1. Use of available registers and shared memory, for example ↩︎
    2. This is loop fusion, in compiler terms. There are other things you can do, but this is the big one. ↩︎
    3. Triton Language Extensions, from Meta. As a disclaimer, these are the folks I work with. ↩︎
    4. From OpenAI ↩︎
  • Megacores

    Megacore - Systole as a 80s metal album cover.

    What we do in machine learnings owes a lot to the history of computer graphics. Folks like Kurt Akeley, one of the founders of SGI, identified that 3D graphics have a naturally pipelined structure. You have a high volume of similar operations, such as applying pixel-y soldier textures to a mesh of triangles, and by pipelining them you can find an opportunity for a high degree of parallelism.

    Akeley was one of the drivers of OpenGL, which provided a standard interface to that pipeline, and later worked with Nvidia on CG, a realtime shader language and compiler. Shader languages, as used in Pixar’s RenderMan and other non-realtime 3D use cases, introduced an approach where you could manage lighting programmatically by describing the transforms to each individual element. The shader would be run in parallel across all the geometry or pixels it was addressing.

    With CUDA, Ian Buck and others at Nvidia helped formalize what had been true in the hardware for a while: GPUs were massively parallel processing machines, not just polygon factories. CUDA was part of a move from the supercomputer approach of Single Instruction Multiple Data (SIMD) to Single Instruction Multiple Thread (SIMT). On a Cray or other vector oriented processor you had to pack the work into a vector. CUDA let programmers familiar with CPU threads think in those terms instead. Under the hood, the threads in a warp were executed in lockstep, but they could be masked off to allow for divergence. It was flexible, fast, and attracted the attention of the machine learning community. Because so much of ML is large matmuls, Nvidia bolted on Tensor Cores as specialized co-processors that handled blocks of matrix math efficiently. This combination of performant hardware and flexible software helped make Nvidia the most valuable company in the world, and drive up house prices across the Bay Area.

    But, it transpires, not everyone loved shoveling their margin to Jensen, and they looked for more cost-efficient ways to run ML workloads. The flexibility for threads to branch, pause or switch requires infrastructure and silicon. You need big register files per core, multiple levels of memory to cache, and logic to manage swapping in and out warps.

    If you look at the “do the math” parts of a chip, a CPU probably only spends about 10% of silicon on that, with the rest managing the chaos of running an operating system: branch prediction, caching, data movement. A GPU, in contrast, is a wildly efficient machine, with maybe 30-40% of the silicon dedicated to mathing effectively.

    When Google looked at the problem of running inference at their scale back in the dark ages of 2016 they wanted to spend as much of their budget as possible doing the math, to keep the costs as low as they could. The chip they created, the Tensor Processing Unit (TPU) recently hit its 7th iteration and SemiAnalysis published an extensive breakdown on it: TPU v7 Ironwood, quickly followed up with a deep dive Amazon’s Trainium v3.

    Trainium3 takes a similar approach to Trainium2 and Google’s TPU and builds the chip out of a small number of large NeuronCores. This contrasts with GPU architectures like Nvidia and AMD’s, which instead uses a large number of smaller tensor cores. Large cores are typically better for GenAI workloads since they have less control overhead.

    Dylan and his team are touting these as the first chips to genuinely threaten Nvidia’s moat. The big frontier labs seem interested, with deals and investigation from Anthropic, OpenAI, Meta and others. As the piece repeatedly points out, if you want to understand the dominance of Nvidia you have to focus on the system, and not the microarchitecture. So, of course, I want to talk exclusively about the microarchitecture here.

    TPU, Trainium, as well as other custom approaches like Meta’s MTIA1 lean on an approach called Systolic Arrays. As a recap, Nvidia’s Streaming Multiprocessor (SMs), AMDs compute units ,and so on are cooperative multiprocessors. They access registers, talk to caches and handle the flow of data. Threads can request data if it’s not ready and the hardware warp schedulers will swap in another piece of work to keep the chip humming.

    Systolic arrays are different. The name comes from systole, the phase where your heart pumps blood. In a systolic array, you load your data once and fire it through a grid of Processing Elements (PEs). Each element maths its math then passes the result to its neighbor on the next clock tick.

    This was very much in line with the needs of the original TPU: load a set of model weights up, then pump user requests through as efficiently as possible. TPUv1 only supported int8: it was a low-bit, high-efficiency matmul machine. The data flow needed to be pre-determined: you set it up and make it go, which made it incredibly silicon efficient. You don’t need lots of caches or schedulers, and in fact the original TPU didn’t have any at all!

    The con of course was that you have to get it right! If the data isn’t there to pump in, the whole thing just waits. There is no backup plan to another warp, no other threads. Not only that, but because the systolic arrays are generally a lot bigger (say 256×256 vs the Tensorcores 16×16), you have fewer of them. While an Nvidia GPU might have more than 100 SMs, a Trainium v3 has 8 cores, and a TPU has just 2. Each core is a lot larger, and wasting it gets a lot more expensive.

    Presumably Jeff Dean just programmed these right the first time, but for the rest of Google (and later the world) they spent years building XLA (Accelerated Linear Algebra), a full-graph compiler. In GPU kernel programming the challenge is hiding memory latency and managing register pressure. On a TPU-type approach, there is one massive VMEM that fulfills a similar role as the registers and no memory hierarchy, but you can’t rely on the hardware to swap between jobs. XLA needs to know exactly how the graph works so that it can schedule the right data at the right time.

    TPUs used a VLIW architecture: Very Long Instruction Words. Rather than a traditional instruction set with diverse instructions, VLIW lets you bundle Very Long packages of instructions into single units (kind of a silicon equivalent of German) which execute operations on each of the different units of the core at the same time. This was introduced in TPU v2, and its where the pressure on the compiler really multiplied.

    To draw a GPU analogy, if you think about something like a Relu(AxB+C) you have a graph of operations: AxB -> Result, Result + C -> Result2, Relu(Result2). To optimize that you could use an CUDA graph to compile it into single kernel dispatch and CPU/GPU communication. One step further would be kernel fusion: keep all the intermediate results in registers and write one kernel that avoids the back and forth to higher tier memory. That lets you bundle up even more , but you have to have even higher confidence in the sizes involved to avoid running out of registers,

    VLIW is like parallel kernel fusions: a TPU v2 had 2 matrix units, 2 vector units, 2 scalar units and 2 memory load/store units2.To keep them busy every step the compiler needs to plan ahead enough to give each of them something useful to do. VLIW instructions bundle those ops along with any constants needed into a single instruction. Fusion goes from being an optimization to being a necessity. Once you get it though, you can spend more like 50-60% of your silicon on the part you care most about, and that translates into an excellent total cost of ownership.

    Does this mean we should all be cancelling our Rubin orders and buying TPUs? I mean, no. But there is some nuance. Choosing between flexible streaming processors or efficient systolic megacores feels drastic, but I think it might not matter quite as much as it seems.

    Research still overwhelmingly benefits from flexibility. You are running experiments, solving bottlenecks and debugging. Nvidia tends to be the big lab tool of choice thanks to the flexibility, the depth of tooling and the general CUDA ecosystem3.

    If you are mainly serving a massive model, it’s worth the investment to lock down all the weirdness and optimize it. That’s where the megacore chips have proved their mettle first, with TPU, Inferentia4, MTIA and others all starting on that side of the house.

    Folks like Akeley and Buck realized that when you’re building a chip you’re really building a programming model. Get that right, and the model can long outlast the hardware. Balancing expressivity with performance is the thing that lets a platform win: who best lets researchers and engineers define the future without fighting the silicon.

    What seems to be emerging isn’t quite the SIMT/CUDA architecture: its something around expressing the dataflow of tiles on the critical kernels5, while relying on a compiler to optimize the larger graph and compute.

    Making sure that you have access to the right software might be more important than trying to perfectly identify which hardware platform is the once and future king. But also, look, the world moves fast and if you get a Prime Day deal on Trainium instances, you should probably just take it. The hardware can and will change and it can always be adopted, as the frontier labs are showing. If we keep hunting for the expressivity we need, as OpenGL, CUDA, Triton and others have over the years, we will keep unlocking the possibilities in whatever hardware is available.

    1. Disclosure: I work at Meta and like these chips a lot, though no one would let me anywhere near any chip design, luckily enough ↩︎
    2. Newer versions have others too, like the sparse cores in TPU v6 and v7 which are basically dedicated embedding management processors ↩︎
    3. With the notable exception of Google themselves, though the Jax-XLA-TPU ecosystem is very rich internally ↩︎
    4. Amazon remain undefeated at naming things ↩︎
    5. From system to VMEM on megacore approaches, from SMEM to registers on GPUs ↩︎

  • Let’s all switch to FP16?

    Serious scientists use FP64 – 64 bit floating point numbers – for high precision simulations, but in the world of machine learning we got by for the longest time with FP32. The perennial quest for increased FLOPS, particularly when memory bound, made even that seem too expensive though.

    FP16 offered a reduced numeric range, but at half the size. Training with it in practice meant embracing autoscaling1 which ensured the values stayed within the range FP16 could represent. Then, Google developed BF16: it moved some of the bits to the exponent from the mantissa, so offered the same numeric range as FP32, but with reduced precision.

    Since TPUv3 back in 2018 and Ampere in 2020 it’s been finding its way into hardware and has become the go-to format for training for many models. Life was good, and training in FP16 was mainly discussed as a memory of hard winters past.

    Last week [2510.26788] Defeating the Training-Inference Mismatch via FP16 dropped and threw ML Twitter into a tither by making the argument everyone was doing Reinforcement Learning wrong and the solution… was FP16.

    “In this work, we take a step back from the complex algorithmic fixes and investigate the root cause of the numerical mismatch: floating-point precision. We identify that the modern standard for mixed-precision training, BFloat16 (BF16), is the primary culprit. While BF16 has a wide dynamic range which is excellent for stable pre-training, its low precision makes it highly susceptible to rounding errors that accumulate and eventually cause the training and inference policies to diverge.”

    The process for RL generally looks like:

    • Get a problem in a prompt
    • Do inference on the model to generate complete responses (a rollout)
    • Get a reward score for the response(s)
    • Run a training loop on the model to update weights based on the reward

    If you want to be on-policy (which generally trains better) you need the “model” in steps 2 and 4 to be identical, but the actual code running around the model in the two steps is different: for example, you don’t use a KV cache in training and you don’t store gradients in inference. But you do want to keep the weights and numerics of the model the same, else your on-policy training becomes a little bit off-policy.

    The last year of LLM research has been scaling this up, which requires managing a training and inference flow efficiently. This ongoing pressure to optimize the two paths independently leads to a risk of divergence. The paper finds that absolutely happens, and the divergence collapses the effectiveness of the learning. Unless, that is, you use FP16:

    This is precisely why switching to FP16 provides a fundamental solution. With its 10 mantissa bits, FP16 offers 8 times more precision (210 values vs. 27 values) than BF16. This higher fidelity means that the outputs of the training and inference engines are much more likely to be numerically identical. The increased precision creates a buffer that absorbs the minor implementation differences between the two engines, preventing rounding errors from accumulating and causing a policy divergence”

    The paper does an excellent job of breaking down the many reasons why this happens, but it pretty clear that FP16 is a patch: if you can’t get your numerics perfectly matched, then having more precision gives you more wiggle room.

    About a month before this the ByteDance folks posted a fantastic deep dive into RL collapse from discrepancies between training and inference: When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

    They identify a range of concerns, including straight up bugs:

    “According to this GitHub issue, we set disable_cascade_attn=True when initializing the vLLM engine and found that it significantly helps reduce the training-inference mismatch in experiments conducted on A100 GPUs.

    Many of the experiments in the FP16 vs BF16 paper were run on A100s2 , so some backlash emerged suggesting that perhaps this whole thing is just a kernel error. But as ByteDance showed, there really is a lot going on that can make things worse.

    Another example is Horace He’s recent work at Thinking Macines around a related problem: Defeating Nondeterminism in LLM Inference – Thinking Machines Lab

    “As mentioned above, one common explanation for why kernels add numbers in different orders is the “concurrency + floating point” hypothesis. The hypothesis states that if the order in which concurrent threads finish is nondeterministic and the accumulation order depends on the order in which concurrent threads finish (such as with an atomic add), our accumulation order will be nondeterministic as well.”

    Horace calls out variance in batching as the primary cause of non-determinism, and hence another quite plausible cause of inference/training mismatch

    “In other words, the primary reason nearly all LLM inference endpoints are nondeterministic is that the load (and thus batch-size) nondeterministically varies! This nondeterminism is not unique to GPUs — LLM inference endpoints served from CPUs or TPUs will also have this source of nondeterminism.”

    The meta-point is that despite being a field fundamentally based in mathematical precision we have been sloppy with numerics, pretty much everywhere.

    Ed Yang’s session in the PyTorch Conference keynote3 a couple of weeks back called this problem out from the perspective of scaling up ML infrastructure. He presented a number of solutions to try and address it, which often comes down to giving folks control over precisely how the numerics work in different parts of their model.

    While the focus here was on RL and FP16, the reality is we deal with this for training->inference in much simpler cases, as well as when moving models between different hardware. Even within generations this can be hard: one of the fun infra problems when the H100 came out was everyone discovering that the FP8 tensor cores in the Hopper used a 22-bit accumulator for intermediate calculations, which wasn’t really documented!

    The balance between speed and accuracy is often effectively made empirically: if something is faster, and works, then at some level it’s right! Reinforcement Learning mixes together different evolutionary chains of optimizations, so maybe those serious scientists with their FP64 were onto something. Not because they absolutely needed the precision, but because they needed to know they had the precision.

    We’re probably not going to switch industry wide back to FP164, but getting a better numerical grounding into the tools we use is going to make everyone’s lives easier, eventually!

    1. torch.cuda.amp and friends ↩︎
    2. Though they did verify on Hopper some as well, which some people seemed to miss ↩︎
    3. Check out the recording: Keynote: PyTorch Technical Deep Dive – Alban Desmaison, Peng Wu, Mark Saroufim & Edward Yang, Meta ↩︎
    4. Especially since most labs are doing so much with FP8 or less these days, and it would probably annoy a bunch of chip designers ↩︎
  • Helion and the evolving GPU programming model

    Helion: A High-Level DSL for Performant and Portable ML Kernels – PyTorch

    Lots of announcements around the Triton and PyTorch Conferences this week, including the 1.0 of Helion, a high-level kernel authoring DSL:

     It establishes a new layer of abstraction that bridges the user-friendly simplicity of PyTorch with the performance of a lower level language. By automating tedious and error-prone tasks like tensor indexing, memory management, and hardware-specific tuning, Helion empowers developers to focus on algorithmic logic rather than hardware-specific implementation details. Helion achieves this balance by pairing a familiar, PyTorch-centric syntax with a powerful autotuning engine that automates the complex search for optimal kernel configurations. This results in a system that delivers performance portability across hardware architectures while drastically reducing development effort. 

    There has been a bit of an explosion in kernel-authoring options recently with CuTe-DSL and CuTile from Nvidia, TileLang (as featured in recent DeepSeek releases), Gluon and TLX1 as well as evolutions to core Triton, Thunderkittens, Pallas, and others.

    There are a couple of different axes of progress occurring in GPU authoring. The first is between iterable, researcher-friendly declarative code and tightly written hardware-friendly imperative code.

    Its a classic developer-experience trade off: you let people tell you what they want to do (matmul these things then apply a softmax) or you let people tell you precisely how to do it (run this dot product on these SMs then aggregate the result).

    In general you want to stay as high-level as possible, particularly if you are experimenting with lots of different variants in a research type setting, but you may have a bound on the performance hit you can accept. A common example is you want to iterate on some attention variant, but don’t want to completely give up on the performance wins of Flash Attention.2

    Triton and others provided an interesting middle ground: it was easy enough to iterate with thanks to being embedded in Python, and was performant enough as it leveraged a compiler to automatically apply some optimizations. You are still much more imperative in a PyTorch program, but you work at a higher level of abstraction: rather than writing programs which own a thread of data, as in CUDA, you think about a tile of data. The ThunderKittens docs put this well:

    A GPU is not really a 1000×1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16×16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16×16 values.

    The next abstraction that frameworks developed was how to represent data across the memory hierarchy. To take advantage of the tensor cores you have to have data laid out in a specific way in registers. But you are better off loading data in a different order in global or shared memory. CuTe offered a big benefit by giving you types to represent layouts that could be composed, making it easier to keep track of the data movement required. Triton and others leaned on the compiler to infer the right layouts and offered higher-level APIs to copy data between stages.

    This started to get challenging on Hopper, thanks to TMA3 and the limitations of memory bandwidth, which gets to the second evolution happening in GPU kernels. How do you orchestrate the movement of data between memory layers while ensuring that data was you keep the tensor cores saturated. This involved techniques like warp specialization, where individual warps do different operations towards a shared goal. That means carefully allocating ownership over registers to avoid warps stepping on each other. Blackwell4 made this even trickier with the addition of TMEM, 2-CTA mode and other features that offered more performance but required even more careful orchestration.

    In compiler terms this is a scheduling problem and in general the industry is quite good at it! CPUs give compilers a lot of leeway to schedule operations efficiently because they have a great deal of support for out-of-order execution, well documented ops, and substantial caches. GPUs process groups of threads5 in lockstep and demand strict timing about when to insert barriers, issues async operations and so on. 

    A GPU scheduler has to tag operations to specific warp-slots in advance, assign numbers of registers to them to avoid conflicts, and sync them with barriers. It’s a lot more brittle: if we guess wrong, we can idle the Tensor cores and tank efficiency. The actual execution model is a bit of a black box too: the target for compilers (PTX) is actually further compiled to SASS by nvcc.

    Across the industry we’ve been exploring ways to be more explicit without giving way all of the operational and developer efficiency gains of higher-level languages. CuTe-DSL offers a very close-to-hardware model but in a Pythonic package6, Gluon (OpenAI) and TLX (Meta) add extensions to allow modelling pipelines in code without getting rid of the Triton compiler, TileLang builds on TVM with explicit pipeline declarations.

    One of the reasons for this variety is we don’t quite know how to express a warp-group pipelined execution model. For example, TileLang has a pipelined construct:

    for k in T.Pipelined(loop_range, num_stages=num_stages):
        MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)  # Q @ K^T
        Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
        Rescale(acc_o, scores_scale)  # Apply correction
        MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)  # P @ V

    Gluon has a descriptor that allocated resources like registers explicitly to warps:

    gl.warp_specialize(
            (config, chnls, descs, M, STAGE),     # Args to correction stage
            _attn_fwd_correction,                  # Trunk task (1 warp, 192 regs)
            (config, chnls, descs, M, STAGE),     # Args to specialized tasks
            [
                _attn_fwd_softmax0,    # 4 warps, 192 registers - Softmax tile 0
                _attn_fwd_softmax1,    # 4 warps, 192 registers - Softmax tile 1
                _attn_fwd_mma,         # 1 warp, 24 registers  - Matrix multiplies
                _attn_fwd_load,        # 1 warp, 24 registers  - TMA loads
                _attn_fwd_epilogue,    # 1 warp, 24 registers  - Store results
            ],
            [4, 4, 1, 1, 1],          # Warps per stage
            [192, 192, 24, 24, 24]    # Registers per stage
        )

    And TLX tags sections of code with contexts to indicate groupings, and also allocates resources:

    with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS,
                        registers=232,
                        replicate=NUM_MMA_GROUPS):

    They can all work and finding the best trade off is a good goal, but in all cases they do force a lot of decisions. As an example, that allocation of how many registers to use is not only operation dependent, its hardware dependent, and that makes portability between hardware (even different generations from the same vendor) expensive. Manual controls are necessary: it takes time to develop the compiler passes and heuristics to optimally divide work, so handing explicit control over7 is beneficial, particularly when serving at scale. The cost is complexity and portability. This is where Helion takes a different tack

    Anyway, so what about Helion?

    Helion instead take a point on the line above Triton, but below the ML frameworks. It focuses on just expressing what you want to happen from the tile perspective.

    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
        out[tile_m, tile_n] = acc

    Under the hood, this compiles down to Triton. You might think would be a bit of a no-op on performance, but in practical terms its often better. The reason is search: Helion can autotune across a wide number of parameters, then let you bake them into your kernel once you’ve identified good ones for your specific setup. The example in the blog posts shows how many dimensions of search need to occur:

    @helion.kernel(config=helion.Config(
        block_sizes=[64, 64, 64],
        loop_orders=[[0, 1]],
        l2_groupings=[4],
        range_unroll_factors=[0, 1],
        range_warp_specializes=[None, False],
        range_num_stages=[0, 3],
        range_multi_buffers=[None, False],
        range_flattens=[None, None],
        num_warps=8,
        num_stages=6,
        indexing='block_ptr',
        pid_type='flat'
    ))
    def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    This makes moving to different hardware as simple as redoing the search process, and offers a much more comprehensive exploration than most folks would do when hand-rolling a lower level kernel. Its a very interesting idea, and I’m glad to see more people kicking the tires!

    Low-level optimizations aren’t going away any time soon, but I’m glad to have more exploration in the kernel development space. Finding the right abstractions and right compiler approaches to keep scaling kernel development will help make it accessible to more and more people and ensure that we can evolve our kernels with the hardware.

    1. Also a Meta thing, disclaimer. ↩︎
    2. This is the logic behind FlexAttention, whch was one of the lights that guided the way towards Helion. ↩︎
    3. Fully async copies – a separate execution engine to move data ↩︎
    4. Well, datacenter blackwell. Consumer blackwell lacks TMEM and 2-CTA, so is a bit more Hopper-like programming model. I’m not sure yet what the DGX Sparks have! ↩︎
    5. Warps – 32 threads on Nvidia, or waves, 64 threads on AMD. The important distinction is that all these threads are doing the same thing: you can mask some of them out, but they have a fairly simple march through the instruction. ↩︎
    6. With a JIT! ↩︎
    7. Without making people write templated C++, sorry Ben ↩︎
  • Constraints & Orchestrators

    I recently read a few posts that helped connect the dots on why Python is a) so successful as the lingua franca of ML b) also seems likely to be successful in the future1.

    ML code reads like one program, but runs many: CUDA kernels, vectorized CPU loops, graph compilers and a bunch of glue code moving data around and tying things together. Python has continually improved at balancing two somewhat competing challenges: constraining the hot path so compilers can optimize it and structuring an orchestration path so humans can reason about it.

    Hot Path

    constrained languages are easier to optimize by Jynn Nelson touches on this:

    we should not be asking “what language can i use everywhere for every purpose”; we should build meta-languages that allow you to easily use the right tool for the job. this is already true for regular expressions and query languages; let’s go further. i want inline futhark; inline CSS selectors; inline datalog; ffi between python and C that’s trivially easy. the easier we make it to interop, the easier it becomes to pick the right tool for the job.

    Compilers are generally going to perform better if you have regular shapes, minimal side effects, predictable memory access and so on, but you want languages to be expressive and flexible, particularly when “research” is a big part of the work. In practice, that’s precisely what happens with ML : torch.compile lowers PyTorch graphs to an IR and (often) emits Triton kernels. Being able to hand off inner-loops to specialized languages allows compilers and runtimes to optimize and target the use cases they are best at.

    While this is (somewhat) clear for GPUs or other accelerators with distinctive programming models, I think it’s also largely true for getting the best out of modern CPUs. Daniel Lemire’s SEA 2025 talk covers nearly a decade of performance work and sums it up: modern CPUs do nearly as many instructions per cycle as you can feed them. To really maximize performance you need to batch work, reduce instruction counts and vectorize. We can do some of that in the general Python2 runtime but dynamic dispatch, aliasing and side effects all make the job a lot harder. We can add speculative guards, which can be hard to reason about, or give up and lose performance. By having DSLs3 that add additional constraints we can give ourselves the ability to get much, much higher performance without scrificing the overall flow of our program.

    Orchestration Path

    Python is unusually good as an orchestrator. From a readability perspective the language is baseline very readable and as long as libraries and DSLs stay Pythonic they tend to inherit that intelligibility. The challenge with orchestration is coordinating work in such a way that your most precious resources are well utilized. The investments in Free-Threaded Python make it a lot cheaper to do concurrency, but they don’t magically fix the challenge of coordination.

    asyncio: a library with too many sharp corners covers some of the many failure modes the community have encountered with asyncio, and makes a case for Trio or ANyIO style structured concurrency that allows for manageable failure modes.

    asyncio is not a good library. It is constantly full of sharp edges everywhere with implementation details leaking and poorly designed APIs forcing end users into odd code patterns to avoid fundamental flaws in the interfaces.

    This is very much a readability version of the constraints concern on the hot path. Threads are a bad app abstraction over shared mutable state, reasoning about races and cancellation is hard, and primitives are always leaky. But threads are a perfectly fine implementation detail behind a more constrained API, like task groups, or actors, or so on.

    One area that I do think needs sustained improvement is how we debug and trace across this kind of set up: it’s been challenging even in a controlled environment to really understand how all the pieces interact in a reasonably scaled ML workload, and I imagine that problem will only get worse. But I also expect that the flexibility and breadth of Python will end up a boon there as well.

    1. Beyond just sheer momentum, of course. ↩︎
    2. Or any language! Certainly for some optimizations having a JIT for Python would (and does) make life easier. ↩︎
    3. Whether that is an embedded JIT like Triton or a library+execution engine like Polars. ↩︎
  • PyTorch Conference 2025

    The schedule is up for the 2025 edition of the PyTorch conference, which is now at the Moscone West in San Francisco.

    https://events.linuxfoundation.org/pytorch-conference/program/schedule/

    There are a lot of great sessions, but I’ll highlight some I personally find particularly interesting:

    Post-Training: Clearly a big theme this year, with some interesting talks from multiple groups:

    General Training

    Kernel development

    Compilers

    Inference

    I’m looking forward to October!

  • Cute-DSL

    In May Nvidia shipped CuTe‑DSL, the Python library they teased at GTC earlier in the year that mirrors CUTLASS’s C++ tensor‑layout . Then, at the start of June, the ‑dev label disappeared (so presumably its production ready now). The pitch is simple: Write speed‑of‑light kernels from the comfort of Python.

    Of course, nothing about CUDA is ever really simple. CuTe‑DSL gives the full Cutlass experience1, wrapped in an ever so slightly more approachable interface.

    Getting Cute: Transpose

    Matrix transpose felt like a reasonable ‘hello world’ : (B[j,i] = A[i,j]). The PyTorch version is simple: torch.transpose(input, 0, 1).

    To get a baseline, here is a simple transpose kernel in Triton. We tl.load, flip the coordinates and tl.store it back.

    @triton.jit
    def triton_transpose_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE: tl.constexpr):
        # 2D block coordinates
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        
        # Calculate offsets
        offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        
        # Load with masking
        mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
        a = tl.load(input_ptr + offs_m[:, None] * N + offs_n[None, :], mask=mask)
        
        # Store transposed (swap coordinates)
        tl.store(output_ptr + offs_n[:, None] * M + offs_m[None, :], a, mask=mask)

    Here’s the same idea in CuTe‑DSL. CuTe leverages a decorator and Pythons ability to integrate with JITs. Anything decorated with @jit runs host-side (on the CPU), while @kernel is used for device side (on the GPU). There are both AST based and tracing based options dependening on the presence of dynamic shapes or control flow.

        @cute.kernel
        def transpose_kernel(self, mA: cute.Tensor, mB: cute.Tensor):
            tidx = cute.arch.thread_idx()[0]
            tidy = cute.arch.thread_idx()[1]
            bidx = cute.arch.block_idx()[0]
            bidy = cute.arch.block_idx()[1]
            # This might all be unnecessary
            # but I was fearful of the compiler
            tile_start_m = cutlass.Int32(0)
            tile_start_n = cutlass.Int32(0)
            global_m = cutlass.Int32(0)
            global_n = cutlass.Int32(0)
            M = cutlass.Int32(0)
            N = cutlass.Int32(0)
            val = cutlass.Float32(0.0)
            # Calculate tile starting positions
            tile_start_m = bidy * self._tile_size
            tile_start_n = bidx * self._tile_size
            # Calculate global coordinates for this thread
            global_m = tile_start_m + tidy
            global_n = tile_start_n + tidx
            # Get matrix dimensions at runtime
            M = mA.shape[0]
            N = mA.shape[1]
            # Bounds checking and transpose operation
            if global_m < M and global_n < N:
                val = mA[global_m, global_n]
                # Transpose: B[n, m] = A[m, n]
                mB[global_n, global_m] = val

    What just happened?

    • Thread and block indices come straight from CUDA (thread_idx, block_idx), vs the Triton block abstraction
    • No explicit loads or stores: CuTe uses overloaded [] to generate them.

    Launching isn’t a million miles away from Triton:

    @cute.jit   # host side
    def launch(self, A: cute.Tensor, B: cute.Tensor):
        M, N = A.shape
        grid = ((N + self.T - 1)//self.T,
                (M + self.T - 1)//self.T, 1)
        self.transpose_kernel(A, B).launch(
            grid=grid,
            block=[self.T, self.T, 1],
        )

    Because CuTe‑DSL speaks DLPack, you can hand it a PyTorch tensor directly. If you wanted to cache the conversion, it looks like this:

    A_cute = from_dlpack(A).mark_layout_dynamic()

    The mark_layout_dynamic is used to trigger the dynamic shape support, and avoid shape specialization. The one place where this went a bit funky in my testing was dealing with singular leading dimensions: there you need to be more explicit about the shape to satisfy the compiler.

    Layouts and Memory

    This kernel isn’t really leveraging the fundamental value of CuTe though, which is composable tensor layouts and memory management. CuTe‑DSL exposes the full memory hierarchy: global, shared, register (and tmem for those with blackwells), and lets you tile, copy, and pipeline data between them. Common primitives:

    • make_layout / make_layout_tv: describe how a tensor is laid out.
    • cute.zipped_divide(tensor, tiler): tile a tensor.
    • cute.copy(src_layout, dst_layout, pred=mask): async copy.
    • cute.arch.sync_threads(): explicit barrier.

    HGEMMony2

    CuTe ships with some example kernels, so I grabbed one — an HGEMM (half-precision, FP16, batched GEMM) — and compared to an example implementation in Triton.

    To express the same thing in  PyTorch, we can unleash our inner Jeremy Howard and use einsum notation: torch.einsum("mkl,nkl->mnl", a, b).  Take L batches of a MxK matrix, L batches of a NxK matrix, and return L batches of a MxN matrix.

    Here is the Triton:

    @triton.jit
    def triton_batched_hgemm_kernel(
    	a_ptr, b_ptr, c_ptr,
      M, N, K, L, 
      stride_am, stride_ak, stride_al, 
      stride_bn, stride_bk, stride_bl, 
      stride_cm, stride_cn, stride_cl, 
      BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
      GROUP_SIZE_M: tl.constexpr,
    ):
        """Triton batched half-precision GEMM kernel: C[m,n,l] = sum_k A[m,k,l] * B[n,k,l]"""
        pid = tl.program_id(axis=0)
        pid_batch = tl.program_id(axis=1)  # Batch dimension
        
        # Calculate batch offsets
        batch_offset_a = pid_batch * stride_al
        batch_offset_b = pid_batch * stride_bl  
        batch_offset_c = pid_batch * stride_cl
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m
        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        
        # Include batch offsets in pointer calculations
        a_ptrs = a_ptr + batch_offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        # transpose for GEMM (load B[K, N] pattern)
        b_ptrs = b_ptr + batch_offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        
        # We accumulate into fp32 for higher accuracy.
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            # Load the next block of A and B, mask in K
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            accumulator = tl.dot(a, b, accumulator)
            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk
        # Convert back to FP16 for output
        c = accumulator.to(tl.float16)
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        c_ptrs = c_ptr + batch_offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
        tl.store(c_ptrs, c, mask=c_mask)

    The core loop here is:

    • Divide into block_size_k chunks of K
    • For each do a masked load (so when we hit the edges we don’t bring in garbage)
    • Do the dot product for the tile into an accumulator for the result matrix
    • Advance the A and B pointers

    The CuTe kernel is, uhhh… a bit more involved. The full kernel is several hundred lines long. You can see the source in the Cutlass repo, which demonstrates some cool features like the ability to pass in an epilogue function for fusion.

    For now, lets focus on the main loop. As before, we are looping over tiles of K3 .

    The first thing we do is advance the pointers. The kernel is doing explicit pipelining of transfers from global memory, to shared memory, to registers, so we need to set wait groups do ensure all the loading has completed before we advance. We’re not actually doing loading in this section, just prepping the ground:

    for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
    	for k_block in range(num_k_block):  
          if k_block == num_k_block - 1:
        	tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
          tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
          cute.arch.cp_async_wait_group(num_smem_stages - 2)
          cute.arch.sync_threads()

    Next we kick off a copy from shared memory (smem) to registers (rmem) using cute.copy for the (future) A and B tiles.

    k_block_next = (k_block + 1) % num_k_block  # static
    cute.copy(
    	tiled_copy_s2r_A,
      	tCsA_p[None, None, k_block_next],
        tCrA_copy_view[None, None, k_block_next],
       )
    cute.copy(
    	tiled_copy_s2r_B,
        tCsB_p[None, None, k_block_next],
         tCrB_copy_view[None, None, k_block_next],
      )

    Finally, we interleave the transfers of the next A and B tiles from global to shared memory with the actual gemm operation (the equivalent of tl.dot). I will trust the folks at Nvidia that this is an optimal pattern. The pred= in there is the equivalent of masking in Triton.

       if k_block == 0:
     		if k_tile + num_smem_stages - 1 < k_tile_count:
        	    cute.copy(
    				tiled_copy_B,
    				tBgB[None, None, None, k_tile_index],
    				tBsB[None, None, None, smem_pipe_write],
    				pred=tBpB,
    			)
            k_tile_index = k_tile_index + 1
            cute.arch.cp_async_commit_group()
    		smem_pipe_write = smem_pipe_read
    		smem_pipe_read = smem_pipe_read + 1
    			
    		if smem_pipe_read == num_smem_stages:
          	    smem_pipe_read = 0

    The pipelining is explicit, which is nice for debuggability and optimization, but very manual.

    Debugging Tips

    export CUTE_DSL_LOG_TO_CONSOLE=1
    export CUTE_DSL_LOG_LEVEL=10   # up to 100
    export CUTE_DSL_PRINT_IR=1     # dump MLIR
    • cute.printf() gives you a GPU‑side printf.
    • Kernels are aggressively cached; rm ~/.cache/cutedsl if things look stale.
    • Multiple @cute.jit host functions in the same Python scope can confuse MLIR (mainly for launching kernels).
    • The control‑flow rules are strict: no return inside a kernel; initialize everything.

    If you’re exploring GPU kernels for the first time, I strongly recommend starting with Triton. When you need to really get into the weeds, or want to reuse CUTLASS building blocks, its great to have CuTe‑DSL as an option in Python (provided you’re comfortable spelunking in GPU internals).

    1. I spent a lot of time holding it wrong. Arguably, still holding it wrong. ↩︎
    2. No one knows what it means, but it’s provocative ↩︎
    3. Note the explicit unroll tag. When you really want #pragma but can’t. ↩︎
  • How to build unmaintainable kernels

    What do you need to do to get better performance and GPU efficiency out of your model? The GPU-oriented folks at Stanford recently published an early preview of the work they have been doing on the LLM generation of kernels: Surprisingly Fast AI-Generated Kernels We Didn’t Mean to Publish (Yet) – and they have a list:

    • Memory Access Optimization: improving the efficiency of data movement between different memory hierarchies (global memory, shared memory, registers) and ensuring data is accessed in a way that maximizes bandwidth and minimizes conflicts.
    • Asynchronous Operations & Latency Hiding: hide the latency of slow operations (like global memory access) by overlapping them with computation or other memory transfers
    • Data Type & Precision Optimization: using lower-precision data types (like FP16 or BF16) where possible to reduce memory bandwidth requirements, increase cache effectiveness, and potentially leverage specialized hardware units.
    • Compute & Instruction Optimization: making the arithmetic computations themselves more efficient, reducing instruction count, or leveraging specialized hardware instructions
    • Parallelism & Occupancy Enhancement: maximize the number of active warps on the Streaming Multiprocessors (SMs) to better hide latencies and improve overall throughput
    • Control Flow & Loop Optimization: reducing the overhead associated with loops, branches, and indexing calculations

    That’s a good list! In this case though, it emerged not from (just) talking with kernel experts, but also from developing a model to generate kernels:

    We have some very fast AI-generated kernels in pure CUDA-C without using libraries and DSLs such as CUTLASS and Triton. They are performing close to or in some cases even beating the standard expert-optimized production kernels shipped in PyTorch.

    They developed a very straightforward but smart pattern on structuring test-time-compute. They reason about optimizations in natural language before generating code. Then, they branch out into a tree structure of refinements for each optimization idea, to avoid loops in investigation.

    The kernels they generated were somewhere between fast, and very fast:

    Conv2D: 179.9% performance of FP32 torch.nn.Conv2D; problem size: (100, 3, 224, 224) input tensor, conv(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2)

    The team aren’t claiming this is a general solution, but just an interesting proof of possibility, which is certainly is! The walk through of how they got to the final conv2D kernel is fascinating, both in terms of human intervention and the chain of optimizations.

    The final code sample for the Conv2D kernel is included in the appendix. It uses advanced CUDA techniques that we find challenging to write ourselves! We also have more example kernels in this Github repo

    The kernel is very fast for specific shapes, on the L40s, in FP32. Its also a kernel that, by the sounds of it, the team themselves struggled a bit with. It’s very, very specialized. It’s not that a human couldn’t have built it, its that (in most cases) they wouldn’t: it’s not a priority kernel, and all that clever CUDA comes with operational overhead, ties on specific hardware, shapes and so on.

    That in itself isn’t new. If you PyTorch or XLA compile you’ll get a lot of kernels which you probably wouldn’t write, but this adds a new (and weird!) layer of non-determinism to everything. Elsewhere at Stanford, they have been looking at one of the other killers of GPU efficiency: kernel launch overhead. Most models are represented by hundreds of kernels, each of which have to be scheduled from the CPU. LLMs are generally memory-bound, small ones particularly so, and the gaps between kernel executions can end up dominating performance:

    Look Ma, No Bubbles! Designing a Low-Latency Megakernel for Llama-1B:

    In this post, we show how we can bypass this problem by merging the entire Llama-1B forward pass into a single “megakernel” that eliminates kernel boundaries altogether. Doing this achieves brr – on an H100, we use 78% of memory bandwidth and outperform existing systems by over 1.5x. (To our knowledge, this is the lowest-latency forward pass for Llama-1B in bfloat16!) In the rest of this post, we’ll walk through how and why one would do this.

    The idea of megakernels that handle all of the operations is not new, but the complexity of fusing everything together is high. Persistent kernel were popularized at the tail end of the CUDA 11 series, due to the right residency and async copy support in Ampere. They allow leaving kernels resident on an SM and having them pull a series of tiles to do their work rather than scheduling repeated launches of the same kernel. The megakernel takes this idea ever further, with multiple operations within the kernel that pulls a stream of different problems. One issue with this approach (traditionally) is register spilling: you only have so many registers available, up to 255 per thread, though with a fairly high overall limit of 64k 32-bit registers (on Hopper). That means you need to keep some data in shared memory, and efficient use of shared memory ends up being the bottleneck. The team at Stanford developed paging shared memory, with a separate reserved block for managing allocations of shared memory to individual tasks.

    This gets the CPU completely out the picture for the forward pass, but is incredibly specific to the model (in this case Llama 3.2 1B).

    Another collaboration was clearly thinking in the same direction, as they recently posted about their Mirage Persistent Kernel: a compiler for megakernels.

    our team from CMU, UW, Berkeley, NVIDIA, and Tsinghua developed Mirage Persistent Kernel (MPK) — a compiler and runtime system that automatically transforms multi-GPU LLM inference into a high-performance megakernel. MPK unlocks the benefits of end-to-end GPU fusion while requiring minimal manual effort from developers.

    The system works by building a task-graph of what they call LAX1[1] fragments, which in practice is a very short list of 14 operators. This is actually too small to represent everything they need, meaning they have to manually decompose some common ops like ReLu, but this level of decomposition gives them an ability to do some pretty complex fusions.

    The actual ops are generated thanks to Mirage’s Kernel Superoptimizer (a great name), which I think is a very intense autotuner:

    Mirage automatically searches for possible GPU kernels for attention. The search space includes existing manually designed attention kernels (e.g., FlashAttention and FlashDecoding) as special cases. It also includes other implementations that outperform today’s handwritten ones by up to 3.5x for certain use cases. The GPU kernels generated by Mirage can directly operate on PyTorch tensors and be called in your PyTorch program.

    The search is not cheap though:

    In our evaluation, Mirage takes up to 4 hours to optimize a Lax program. This optimization is a one-time cost before deployment on the target hardware.

    The aggressive decomposition allows them to have a clever verification scheme where they validate kernels on random inputs to get confidence in (approximate) numerical correctness.

    They then build a worker kernel with all the relevant operations, and schedule the optimized graph via dedicated scheduler warps. Workers are scheduled on the SMs, and report back status. The scheduler warps then decide when tasks can be enqued for execution.

    They’ve got a code example that walks through setting it up for Qwen. They recreate the model structure explicitly, generate a task graph from it, and kick off the search for optimal kernels and fusions. This avoids the need to solve the Dynamo-style problem of tracing the model!

    The resulting kernel is again heavily tied to the specific hardware and model. One thing we have found useful for investigating production problems is that the ability to ablate different parts of the compile process, running models in (basically) PyTorch eager mode. This approach leaves the darkest of black boxes to work with, and I would imagine even more terrifying PTX than the complex CUDA that the LLM kernel generation team came up with.

    Between these projects though, it feels like we are exploring the edges of what running a program on GPUs should actually look like: a combination of kernel generation and multi-level search seems almost guaranteed to yield optimizations that would be far outside the cost-benefit curve for manual implementation. What we don’t have yet is known ways to operationalize this kind of approach, but its an exciting area to watch!

    Thanks to Matt for nudging me to actually read through these papers, they’d been on my todolist for a bit!


    1. I am not sure what this stands for, but the basic ops in jax are in jax.lax, so I presume its the same source! ↩︎

  • Free-Threaded Python gets ‘supported’ status

    Huge congratulations to Thomas, Matt and Sam for sheparding through PEP 779 that moves the no-gil/free threaded python mode from experimental to supported:

    https://discuss.python.org/t/pep-779-criteria-for-supported-status-for-free-threaded-python/84319/123

    With these recommendations and the acceptance of this PEP, we as the Python developer community should broadly advertise that free-threading is a supported Python build option now and into the future, and that it will not be removed without following a proper deprecation schedule

    Having confidence in the long term of this feature is great for anyone building on it. I’m very grateful to (and feel lucky to be working with!) the many folks who have been squashing bugs and improving performance, and to the people adding support for FT Python across the ecosystem!

    The steering commitee have laid out some solid documentation and performance expectations for the ongoing work, and are setting an expectation for broad compatibility for future cpython work:

    New experimental projects within CPython must be compatible with, and should be based on the free-threading build. The SC encourages this direction to reduce engineering complexity caused by supporting both GIL and free-threaded builds

    I also appreciate the call for building out more high-level concurrency primitives: I think there are a lot of exciting projects to come as we move more of this into production!

  • Monarch: PyTorch Single Controller

    I’ve been excited for this to make it to OSS: The PyTorch team at Meta recently soft-launched Monarch on Github.

    pytorch-labs/monarch: PyTorch Single Controller

    Back in 2022, Google’s Pathways paper proposed (revisiting) a single-controller approach for managing machine learning runs. Typically, ML jobs use an SPMD (Single Program, Multiple Data) approach, distributing identical code across multiple hosts. Each host runs independently, synchronizing during collective operations. This works, as evidenced by the many large training runs in the world. It also introduces complexity, especially with pipeline parallelism where conditional logic for different ranks can clutter up your training code. Even without that, subtle issues can arise: for example, slight differences in torch.compile optimization have (in the past!) lead to deadlocks by placing collectives differently on separate nodes.

    The single-controller model simplifies this by centralizing program execution on one main node and using generic workers on the hosts that execute assigned tasks. This provides a consistent, global view of the entire computation, making it easier to get to a correct implementation of parallelisms and other distributed work. This doesn’t come for free though: the main node must efficiently manage (potentially!) thousands of GPUs without becoming a bottleneck, and existing code must adapt to this new centralized model.

    Monarch is the PyTorch team’s implementation of this single-controller concept. It provides a familiar PyTorch frontend, additional module wrappers, and a high-performance Rust-based actor system for distributing and managing work.

    The fundamental abstraction in Monarch is the Actor. Each Actor executes on their own accelerator, maintains state and behavior. Communication with other Actors is via async message passing on methods decorated with @endpoint. The nice thing about the programming model is you can interact with all of the actors in your mesh in a consistent way.

    Monarch is appealing even if you’re not GPU-rich. For instance, at home, I have a machine equipped with two (mismatched) 3090s, and Monarch allows me to run and debug jobs directly in notebooks without relying on external services.

    Installation had minor hurdles because I built from source rather than using the available pip package. Although the README specifies Python 3.10, Python 3.13 worked fine. The dependencies reference dnf (reflecting Meta’s internal Linux distro choice), so adapting commands to other Linux distributions (Ubuntu in my case) was necessary. Additionally, I had to set BINDGEN_EXTRA_CLANG_ARGS="-I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11" to resolve Rust compilation issues.

    Once installed, running Monarch’s distributed data-parallel notebook was straightforward (see: monarch/examples/notebooks/spmd_ddp.ipynb). The notebook shows that minimal code changes to the standard DDP example are required, mainly subclassing Actor (e.g., class DDPActor(Actor)), while keeping the training loop virtually identical. Monarch handles the rest, including distributed execution and debugging across multiple GPUs.

    Setting up the environment means providing the mesh configuration and launching the actors, which can be done from a cell:

    # Spawn a process mesh
    local_proc_mesh = await proc_mesh(
        gpus=WORLD_SIZE,
        env={
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12455",
        },
    )
    # Spawn our actor mesh on top of the process mesh
    ddp_actor = await local_proc_mesh.spawn("ddp_actor", DDPActor)

    I didn’t have to manually start any other services; it all happened under the hood. Triggering the run is just:

    await ddp_actor.demo_basic.call()

    Which output:

    self.rank=0 Running basic DDP example
    self.rank=1 Running basic DDP example
    self.rank=1 Finished running basic DDP example
    self.rank=0 Finished running basic DDP example

    What I find really appealing is how easy it is to execute across ranks. For example, to query for system info:

    print("Gathering system info from all ranks...")
    system_info = await ddp_actors.get_system_info.call()
    
    print("\n SYSTEM INFORMATION ACROSS ALL RANKS:")
    print("=" * 60)
    
    for point, rank_info in system_info:
        print(f"Rank {rank_info['rank']}: PID={rank_info['process_id']}, "
              f"Device={rank_info['device_name']}, "
              f"GPU Memory={rank_info['gpu_memory_allocated']/1e6:.1f}MB")
    
    print(f"\nFound {len(system_info)} ranks in the mesh")
    
    all_rank_info = [value for point, value in system_info]
    print(f"Total GPU memory across all ranks: {sum(info['gpu_memory_allocated'] for info in all_rank_info)/1e6:.1f}MB")

    Outputting:

    Gathering system info from all ranks...
    [Rank 0] System Info: PID=10519, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    [Rank 1] System Info: PID=10520, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    
     SYSTEM INFORMATION ACROSS ALL RANKS:
    ============================================================
    Rank 0: PID=10519, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    Rank 1: PID=10520, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    
    Found 2 ranks in the mesh
    Total GPU memory across all ranks: 0.1MB

    I can also stop training and dump state if I need to , making it easy to check norms and debug:

    print("Running training steps...")
    for step in range(3):
        print(f"\n--- Step {step + 1} ---")
        
        step_results = await ddp_actors.train_step.call()
        
        all_results = [value for point, value in step_results]
        
        losses = [result['loss'] for result in all_results]
        grad_norms = [result['grad_norm'] for result in all_results]
        throughputs = [result['throughput'] for result in all_results]
        
        print(f"Losses across ranks: {[f'{l:.4f}' for l in losses]}")
        print(f"Gradient norms: {[f'{g:.4f}' for g in grad_norms]}")
        print(f"Avg throughput: {sum(throughputs)/len(throughputs):.1f} samples/sec")
    --- Step 1 ---
    [Rank 1] Step 1: Loss=1.1128, GradNorm=0.3198, Time=0.241s
    [Rank 0] Step 1: Loss=1.0414, GradNorm=0.3198, Time=0.253s
    Losses across ranks: ['1.0414', '1.1128']
    Gradient norms: ['0.3198', '0.3198']
    Avg throughput: 129.8 samples/sec
    
    --- Step 2 ---
    [Rank 0] Step 2: Loss=1.1526, GradNorm=0.3096, Time=0.003s
    [Rank 1] Step 2: Loss=1.0546, GradNorm=0.3096, Time=0.003s
    Losses across ranks: ['1.1526', '1.0546']
    Gradient norms: ['0.3096', '0.3096']
    Avg throughput: 9800.9 samples/sec
    
    --- Step 3 ---
    [Rank 1] Step 3: Loss=0.9116, GradNorm=0.2243, Time=0.002s
    [Rank 0] Step 3: Loss=0.9662, GradNorm=0.2243, Time=0.002s
    Losses across ranks: ['0.9662', '0.9116']
    Gradient norms: ['0.2243', '0.2243']
    Avg throughput: 19977.5 samples/sec

    While the distributed stuff here is cool, it’s not wildly different than using a distributed framework like Ray and a little bit of setup (though I am pretty allergic to setup). What is most interesting is how this changes the programming model of PyTorch, and makes it really easy to build out distributed experiments.

    For example, if I was building a param server the sync only requires an await’d read of the weights from another object, taking advantage of the RDMA support for an efficient cop1y:

        @endpoint
        async def worker_sync_with_ps(self, param_server) -> bool:
            """Synchronize with parameter server and get RDMA handles"""
                
            self._log("Synchronizing with parameter server...")
            
            # Get RDMA buffer handles
            self.weight_buffers = await param_server.ps_get_weight_handles.call_one()
            self.gradient_buffers = await param_server.ps_get_gradient_handles.call_one()
            
            # Get metadata
            metadata = await param_server.ps_get_metadata.call_one()
            self.weight_metadata = metadata['weights']
            self.gradient_metadata = metadata['gradients']
            
            # Perform initial weight sync
            sync_time = await self._sync_weights_from_ps()
            
            self._log(f"Synchronized with parameter server (sync time: {sync_time:.3f}s)")
            return True

    Getting those weight buffers is as simple as creating the right Monarch object:

    def tensor_to_rdma_buffer(tensor: torch.Tensor) -> RDMABuffer:
        # RDMA requires 1D contiguous uint8 tensors
        byte_tensor = tensor.view(torch.uint8).flatten()
        return RDMABuffer(byte_tensor)

    For an early preview of a library, Monarch is surprisingly complete, and definitely worth a look.

    1. Not that this would do anything for my 3090s! ↩︎
  • Fused Linear Cross-Entropy

    Fused Linear Cross-Entropy is a popular optimization that combines the final linear projection and cross-entropy loss into a single operation. This fusion is very valuable for training large language models efficiently, as it can reduce memory usage significant, particularly for larger vocabularies.

    If you look at a LLM training loop, you generally see something like:

    logits = model(input_ids)
    loss = cross_entropy(logits, targets)

    And if you look at the end of the model, you’ll see something like the below, where h is the hidden state so far and output is output = nn.Linear(embed_dim, vocab_size, bias=False)

    # shape: [b, seq_len, out_dim]
    output = self.output(h)

    That final logics value can be pretty big: sequence length is long and the vocabulary size is large (128k for Llama 3, 202k for llama 4), so logits can be GB of memory: with a 16k context window, a 128k vocab, and 4k embedding dimensions even at a batch size of 1, you get 8bn entries. At BF16, that’s 4GB. You’ll also need to capture the gradient, which will give you another 4GB in the backwards.

    That set of logits has a range of values that are a bit all over the place, one for each of the possible targets.

    Cross-entropy is a loss between two probability distributions. Jay Mody has an excellent blog post breaking down softmax and CE loss

    Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it’s common to use cross entropy as a loss function for classification problems where:

    • q is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called logits, denoted as y^), that is q=softmax(y^)
    • p is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position y (our label for the correct class): pi={1i=y 0i≠y

    This means that cross-entropy simplifies to F.nll_loss(F.log_softmax(x, 1), target)

    Softmax makes our previously messy logits into a nice probability distribution where all the values are positive and sum to one. log softmax is usually used in LLMs, for numerical stability and efficiency.

    When we implement softmax, the naive implementations looks something like:

    out = torch.log(torch.exp(x) / torch.sum(torch.exp(x)))

    This isn’t numerically stable, so you want to subtract the max to avoid overflows and underflows in the exp. This is the common log-sum-exp implementation:

    x_max = torch.max(x)
    shifted_x = x - x_max
    exp_shifted = torch.exp(shifted_x)
    out = shifted_x - torch.log(torch.sum(exp_shifted)

    In general the memory and compute cost of this grows with the size, which gets painful for our hefty logits. We can instead keep a running log-sum-exp so we don’t have to deal with the whole input at once.

    lse = xs[0]
    for x in xs[1:]:
        m = torch.max(torch.stack([lse, x]))
        lse = m + torch.log(torch.exp(lse - m) + torch.exp(x - m))
    out = lse

    This is the online log-sum-exp approach, and makes our life easier! We can now compute incrementally, but we are still generating the big logits before hand.

    Fused Linear Cross-Entropy replaces the output projection, softmax and loss calculation with a single kernel that a tiles across all of it.

    This is the core of the idea: instead of computing all logits at once (which creates a massive tensor), we can:

    1. Compute logits for small chunks of the vocabulary
    2. Compute the softmax incrementally
    3. Only store the logits we need for the loss calculation

    Quoting https://github.com/mgmalek/efficient_cross_entropy

    This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.

    Roughly, the loop looks like:

    For each of the token i in the sequence, with output layer weights h

    • Compute a partial dot product si = hi dot W_tile
    • Reduce for a running max and exp-sum
    • Return only the si[targeti] needed for the loss.

    This gives you quite a lot of memory wins, which also reduce peak memory bandwidth needs. But this also introduces some potential pain!

    1. You’re fusing the final layer op into the loss, which might be defined in different places in your model code
    2. You’re accumulating, so you have to use fp32 or risk introducing numeric errors
    3. You have to write you own backwards op as well, which will generally do some extra computation, so you are paying some extra FLOPS
    4. Debugging can be harder, so you want a good recipe prior to swapping in the op
    5. May require some futzing for best implementations on different hardware.

    Actually implementing is pretty straightforward.

    @staticmethod
    def forward(ctx, h, W, target):
        B, D = h.shape
        V, _ = W.shape
        
        chunk_size = min(1024, V)
        
       # Initialize online softmax accumulators
       max_logits = torch.full((B,), -float('inf'), device=h.device, dtype=torch.float32)
       sum_exp = torch.zeros(B, device=h.device, dtype=torch.float32)
       target_logits = torch.zeros(B, device=h.device, dtype=torch.float32)
            
        # Process vocabulary in chunks
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
                
            # Compute logits for this chunk only
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Update running max
            chunk_max = logits_chunk.max(dim=1).values
            new_max = torch.maximum(max_logits, chunk_max)
                
            # Adjust previous sum_exp by exp(old_max - new_max)
            sum_exp *= torch.exp(max_logits - new_max)
            
            # Add this chunk's contribution to sum_exp
            sum_exp += torch.exp(logits_chunk - new_max.unsqueeze(1)).sum(dim=1)
            
            # Update max
            max_logits = new_max
                
            # Extract target logits if target is in this chunk
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
            is_target_in_chunk = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            target_logits += (logits_chunk * is_target_in_chunk).sum(dim=1)
        
        # Compute loss: -log(p_target) = -(target_logit - log_sum_exp)
        log_sum_exp = max_logits + torch.log(sum_exp)
        loss = log_sum_exp - target_logits
        
        # Save for backward
        ctx.save_for_backward(h, W, target, max_logits, sum_exp)
        ctx.chunk_size = chunk_size
            
        return loss.mean()

    Here we chunk the vocabulary, calculate the partial transform for the chunk h @ W_chunk.T, do online softmax and accumulate the target logits.

    The backward calculates the gradients:

    @staticmethod
    def backward(ctx, grad_output):
        h, W, target, max_logits, sum_exp = ctx.saved_tensors
        chunk_size = ctx.chunk_size
            
        B, D = h.shape
        V, _ = W.shape
            
        # Scale gradient by batch size (since we use mean reduction)
        grad_scale = grad_output / B
            
        # Initialize gradient accumulators
        grad_h = torch.zeros_like(h)
        grad_W = torch.zeros_like(W)
            
        # Process vocabulary in chunks (same as forward)
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
                
            # Recompute logits for this chunk
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Compute softmax probabilities for this chunk
            # p_i = exp(logit_i - max) / sum_exp
            probs_chunk = torch.exp(logits_chunk - max_logits.unsqueeze(1)) / sum_exp.unsqueeze(1)
                
            # Gradient w.r.t. logits: grad_logits = p - 1_{y=i}
            grad_logits_chunk = probs_chunk * grad_scale
                
            # Subtract 1 from target positions
            is_target = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            grad_logits_chunk -= is_target.float() * grad_scale
                
            # Accumulate gradients
            grad_h += grad_logits_chunk @ W_chunk
                
            grad_W[chunk_start:chunk_end, :] = grad_logits_chunk.T @ h
            
        return grad_h, grad_W, None

    In the backwards we recompute the logits for the chunks, and calculate the logits.

    This is a very simplified implementation that trades off a bunch of kernel launches, so gives up a lot of performance, but you can see the memory savings:

    Regular:
    Time: 285.18 ms
    Memory (total): 3072.0 MB
    Loss: 11.142737
    Chunked online softmax:
    Time: 470.27 ms
    Memory (total): 356.0 MB
    Loss: 11.142738

    For a more sophisticated implementation, you can look at the repo mentioned before or Liger has a good quality kernel with further optimizations. These calculate the gradients in the forward pass, then can just scale them in the backwards. This trades off a bit more memory for less of a compute hit. In general there are a few options for choosing the right point

  • Pyrefly

    https://pyrefly.org

    I’m at PyCon (mildly awkward photo thanks to Simon Willison!) and earlier had to steal some extra chairs for the Typing Summit as it was full up! There is a lot of energy and interest around type checking, thanks to Astral’s Ty and Meta’s Pyrefly projects coming in to the space recently.

    While the playground is great to try it, I wanted to see what it was like on a larger codebase I was familiar with. I decided to try TorchTune, which makes use of types, but doesn’t configure a typechecker explicitly for CI (as far as I know!), relying on the LSP to show squiggles as the main type hinting feedback (which is reasonable!)

    I tried running mypy over it with a very basic config, and time mypy

    [mypy]
    python_version = 3.13
    ignore_missing_imports = True
    warn_unused_ignores = True
    strict_optional = True
    files = .

    Unsurprisingly, there are quite a few errors!

    Found 1361 errors in 211 files (checked 485 source files)
    real 10m45.596s
    user 0m12.906s
    sys 0m0.929s

    I installed pyrefly and init’d it:

    pip install pyreflypyrefly init

    This created a pyrefly.toml containing a very minimal config:

    project_includes = ["."]
    python_version = "3.13.0"

    pyrefly check then gave me

    INFO 2,966 errors shown, 7 errors ignored, 485 modules, 7,364 transitive dependencies, 3,522,743 lines, took 47.94s (checking 34.88s; reporting 12.98s), peak memory physical 863.6 MiB

    It’s impressively fast: 10 minutes for mypy vs under 50 seconds for pyrefly. There are also a lot more errors, and it’s tricky to tell whether they’re false positive from pyrefly, skipped errors from mypy, or something else. I scoped it down to the TorchTunes KV cache module in torchtune/modules/kv_cache.py to get a better look.

    There pyrefly returns 9 errors, and mypy 17, but that’s caused by slightly different ways of capturing some of the same errors from what I can see. For example:

    k_out[:, :, self.cache_pos[:seq_len]] = k_val

    This code is doing a bit of Tensor slicing:

    • cache_pos is a max_seq_len long tensor holding absolute write positions
    • k_out is the key cache, with shape batch_size x num_heads x max_seq_len x head_dim
    • Here we’re getting a view for the part of the cache we want to update, and appending the latest values

    MyPy gives these errors:

    torchtune/modules/kv_cache.py:104: error: "Tensor" not callable [operator]
    torchtune/modules/kv_cache.py:104: error: Value of type "Tensor | Module" is not indexable [index]

    While pyrefly gives:

    torchtune/modules/kv_cache.py:104:9-46: Item assignment is not supported on Module | Tensor   Expected __setitem__ to be a callable, got BoundMethod...

    In this module cache_pos and k_cache are created by calling PyTorch’s register_buffer, which stores params for use in the state_dict but doesn’t use them for training. Buffers don’t have to be Tensors, so I am guessing the type doesn’t propagate well there. Adding explicit type declarations in the class body fixes the errors in both mypy and pyrefly.

    cache_pos: torch.Tensor
    k_cache: torch.Tensor