Author: Ian

  • Attention, Compression & Predicting the next token

    Language modelling is one of the great ideas in ML: if you train a model to accurately predict the next word in a sequence of text1, you are forcing it to learn a deep structure for human language. Because language is how we map reality, hopefully then you can do many useful things. This turned out to be right!

    The challenge with actually, you know, doing this is that text is messy. It’s sequential, variable length, and has structure, but the structure is kind of weird: the phrase “the cat, a mellow long-haired persian, sat on the mat” very clearly associates “sat” with “cat”, but the actual words are quite far away2.

    Dealing with sequential, variable length data with a fixed network is a bit of an inherent mismatch. In training you often know the sizes you’re dealing with, but at inference time it’s variable. One elegant solution to that was the Recursive Neural Net (RNN): start at the beginning, read one word at a time and keep a “hidden state” as a scratch pad to provide memory of what has come before.

    Training RNNs was painful, because now you have to backpropagate over multiple steps, and it was a minefield of vanishing and exploding gradients. The hidden state was used for two different things: the long-term memory of the whole sequence and as the key to the next word.

    Getting to Attention

    The architecture that really addressed this was the LSTM: instead of a single memory they split short and long-term memory and added activation functions to keep the gradient updates sane. They also made the updating the memory a function of the input rather than of the weights by adding learnable gates that let the model decide which parts of the input to remember, and what information from the memory to forget. This unlocked real sequence-to-sequence models, which proved immediately useful in areas like machine translation: one model reads a sequence and compresses it to a hidden state (the encoder), another generates new output based on it (the decoder).

    This solved the training stability bottleneck, and introduced a new one: compression. The entire sequence got compressed to a single hidden state, which limited how much complexity could be captured.

    Bahdanau et al. addressed that with the idea of attention in 2014. The hidden state gets updated in the encoder with each new word, so why not keep all the hidden states around? Then, have a small network score which hidden states are relevant to the current decoder state, and make a new contextualized input to the decoder that is a weighted sum of the encoder states. This was called “attention” as it allowed the model to put different amounts of focus on different parts of the input sequence.

    The new bottleneck though was throughput: to generate hidden state n, you first needed hidden state n-1. That made it hard to parallelize, which made it hard to take advantage of emerging accelerators. Luong et al first showed that you could simplify the state scoring to make it more hardware friendly, then Attention Is All You Need in 2017 stripped away the recurrent part entirely. In the Transformer architecture they got rid of the RNN and hidden state, replacing it with another version of the attention mechanism: self-attention.

    Rather than a stack of hidden states that progressively encode the state of the sequence, each incoming word is transformed at once into a contextualized representation that carries information about it and its surroundings. This was really parallelizable; you don’t need to care about previous time steps to make decisions, so you can scale the computation on GPUs and other accelerators.

    In regular attention you can think of the current decoder3 state as a query, and the various encoder hidden states as keys: the scoring function would generate a value for each pair of key and query. In self-attention, all the tokens were projected through key and query networks, and the query for each token was compared to the key of all the others. The transformer also added a value projection: in the older attention the “key” from the hidden state was both “what makes a good match” and “what information the token provides”, but in the transformer the two were decoupled.

    The new bottleneck that emerged was performance, particularly during inference. Comparing everything to everything else is an O(n2) operation. During training you can ameliorate some of that through batching, but you’re directly exposed in inference. And, unlike an RNN, increasing the sequence length (aka context length) gives you a quadratic increase in time, not linear.

    There were various attempts at addressing this one too. In “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention” back in 2020, Katharopoulos et al showed that the quadratic aspect of self-attention comes from having to materialize a big matrix to calculate the softmax for scoring. If you replace the softmax with a map-type function you can chunk the computation and get linear time performance. This was mathematically elegant, but didn’t actually work very well, so more engineering-oriented approaches like KV caching and FlashAttention were the main-stay for tackling the bottleneck.

    So why talk about this now? Because of Moonshot AI, and their excellent Kimi models. Moonshot are perhaps the frontier-est of the Chinese tiger labs, and their recent model releases have involved: Kimi Linear: An Expressive, Efficient Attention Architecture

    The architecture mixes regular, self-attention layers with Kimi Delta Attention. And Kimi Delta Attention is just the latest in a thread of evolution which goes back (sorta!) to RNNs.

    State space models

    For a long time, folks modelled control systems using state-space models. These return both an output and a state, and have a linear update function. RNNs such as LSTMs weren’t strictly state-space models in part because of their use of non-linearities: when updating the memory LSTMs used a tanh activation, for example. If you hand-wave a bit and ignore that, you’re looking at a very similar process.

    But there is a gap between hand-waving and science, and luckily someone crossed it. The benefit of that activation function was that it squashed the state into a known range and avoided the vanishing gradient issue that plagued RNNs. The key realization was that you can drop the non-linearity entirely4 as long as the weight matrix that multiplies the hidden state is well behaved (specifically, has eigenvalues close to, but less than, one).

    Much of this is in the HiPPO and S4 papers, with Albert Gu, Chris Ré and Tri Dao. This was another neat idea, which included a clever bit of linear algebra with a technique called Diagonal+Low Rank to make the state updates relative efficient, but didn’t perform as well as regular transformer models. Gu and Dao identified the challenge as those well-behaved weights that updates the hidden state. Much like with RNNs prior to LSTMs they were adding a fixed amount of information from the input to the state. In Mamba they reused the same kind of trick: adding a small network to gate the updates so the model can learn remember more, or less, depending on the specific input5.

    Then, in the Mamba 2 paper from 2024, Gu and Dao brought everything together. They showed that the 2020 style linear attention, with a decay mask, was the same as a structured state space model like Mamba 1. That means they could apply the same chunking tricks in linear attention and get much better scaling and training, but with the ability to handle long sequences Mamba had.

    The slow recreation of LSTM features in more scalable forms continued with Gated DeltaNet. The Mamba approach ‘faded’ old memories via a decay, but it couldn’t explicitly subtract information like the LSTM forget gate. Gated DeltaNet also calculated the difference (the delta) between the expected and actual state, allowing it to effectively edit the memory rather than just overwriting it6.

    Kimi Linear sped this up, and improved the fading mechanism to be per-dimension rather than a single rate across the memory:

    “Crucially, KDA parameterizes its transition dynamics with a specialized variant of the Diagonal-Plus-Low-Rank (DPLR) matrices [30, 71], enabling a bespoke chunkwise-parallel algorithm that substantially reduces computation relative to general DPLR formulations while remaining consistent with the classical delta rule. Kimi Linear interleaves KDA with periodic full attention layers in a uniform 3:1 ratio.”

    They manage to solve two birds with one stone linear algebra: They reused the DPLR trick from S4 let you take a diagonal vector for the update rate and apply it across the matrix product of a low-rank approximation for the state transition. Moonshot realized that you could replace the approximation with the K and V matrices directly, which is much more efficient, and that you could have the diagonal come from a vector of the same dimension, so you get per-channel forgetting.

    Compression & Recall

    It seems likely we will see more sophisticated mixing of different types of attention in models as labs continue improving architectures. We started with recursive models as a natural expression of the problem, moved to transformers for scale, and have been slowly integrating the two expressions together. We are still just trying to predict the next word, but it turns out the best way to do it is to remember some things, forget most things, and accept that the map is not the territory.

    Reading through the papers on this journey really highlighted how the field moves between compression and breadth of recall. Sometimes researchers get a bad rap from their engineering brethren for being disconnected from reality, but this chain of evolutions is a pragmatic one.

    You want to get the most intelligence in the model as possible. That’s done by compressing the training data into efficient, useful and general representations, but finding those representations is hard! If you hit a limit in finding them, then one approach is to simply add more knowledge: add more parameters, consider more training data, and build more of the imperfect representations to give you more options to choose from.

    MoEs, synthetic data, and various other aspects of modern model training are playing with this same trade off: represent better or represent more. After his recent HotChips talk, Noam Shazeer was asked how we can find more efficient ways of encode knowledge into parameters, closer to how the brain does it. He responded first by asking the questioner: “why are you limited on parameters?”

    1. The idea dates back to Jeff Elman, I think, who showed that training a network on this objective caused the network to learn grammar categories and other features of English. ↩︎
    2. This kind of thing is even hard for humans at sufficient lengths of text: there is a version of War & Peace in English that is largely the original (translated, natch), but normalizes all the character names as they were such a common point of confusion ↩︎
    3. In the original paper they kept the same encoder/decoder set up as with earlier models, as its eminently sensible for translation tasks. The GPT models and others demonstrated you could go decoder-only effectively. What we tend to call “prefill” these days is effectively a (causal) encoder step within the decoder model that contextualizes the input, then the “decoder” is the autoregressive generation process after. ↩︎
    4. There actually still is non-linearity, as you need it for neural networks in general but rather than doing it in the loop memory update, it happens in projection MLPs after the layer. Then in Mamba it moved into the gating, so it’s only dependent on input, not the h_{t-1} state! ↩︎
    5. And it was Orvieto and the DeepMind folks that showed that you can get the same results in an RNN without the non-linearities if you can set up the matrix right. ↩︎
    6. Part of this reason was recall, which Jamba addressed. Because the RNN approach is inherently compression based it was harder to just cut and paste sections of the context when they were relevant. Jamba mixed regular attention layers with Mamba layers, giving back the global context while still providing better scaling. The specific recall problem is really emphasized by the fact that one of the standard long context evals is the “needle in a haystack” task, where a relevant fact is hidden in a long doc and needs to be pulled out. ↩︎
  • Bulls in the bazaar

    I don’t think even the most perceptive forecaster would have identified a 90s LucasArts video format being a flashpoint for a discussion of the state of the security. We live in an age of generative AI agents rampaging through OSS though, and that seems to be what has happened.

    Open source is one of the great triumphs in loose, global coordination. In most meaningful ways, proprietary software… lost. The scale and effectiveness of open source projects consistently outstripped closed source components, across the stack, leaving proprietary software mainly existing at the application level.

    This also had the effect of shifting open source from being in contrast to corporate, top-down development of proprietary software to being deeply intertwined with it. The expectations and requirements intermingled volunteer-ish communities and profit-seeking businesses, leading to tension in several areas, including security.

    Luckily, the loving grace of the megacorps invested in things like Google’s Project Zero to provide the type of security investments that need corporate-scale backing.

    The flow for things like Project Zero look like:

    • Investigate popular projects and find real security risks before the bad guys do
    • Share a report with the project, and give them time to fix it before disclosing it
    • If the project doesn’t fix it in a certain time, disclose it so that folks can work around the issue rather than being vulnerable to it.

    That’s their mission: “make the discovery and exploitation of security vulnerabilities more difficult, and to significantly improve the safety and security of the Internet for everyone. “

    Inherently, that’s a pretty good idea as the incentive for various bad actors is:

    • Investigate popular projects and find a real security risk
    • Tell no one
    • Use it (or sell it to the national intelligence agency of choice)

    That seems worse!

    Something, however, was rotten in the state of Stallman. The folks who maintain some of the most popular package repositories recently published an open letter: Open Infrastructure is Not Free: A Joint Statement on Sustainable Stewardship that starts:

    “Not long ago, maintaining an open source project meant uploading a tarball from your local machine to a website. Today, expectations are very different”

    Today’s expectations include complex distribution infra, signed packages, deterministic builds, CI coverage across many types of hardware, and resilience against security concerns. These expectations aren’t unfounded: the PyPitfalls paper: [2507.18075] PyPitfall: Dependency Chaos and Software Supply Chain Vulnerabilities in Python, released earlier this year, took an extensive look into one particular community:

    “By analyzing the dependency metadata of 378,573 PyPI packages, we quantified the extent to which packages rely on versions with known vulnerabilities. Our study reveals that 4,655 packages have guaranteed dependencies on known vulnerabilities, and 141,044 packages allow for the use of vulnerable versions. Our findings underscore the need for enhanced security awareness in the Python software supply chain.”

    As the world centralized around open source, some aspects of the infrastructure have scaled up, but the support and investment model really didn’t.

    It’s very easy for the corporations building on OSS to treat it like an infinitely available good, especially when they don’t have to deal with the impact of their usage. Again, from the letter.

    “Automated CI systems, large-scale dependency scanners, and ephemeral container builds, which are often operated by companies, place enormous strain on infrastructure. These commercial-scale workloads often run without caching, throttling, or even awareness of the strain they impose. The rise of Generative and Agentic AI is driving a further explosion of machine-driven, often wasteful automated usage, compounding the existing challenges.”

    Because this code ends up in production for some very large products, maintainers end up as unpaid on-call. Folks with good intentions want to keep a library in healthy shape and feels the pressure of knowing that perhaps millions of people are (indirectly) depending on it. Then we mixed in AI.

    The Big Sleep

    The FFMPeg project is at the center of a storm right now about the demands from security research teams:

    Google have spent billions of dollars training Gemini, and a hefty chunk moreon a project called BigSleep: an agent to do the security research work at scale. That tool is exactly what the FFMPEG developers are reacting to, with issues like this use-after-free write in SANM process_ftch [440183164]

    The vulnerability is in a codec for the LucasArts SMUSH format, which was used in games like Grim Fandango: a security risk targeting a very narrow group of people in their 40s. In a world of human researchers, I suspect that neither attacker or researcher would have spent much time on that codec.

    For an AI agent, it’s feasible to scale up the search if you have the compute and model resources, which Google do. So now that (very real!) vulnerability is documented1. That also scales up the demands on maintainers, who don’t have the equivalent billions to do research into generative AI security patch systems.

    Security has always been asymmetric, in that it’s easier to break than to build. Scaling up discovery tips that scale off the table. The bulls are in the bazaar, finding vulnerabilities in code for rendering 1995 Rebel Assault 2 cutscenes, and the maintainers just want someone to help clean up after them. Global-scale coordination on global-scale problems remains hard.

    1. and, to be clear, fixed! ↩︎
  • Let’s all switch to FP16?

    Serious scientists use FP64 – 64 bit floating point numbers – for high precision simulations, but in the world of machine learning we got by for the longest time with FP32. The perennial quest for increased FLOPS, particularly when memory bound, made even that seem too expensive though.

    FP16 offered a reduced numeric range, but at half the size. Training with it in practice meant embracing autoscaling1 which ensured the values stayed within the range FP16 could represent. Then, Google developed BF16: it moved some of the bits to the exponent from the mantissa, so offered the same numeric range as FP32, but with reduced precision.

    Since TPUv3 back in 2018 and Ampere in 2020 it’s been finding its way into hardware and has become the go-to format for training for many models. Life was good, and training in FP16 was mainly discussed as a memory of hard winters past.

    Last week [2510.26788] Defeating the Training-Inference Mismatch via FP16 dropped and threw ML Twitter into a tither by making the argument everyone was doing Reinforcement Learning wrong and the solution… was FP16.

    “In this work, we take a step back from the complex algorithmic fixes and investigate the root cause of the numerical mismatch: floating-point precision. We identify that the modern standard for mixed-precision training, BFloat16 (BF16), is the primary culprit. While BF16 has a wide dynamic range which is excellent for stable pre-training, its low precision makes it highly susceptible to rounding errors that accumulate and eventually cause the training and inference policies to diverge.”

    The process for RL generally looks like:

    • Get a problem in a prompt
    • Do inference on the model to generate complete responses (a rollout)
    • Get a reward score for the response(s)
    • Run a training loop on the model to update weights based on the reward

    If you want to be on-policy (which generally trains better) you need the “model” in steps 2 and 4 to be identical, but the actual code running around the model in the two steps is different: for example, you don’t use a KV cache in training and you don’t store gradients in inference. But you do want to keep the weights and numerics of the model the same, else your on-policy training becomes a little bit off-policy.

    The last year of LLM research has been scaling this up, which requires managing a training and inference flow efficiently. This ongoing pressure to optimize the two paths independently leads to a risk of divergence. The paper finds that absolutely happens, and the divergence collapses the effectiveness of the learning. Unless, that is, you use FP16:

    This is precisely why switching to FP16 provides a fundamental solution. With its 10 mantissa bits, FP16 offers 8 times more precision (210 values vs. 27 values) than BF16. This higher fidelity means that the outputs of the training and inference engines are much more likely to be numerically identical. The increased precision creates a buffer that absorbs the minor implementation differences between the two engines, preventing rounding errors from accumulating and causing a policy divergence”

    The paper does an excellent job of breaking down the many reasons why this happens, but it pretty clear that FP16 is a patch: if you can’t get your numerics perfectly matched, then having more precision gives you more wiggle room.

    About a month before this the ByteDance folks posted a fantastic deep dive into RL collapse from discrepancies between training and inference: When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

    They identify a range of concerns, including straight up bugs:

    “According to this GitHub issue, we set disable_cascade_attn=True when initializing the vLLM engine and found that it significantly helps reduce the training-inference mismatch in experiments conducted on A100 GPUs.

    Many of the experiments in the FP16 vs BF16 paper were run on A100s2 , so some backlash emerged suggesting that perhaps this whole thing is just a kernel error. But as ByteDance showed, there really is a lot going on that can make things worse.

    Another example is Horace He’s recent work at Thinking Macines around a related problem: Defeating Nondeterminism in LLM Inference – Thinking Machines Lab

    “As mentioned above, one common explanation for why kernels add numbers in different orders is the “concurrency + floating point” hypothesis. The hypothesis states that if the order in which concurrent threads finish is nondeterministic and the accumulation order depends on the order in which concurrent threads finish (such as with an atomic add), our accumulation order will be nondeterministic as well.”

    Horace calls out variance in batching as the primary cause of non-determinism, and hence another quite plausible cause of inference/training mismatch

    “In other words, the primary reason nearly all LLM inference endpoints are nondeterministic is that the load (and thus batch-size) nondeterministically varies! This nondeterminism is not unique to GPUs — LLM inference endpoints served from CPUs or TPUs will also have this source of nondeterminism.”

    The meta-point is that despite being a field fundamentally based in mathematical precision we have been sloppy with numerics, pretty much everywhere.

    Ed Yang’s session in the PyTorch Conference keynote3 a couple of weeks back called this problem out from the perspective of scaling up ML infrastructure. He presented a number of solutions to try and address it, which often comes down to giving folks control over precisely how the numerics work in different parts of their model.

    While the focus here was on RL and FP16, the reality is we deal with this for training->inference in much simpler cases, as well as when moving models between different hardware. Even within generations this can be hard: one of the fun infra problems when the H100 came out was everyone discovering that the FP8 tensor cores in the Hopper used a 22-bit accumulator for intermediate calculations, which wasn’t really documented!

    The balance between speed and accuracy is often effectively made empirically: if something is faster, and works, then at some level it’s right! Reinforcement Learning mixes together different evolutionary chains of optimizations, so maybe those serious scientists with their FP64 were onto something. Not because they absolutely needed the precision, but because they needed to know they had the precision.

    We’re probably not going to switch industry wide back to FP164, but getting a better numerical grounding into the tools we use is going to make everyone’s lives easier, eventually!

    1. torch.cuda.amp and friends ↩︎
    2. Though they did verify on Hopper some as well, which some people seemed to miss ↩︎
    3. Check out the recording: Keynote: PyTorch Technical Deep Dive – Alban Desmaison, Peng Wu, Mark Saroufim & Edward Yang, Meta ↩︎
    4. Especially since most labs are doing so much with FP8 or less these days, and it would probably annoy a bunch of chip designers ↩︎
  • Helion and the evolving GPU programming model

    Helion: A High-Level DSL for Performant and Portable ML Kernels – PyTorch

    Lots of announcements around the Triton and PyTorch Conferences this week, including the 1.0 of Helion, a high-level kernel authoring DSL:

     It establishes a new layer of abstraction that bridges the user-friendly simplicity of PyTorch with the performance of a lower level language. By automating tedious and error-prone tasks like tensor indexing, memory management, and hardware-specific tuning, Helion empowers developers to focus on algorithmic logic rather than hardware-specific implementation details. Helion achieves this balance by pairing a familiar, PyTorch-centric syntax with a powerful autotuning engine that automates the complex search for optimal kernel configurations. This results in a system that delivers performance portability across hardware architectures while drastically reducing development effort. 

    There has been a bit of an explosion in kernel-authoring options recently with CuTe-DSL and CuTile from Nvidia, TileLang (as featured in recent DeepSeek releases), Gluon and TLX1 as well as evolutions to core Triton, Thunderkittens, Pallas, and others.

    There are a couple of different axes of progress occurring in GPU authoring. The first is between iterable, researcher-friendly declarative code and tightly written hardware-friendly imperative code.

    Its a classic developer-experience trade off: you let people tell you what they want to do (matmul these things then apply a softmax) or you let people tell you precisely how to do it (run this dot product on these SMs then aggregate the result).

    In general you want to stay as high-level as possible, particularly if you are experimenting with lots of different variants in a research type setting, but you may have a bound on the performance hit you can accept. A common example is you want to iterate on some attention variant, but don’t want to completely give up on the performance wins of Flash Attention.2

    Triton and others provided an interesting middle ground: it was easy enough to iterate with thanks to being embedded in Python, and was performant enough as it leveraged a compiler to automatically apply some optimizations. You are still much more imperative in a PyTorch program, but you work at a higher level of abstraction: rather than writing programs which own a thread of data, as in CUDA, you think about a tile of data. The ThunderKittens docs put this well:

    A GPU is not really a 1000×1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16×16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16×16 values.

    The next abstraction that frameworks developed was how to represent data across the memory hierarchy. To take advantage of the tensor cores you have to have data laid out in a specific way in registers. But you are better off loading data in a different order in global or shared memory. CuTe offered a big benefit by giving you types to represent layouts that could be composed, making it easier to keep track of the data movement required. Triton and others leaned on the compiler to infer the right layouts and offered higher-level APIs to copy data between stages.

    This started to get challenging on Hopper, thanks to TMA3 and the limitations of memory bandwidth, which gets to the second evolution happening in GPU kernels. How do you orchestrate the movement of data between memory layers while ensuring that data was you keep the tensor cores saturated. This involved techniques like warp specialization, where individual warps do different operations towards a shared goal. That means carefully allocating ownership over registers to avoid warps stepping on each other. Blackwell4 made this even trickier with the addition of TMEM, 2-CTA mode and other features that offered more performance but required even more careful orchestration.

    In compiler terms this is a scheduling problem and in general the industry is quite good at it! CPUs give compilers a lot of leeway to schedule operations efficiently because they have a great deal of support for out-of-order execution, well documented ops, and substantial caches. GPUs process groups of threads5 in lockstep and demand strict timing about when to insert barriers, issues async operations and so on. 

    A GPU scheduler has to tag operations to specific warp-slots in advance, assign numbers of registers to them to avoid conflicts, and sync them with barriers. It’s a lot more brittle: if we guess wrong, we can idle the Tensor cores and tank efficiency. The actual execution model is a bit of a black box too: the target for compilers (PTX) is actually further compiled to SASS by nvcc.

    Across the industry we’ve been exploring ways to be more explicit without giving way all of the operational and developer efficiency gains of higher-level languages. CuTe-DSL offers a very close-to-hardware model but in a Pythonic package6, Gluon (OpenAI) and TLX (Meta) add extensions to allow modelling pipelines in code without getting rid of the Triton compiler, TileLang builds on TVM with explicit pipeline declarations.

    One of the reasons for this variety is we don’t quite know how to express a warp-group pipelined execution model. For example, TileLang has a pipelined construct:

    for k in T.Pipelined(loop_range, num_stages=num_stages):
        MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)  # Q @ K^T
        Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
        Rescale(acc_o, scores_scale)  # Apply correction
        MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)  # P @ V

    Gluon has a descriptor that allocated resources like registers explicitly to warps:

    gl.warp_specialize(
            (config, chnls, descs, M, STAGE),     # Args to correction stage
            _attn_fwd_correction,                  # Trunk task (1 warp, 192 regs)
            (config, chnls, descs, M, STAGE),     # Args to specialized tasks
            [
                _attn_fwd_softmax0,    # 4 warps, 192 registers - Softmax tile 0
                _attn_fwd_softmax1,    # 4 warps, 192 registers - Softmax tile 1
                _attn_fwd_mma,         # 1 warp, 24 registers  - Matrix multiplies
                _attn_fwd_load,        # 1 warp, 24 registers  - TMA loads
                _attn_fwd_epilogue,    # 1 warp, 24 registers  - Store results
            ],
            [4, 4, 1, 1, 1],          # Warps per stage
            [192, 192, 24, 24, 24]    # Registers per stage
        )

    And TLX tags sections of code with contexts to indicate groupings, and also allocates resources:

    with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS,
                        registers=232,
                        replicate=NUM_MMA_GROUPS):

    They can all work and finding the best trade off is a good goal, but in all cases they do force a lot of decisions. As an example, that allocation of how many registers to use is not only operation dependent, its hardware dependent, and that makes portability between hardware (even different generations from the same vendor) expensive. Manual controls are necessary: it takes time to develop the compiler passes and heuristics to optimally divide work, so handing explicit control over7 is beneficial, particularly when serving at scale. The cost is complexity and portability. This is where Helion takes a different tack

    Anyway, so what about Helion?

    Helion instead take a point on the line above Triton, but below the ML frameworks. It focuses on just expressing what you want to happen from the tile perspective.

    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
        out[tile_m, tile_n] = acc

    Under the hood, this compiles down to Triton. You might think would be a bit of a no-op on performance, but in practical terms its often better. The reason is search: Helion can autotune across a wide number of parameters, then let you bake them into your kernel once you’ve identified good ones for your specific setup. The example in the blog posts shows how many dimensions of search need to occur:

    @helion.kernel(config=helion.Config(
        block_sizes=[64, 64, 64],
        loop_orders=[[0, 1]],
        l2_groupings=[4],
        range_unroll_factors=[0, 1],
        range_warp_specializes=[None, False],
        range_num_stages=[0, 3],
        range_multi_buffers=[None, False],
        range_flattens=[None, None],
        num_warps=8,
        num_stages=6,
        indexing='block_ptr',
        pid_type='flat'
    ))
    def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    This makes moving to different hardware as simple as redoing the search process, and offers a much more comprehensive exploration than most folks would do when hand-rolling a lower level kernel. Its a very interesting idea, and I’m glad to see more people kicking the tires!

    Low-level optimizations aren’t going away any time soon, but I’m glad to have more exploration in the kernel development space. Finding the right abstractions and right compiler approaches to keep scaling kernel development will help make it accessible to more and more people and ensure that we can evolve our kernels with the hardware.

    1. Also a Meta thing, disclaimer. ↩︎
    2. This is the logic behind FlexAttention, whch was one of the lights that guided the way towards Helion. ↩︎
    3. Fully async copies – a separate execution engine to move data ↩︎
    4. Well, datacenter blackwell. Consumer blackwell lacks TMEM and 2-CTA, so is a bit more Hopper-like programming model. I’m not sure yet what the DGX Sparks have! ↩︎
    5. Warps – 32 threads on Nvidia, or waves, 64 threads on AMD. The important distinction is that all these threads are doing the same thing: you can mask some of them out, but they have a fairly simple march through the instruction. ↩︎
    6. With a JIT! ↩︎
    7. Without making people write templated C++, sorry Ben ↩︎
  • Qwen-Image

    GPT4o’s image generation was a remarkable event, beyond the brief Ghiblification of all social media.GPT-4o offered significantly more steerability than earlier image generation models,, while offering image quality in the ball park of the best diffusion models. Qwen-Image gives a similar level of fidelity and accuracy and is an open-weights model with a pretty decent technical report: QwenLM/Qwen-Image.

    While I was fairly familiar with diffusion models, I wasn’t really familiar with the backbone of this model, the multimodal diffusion transformer (MMDiT). Rather than just look at it, I vibed up a repo with Claude Code that went step by step through the architectures, training on good old MNIST. ianbarber/diffusion-edu — which spat out this:

    This ended up being a helpful way to go step by step through the evolution of diffusion models. 

    Loss/Target

    Modern image generation really kicked off with GANs. GANs were a clever idea that exploited the fact that we are better at building classifiers than generators by using one to bootstrap the other. A generator would generate an image against a reference, the discriminator would be given the generated image and the reference and have to predict which was the real one, and both networks were scored on how well they did on their tasks. This was effective, but was challenging to train. The generator also had to start from somewhere and what it effectively started from was noise: the generate would start with fairly random output and the discriminator would learn to identify noise vs the real image. 

    The clever idea Jonathan Ho and co had with DDPM was to focus on that noise: what if instead of learning to generate images we learned to remove noise from images. In the snippet below we:

    • Pick a timestep between 0 and 1000
    • Generate some noise
    • Add an amount of noise to the training image proportional to the timestep
    • Get the model to predict the noise, given the time step
    • Calculate the loss as the mean squared error between the known noise and the predicted noise
    # Sample random timestep
    t = torch.randint(0, 1000, (B,), device=device)
    
    # Add noise to image
    eps = torch.randn_like(x0)
    alpha_t = self.alpha_schedule(t)
    xt = sqrt(alpha_t) * x0 + sqrt(1 - alpha_t) * eps
    
    # Predict the noise we just added
    eps_pred = self.model(xt, t, cond)
    
    return F.mse_loss(eps_pred, eps)  

    This pretty much worked! You needed to use quite a few timesteps (around 1000), but the model would learn to discriminate noise from data. Then, you can reverse the process to generate: given a noisy starting point, generate some noise,  predict the noise at the first timestep, remove it, increment the timestep, then repeat, each time adding some noise and removing. 

    Song et al. followed this up with DDIM, identifying that one of the reasons you need so many samples is that you are injecting new noise each generation. If you start with the noise up front when sampling you have a much more deterministic process, and can generate in more like 50 steps than 1000: 

    x = torch.randn(*x_shape)  # Start with pure noise
    
    for i in reversed(range(steps)):
      t = torch.full((B,), i/steps)
      if target == TargetMode.EPS:
        eps = model(x, t, cond)
        x = (x - eps * dt) / sqrt(1 - dt)

    The next step, in 2021, was Classifier-Free Guidance from Ho and Salimans. The clever idea was to pass a conditioning variable through to the model: for example in our MNIST example it could be the digit we are learning from. However, during training we would sometimes zero it out. This means the model learns to generate conditionally (for the specific digit) and unconditionally (just in whichever direction looks the best). 

    if cond is not None and self.cfg_dropout_prob > 0:
      mask = torch.rand(B, 1, 1) < self.cfg_dropout_prob
    
      cond = cond * ~mask  # Zero out conditioning randomly
    
    return F.mse_loss(self.model(xt, t, cond), target)

    This gets useful at generation time. When sampling, we can sample both conditionally and unconditionally, and diff out the unconditioned part: 

    # Run model twice: with and without conditioning
    cond_pred = model(x, t, cond)
    uncond_pred = model(x, t, None)
    
    # Amplify the difference
    return uncond_pred + cfg_scale * (cond_pred - uncond_pred)

    If you imagine the sampling process as denoising, this is saying there is the “best” direction, given the condition, and the “best direction” overall. By reducing the influence of the overall best direction, we get clearer steerability, and effectively the model serves as its own iterative classifier. 

    Also in 2021, Song et al published Score-Based Generative Modeling through Stochastic Differential Equations. They framed the diffusion problem as a Stochastic Differential Equation (SDE), effectively a regular differential equation dx = f(x, t)dt with an additional noise term: dx = f(x, t)dt + g(t)dw1. That latter term g(t) controls how much random noise is involved.

    The contribution from the paper is that they worked out how to reframe this without that dw noise – e.g. they turned it into an “Ordinary” Differential Equation (ODE) without the random component. Then the model can be viewed as a velocity field that ends up having a similar shape to the one modelled by the random noise version, but is deterministic.

    Salimans & Ho were not done, and proposed another improvement to loss in V-Parameterization in the Imagen paper. One of the challenges with predicting the noise (eps above) is that when you get pretty close to a finished image there isn’t much noise, so the prediction isn’t particularly good. Similarly, when you are starting with pure noise it’s predicting almost everything, so also doesn’t give much information. Predicting the noise involves estimating both the clean sample and the noise. Some reordering lets you predict a single value, the velocity field (or gradients) which combines the clean sample (alpha), the noise (eps) the time step and the current (noised) sample. By having the model predict that we can balance between predicting the image and the noise, giving better results better at extremes. 

    v_target = alpha_t * eps - sigma_t * x0
    v_pred = self.model(xt, t, cond)
    
    return F.mse_loss(v_pred, v_target)

    Finally (on the loss) we get to flow matching from folks at Meta FAIR (Flow matching) and UT Austin (Rectified Flow). Rather than making the target a blend of start and noise, why don’t we just predict the straight path to the data. Compare the v_target below to the one above: 

    t = torch.rand(B, 1, 1, 1)
    z = torch.randn_like(x0)
    
    # Straight line: xt = (1-t)*x0 + t*z
    xt = (1 - t) * x0 + t * z
    
    # Learn the velocity field pointing from noise to data
    v_target = x0 - z  # The straight path direction
    v_pred = self.model(xt, t.squeeze(), cond)
    
    return F.mse_loss(v_pred, v_target)

    Flow matching models often converge faster during training and can generate good samples with fewer steps. They also tend to have more consistent quality across different sampling step counts.

    Architecture

    All of that evolution was about the loss function and sampling, and we haven’t really discussed the model architecture itself. The original diffusion models used an approach called Unets: a series of convolutions that compressed the (latent) visual information into fewer dimensions, then expanded it back up (giving a sort of U shape). But post-ChatGPT the Transformer was ascendent, so in 2023 Peebles and Xie proposed swapping out the Unet for a stack of transformers in the Diffusion Transformers (DiT) paper. 

    class DiTTiny(nn.Module):
        def __init__(self, embed_dim=256, depth=6):
            # Patchify the image (like ViT)
            self.patch_embed = PatchEmbed(patch_size=2)
    
            # Stack of transformer blocks
            self.blocks = nn.ModuleList([
             TransformerBlock(embed_dim) for _ in range(depth)
            ])
    
        def forward(self, x, t, cond=None):
            # Convert image to patches
            x = self.patch_embed(x)  # (B, num_patches, embed_dim)
    
            # Add positional encoding
            x = x + self.pos_embed
    
            # Transform through attention layers
            for block in self.blocks:
                x = block(x, t_emb)
    
            # Reshape back to image
            return self.unpatchify(x)

    This looks like a regular transformer but with patches (segments of the image) rather than text tokens, as in ViT understanding models. The transformer block will also look pretty familiar 

    class TransformerBlock(nn.Module):
      def __init__(self, dim, heads=8, mlp_ratio=4.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
          nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
      )
    
      def forward(self, x):
        h = self.ln1(x)
        x = x + self.attn(h, h, h, need_weights=False)[0]
        x = x + self.mlp(self.ln2(x))
        
        return x

    They got good results and more importantly it was easier to scale up to more compute and larger inputs. For what it’s worth, I found DiTs a bit tricky for training on small data sets (like the mnist example), but didn’t spend much time on it, since: 

    MMDiTs emerged in 2024, and were used for Stable Diffusion 3 and Flux, largely setting the standard in terms of image quality. The idea is to process images and text in parallel with the ability to attend across each other, reminiscent of cross-encoder models.

    class MMDiTTiny(nn.Module):
        def __init__(self, img_dim=256, txt_dim=256):
            # Separate encoders for each modality
            self.img_encoder = nn.Linear(patch_dim, img_dim)
            self.txt_encoder = nn.Linear(txt_dim, txt_dim)
    
            # Joint transformer blocks
            self.blocks = nn.ModuleList([
                CrossTransformerBlock(img_dim, txt_dim) for _ in range(depth)
            ])
    
        def forward(self, img, t, txt=None):
            # Process both modalities
            img_tokens = self.img_encoder(patchify(img))
            txt_tokens = self.txt_encoder(txt) if txt is not None else None
    
            # Bidirectional attention between modalities
            for block in self.blocks:
                img_tokens, txt_tokens = block(img_tokens, txt_tokens, t)
    
            return unpatchify(img_tokens)

    MMDiT models demonstrate great prompt adherence and can handle complex requests. The bidirectional flow means text understanding improves alongside image generation.

    class CrossTransformerBlock(nn.Module):
    """Cross‑attention: query=image tokens, key/value = text tokens."""
    
        def __init__(self, dim_img, dim_txt, heads=8, mlp_ratio=4.0):
            super().__init__()
            self.q_proj = nn.Linear(dim_img, dim_img)
            self.k_proj = nn.Linear(dim_txt, dim_img)
            self.v_proj = nn.Linear(dim_txt, dim_img)
    
            self.attn = nn.MultiheadAttention(dim_img, heads, batch_first=True)
    
            self.ln_q = nn.LayerNorm(dim_img)
            self.ln = nn.LayerNorm(dim_img)
            self.mlp = nn.Sequential(
                nn.Linear(dim_img, int(dim_img*mlp_ratio)), nn.GELU(), nn.Linear(int(dim_img*mlp_ratio), dim_img)
            )
    
        def forward(self, x_img, x_txt):
            q = self.q_proj(self.ln_q(x_img))
            k = self.k_proj(x_txt)
            v = self.v_proj(x_txt)
    
            x = x_img + self.attn(q, k, v, need_weights=False)[0]
            x = x + self.mlp(self.ln(x))
    
            return x

    Here, in the cross attention block the image is used for the Query part and the text for the K and V parts of the attention. The results are combined with the image input. 

    Putting this all together, you can see the evolution of the common diffusion baselines across both scale and steerability:

    1. DDPM: Clean but slow. The baseline everything else improves on.
    2. SD1-style (UNet + Epsilon + CFG): The first practical system. Good quality, reasonable speed, follows prompts well with CFG.
    3. SD2-style (UNet + V-param + CFG): Slightly better contrast and stability, especially at high resolutions.
    4. SD3-style (MMDiT + Flow): The current state-of-the-art. Fastest training, best prompt adherence, most efficient sampling.

    Back to Qwen

    The Qwen-Image model is a good, practical example of scaling this up. It uses an existing multimodal model2 () to encode text and image inputs, a pretrained VAE3 for translating between pixel and latent space, and then as its backbone an MMDiT. The use of strong (understanding) models for encoding helps really enhance the steerability of the results in the MMDiT. 

    In the MMDiT sketch above we just concatenate image and text together. In real systems we first add the positional embeddings for the image tokens, then add on text tokens. This works, but made it difficult to adapt to different image resolutions.

    Seedream introduced Scaling RoPE4 that instead puts the image positional encoding in the middle of the image, treats the text tokens as 2D shapes [1,L], then applied 2D RoPE to both text and image tokens. This worked better, but had some problems where the positions were confusable between text and image latents, meaning the model couldn’t properly differentiate in some cases. The Qwen team updates this by implementing positional encoding across both dimensions of the text tokens, and concatenating the text along the diagonal of the image:

    This design allows MSRoPE to leverage resolution scaling advantages on the image side while maintaining functional equivalence to 1D-RoPE on the text side, thereby obviating the need to determine the optimal positional encoding for text.

    The resolution independence is important for the training recipe. The model is progressively trained  with images starting at 256×256 and increasing in steps up to 1328x, in a variety of aspect ratios. They follow it up with post-training consisting of SFT on organized, high quality image-text pairs and DPO against preference pairs judged by human raters5. Finally, they do a GRPO stage with a “reward model”: though it isn’t clear if that’s based on the aforementioned preference data or is some other secret sauce. 

    While we don’t know how GPT-image is trained, this recipe certainly gave some comparable results. I was surprised to learn that the combination of a strong text and image encoding model plus MMDiT6 gives this level of steerability and fidelity. As usual, it’s exciting to have open models and papers to bring these concepts together! 

    1.  Its w because the noise is a Weiner process, also known as standard Brownian motion. I am heavily conditioned to think of this as the motion in a cup of tea thanks to HHGTTG
      ↩︎
    2. Qwen 2.5-VL ↩︎
    3. Interestingly, a video one from Wan-2.1 ↩︎
    4. Roughly the same idea was about as Column-wise position encoding as I understand it. 
      ↩︎
    5.  The same prompt with two different seeds, and — if present — a reference image
      ↩︎
    6. And a lot of very carefully curated and programmatically generated data, to be fair
      ↩︎
  • Automation & Managerial Control

    There’s a chart making the rounds that caused Tim Lee over at Understanding AI to rewrite his recent (excellent!) article about the impact of AI on jobs. MIT’s Erik Brynjolfsson and colleagues found1 that young workers in AI-exposed jobs2 have seen their employment drop by 13% since ChatGPT arrived. Meanwhile, their older colleagues in the same fields are doing just fine.

    […] the youngest workers saw dramatic job losses—but only if they worked in occupations (like accountants or computer programmers) that were highly exposed to AI. Young workers in less exposed occupations (like nurses or construction workers) saw normal employment growth over the same period.

    From a tech industry focus, it’s a little hard to disentangle the impact of reduced hiring after layoffs 3 from the growth of AI, but likely both had an impact. AI coding agents are making it easier to complete the kind of introductory tasks that might have been left for junior engineers.

    New grads don’t just do simple tasks though, they grow and develop tacit knowledge of their company industry, begging the question is whether this is permanent disruption or temporary dislocation as the skills need shifts. As Tim calls out: 

    It’s important not to read too much into this research. Workers between the ages of 22 and 25 are a small slice of the job market, and their employment has always been more volatile than for older workers. When I graduated with a computer science degree in 2002, the economy was just emerging from the recession that followed the dot-com bubble. It was a hard time for a young adult to get their first programming job, but most of my peers eventually found work in the field.

    To give an analogy, there was a time when becoming a junior programmer meant learning how to write fast code as cycles were too important to waste. Now, writing particularly efficient code is largely the preserve of specialist, more senior people: some folks opt in to that route early because of their personal interests, but in general raw performance of code is not the blocking factor to building something valuable.

    My sense is we are seeing the same thing in terms of general “program composition”: senior folks with experience on large, collaborative projects can benefit from LLM automation as they understand how to put in the right project guardrails and how to translate needs into technical direction. Junior people are still mostly trained how to write working code, and that need has become less pressing as LLMs have proved moderately competent at it.

    Rodney Brooks, the robotics legend, made a point back in 2018 that stuck with me: it’s not automation that disrupts workers—it’s digitalization. In his article, Brooks wrote

    Digitalization is replacing old methods of sharing information or the flow of control within a processes, with computer code, perhaps thousands of different programs running on hundreds or thousands of computers, that make that flow of information or control process amenable to new variations and rapid redefinition by loading new versions of code into the network of computers.

    An example that Brooks uses is bridge toll takers. This directly happened on the Bay Bridge between San Francisco and Oakland, which used to employ toll takers in booths. Then FastTrak was added, allowing passing through without interacting with anyone, while still offering cash tolls for those without. Now, between that and direct mail to people via cameras watching license plates, the tollbooths are empty.

    LLMs also digitalize. Task descriptions and project documentation, for example, have been stored in human language: digital, but not particularly accessible to automation. Much of the work of managing a large bug tracking system has been in adding metadata that is accessible to automation. LLMs digitalize language, imperfectly to be sure, but enough to expose new swathes of work to automation.  

    High Road/Low Road

    How will companies respond? Thomas Kochan at MIT has been mapping this kind of choice for years, and describes the separation between what he called the high road and low road. 

    The language that was used to differentiate these two approaches quickly evolved to a comparison of “high road” and “low-road” business strategies and “high-performance work systems,” which viewed labor as an asset, versus “command and control” systems, which viewed labor as a cost like any other factor of production. A comparison of the business strategies of two household names, Walmart and Costco, illustrates the differences between low-road and high-road business strategies. Walmart has been extremely successful (when judged solely on the grounds of finances and shareholder value) by pursuing a business strategy best captured by its marketing tag line: “Every day low prices.” To achieve this strategy, it places top priority on minimizing and tightly controlling labor costs, discouraging long-term tenure of its “associates,” investing little in training and development, and avoiding unions at all costs. Costco’s business strategy places a higher value on product quality and customer service, and to achieve these objectives it pays higher wages, invests more in training its work force to understand and serve customer needs, and has longer tenure patterns (and thus lower turnover costs). As a result, Costco’s employees are more productive, stay with the firm longer, and have more discretion to use their time and knowledge to solve customer problems.

    Tech companies have, in the most part, been high-road employers. Employees have been an asset, and in some cases the key asset of the company. The low road though is not simply driven by cost cutting, it’s about control. Having a more fungible, replaceable workforce gives executives more options. Having more specialized, skilled workers offers the options of more flexibility in how work is done, but shifts control to the workers and away from management.

    We can see this play out in some of the post-pandemic cultural changes. There is a concept in work called deskilling, where work is atomized to improve efficiency: take something which was a skill and divide it up until it until the individual components becomes unskilled. Classic examples are in factory work, where a skilled person is replaced with an operator of a machine, or more often a series of operators of a series of machines4. This trades a higher up-front cost in terms of capital and procedure development for a lower labor cost, transferring both money but also power from workers to managers. 

    A recent article extended this concept to virtues, with the idea of “moral deskilling”. A virtue is a positive behavior, such as building responsibility or with high quality. Virtues tend to be individual qualities, things we recognize and reward in others: much of culture in a company is about inoculating virtues. That is inherently messy and the idea of systematizing virtue is appealing: move from a fuzzy, personal conception to a verifiable checklist or a rule that can be followed. This worked in a lot of cases, but it also enabled a form of deskilling: 

    Systematising virtue handed control to managers. Who, endlessly mistrusting these expert folk who were always trying to do things the expensive way, converted that mistrust into endless, endless paper work.

    It was endless because it broke every little aspect of what had been virtue into tiny components. Fearful of losing control of any scrap of virtue, managers needed to relentless check on every little task.

    If we want to see this play out in real-time we can look at the return-to-office mess in tech.  A vibrant, collaborative office culture is a good thing, and it requires a compact. Employees would deal with the misery of a commute5 (particularly in the SF bay area), but in exchange they would participate in an environment where they could learn and teach, build camaraderie and so on. 

    When the idea of a return to office happened post-pandemic, people had found pleasure and benefit in not doing the commute. When they returned, they found the offices less vibrant, the workforce more distributed, and cost-driven reductions in space making the experience harder through shortages of meeting rooms or desks.

    Compounded by a series of layoffs and a change in the prior relationship between company and employee, the in-office deal felt worse. Frustrated with the lack of the old compact, management exerted control through systems. They set required days and logged attendance through badge ins. Workers responded by treating the atomized requirements as mere requirements, not aspects of a culture: even a small percentage of folks coffee badging or trying to work from more convenient offices were visible in the empty desks, exacerbating tensions for workers “doing the right thing”. 

    Rather than analyze the problem and step back, management in many cases doubled down on systematizing: validating time at desk, logging badge out times or adding similar extra controls. This continued to take what had been a morally complex set of trade-offs and reduce it to a checklist. For many newer staff, that was the in-office experience. 

    This is the essence of the low road: prioritizing the systematized and legible over the messy, and complex, but more interesting, world of dealing with real people; prioritizing power and control over exploring new outcomes.

    One way to view what’s happening is through the lens of debt, which is one of the angles in a recent position paper that frames the future of work as an AI Safety risk. Every time a company chooses to replace junior workers with LLMs rather than training them, they’re borrowing against the future. Matt Garman of AWS was pretty clear on his position: 

    “I was at a group, a leadership group and people were telling me they’re like we think that with AI we can replace all of our junior people in our company. I was like that’s the like one the dumbest thing I’ve ever heard […] They’re probably the least expensive employees you have. They’re the most leaned into your AI tools and like how’s that going to work when you go like 10 years in the future and you have no one that has built up or learned anything.”

    But understanding something and acting on it are different things. Both the low road and high road can lead to a lot of success in business, but I do hope we can navigate this transition towards a place where the craft can be retained in software development. The question is whether enough companies will choose the messy, complex work of developing people over the appealing simplicity of trying to replace them.

    1. Canaries in the Coal Mine? Six Facts about the Recent Employment Effects of Artificial Intelligence — Stanford Digital Economy Lab ↩︎
    2. Like programming and accountancy, knowledge work fields that have a large amount of machine interaction ↩︎
    3. As well as pandemic-driven overhiring and the end of zero interest rates ↩︎
    4. Or now robots in entirely lights out factories for sufficiently high scale productions ↩︎
    5. Particularly in the SF bay area! ↩︎
  • A Primer on Post-Training

    A Primer on LLM Post-Training – PyTorch

    Very excited to see this publicly available. David moved to the PyTorch team at the start of the year, having worked on Llama, and wrote up this excellent guide for post-training internally. This is a cleaned up version of the same doc, and provides a fantastic introduction to the world of post-training for modern LLMs.

    It also includes one of my favorite perverse incentive examples:

    Note: this happens with humans too! We just call these Perverse Incentives, but they are literally the same thing. The British government, concerned about the number of venomous cobras in Delhi, offered a bounty for every dead cobra. Initially, this was a successful strategy; large numbers of snakes were killed for the reward. Eventually, however, people began to breed cobras for income.

    The real kicker in that one came when the government realized what was happening and canceled the bounty. The folks who had been breeding cobras didn’t want to look after them any more, so just released them, leading to a lot more cobras than there had been before!

  • Layouts

    You could have invented CuTe hierarchical layout (but maybe not the rest of it?) : ezyang’s blog

    Ed posted the best intro to CuTe layouts I have seen, by showing how to extrapolate them from PyTorch striding1.

    Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.)

    Relatedly, Simon Veitner put together a quite visual understanding of layouts. https://veitner.bearblog.dev/intuition-behind-hierarchical-layouts/ – the graphics are helpful once you have the baseline intuition from Ed’s post!

    1. If you’re not familiar with striding, Ed’s PyTorch Internals talk/post remains the best intro! ↩︎
  • The TPU book, on GPUs

    How to Think About GPUs | How To Scale Your Model

    The Jax “How To Scale Your Model” book is one of my favorite references for folks trying to get their head round pretraining1. It breaks down the performance characteristics of model training (often using Llama 3 as an example) in an incredibly clear way. The only slight limitation is that it is primarily focused on scaling LLMs on TPUs: interesting, but probably not your main platform target (unless you work at Deepmind). They just released a new chapter covering GPUs, and it’s also a great summary2.

    There are also plenty of mildly snarky comments about design choices to leaven the reading too:

    Takeaway: in theory, NVIDIA SHARP (available on most NVIDIA switches) should reduce the cost of an AllReduce on B bytes from about 2 * B / W to B / W. However, in practice we only see a roughly 30% improvement in bandwidth. Since pure AllReduces are fairly rare in LLMs, this is not especially useful.

    1. Though they include a chapter on inference too! ↩︎
    2. Though if you haven’t read the rest of the book it moves pretty fast – definitely best to read through the whole thing and treat this as the appendix it is intended to be! ↩︎
  • Extending Arcee’s FM context length

    Extending AFM-4.5B to 64k Context Length

    Via Nathan Lambert, an extremely fun write up of the journey to an 64k context length for Arcee’s 4.5B foundation model. There are a lot of good takeaways, but this one particularly resonated with me:

    Experimentation is Key: As in everything I write, I am unable to stress enough the importance of trying dumb things. If you try enough dumb things, eventually one of them will turn into a smart thing. Embrace the chaos.

  • Rubrics

    Rubrics

    Pre-training is about making AI correct, post-training is about making AI helpful1. That helpfulness is (primarily) shaped by reinforcement learning. RL for LLMs really took off with RLHF (RL from Human Feedback), which trained based on the score from a reward model.

    The reward model was designed to score responses based on how well they met certain preferences, and the preferences were inferred from a set of human ratings: the graders were told what to look for in pairs of responses, and the reward model was trained to predict what they would pick. This worked, but was gated on how much signal you could get into the reward model and hence how many humans you had to generate preference data.

    RLAIF (RL from AI Feedback) naturally extended this to using an LLM to make the preference picks rather than humans2. Folks also started to use LLMs in an LLM-as-Judge pattern for evaluation after training: give the model a list of criteria, and ask it to rate how well the responses meet them. 

    The next notable step was RLVR (RL with Verifiable Rewards), which uses ground-truth data to provide rewards scores instead of a model. For example, a math problem might have a defined numeric answer, or a generated proof could be verified by a dedicated theorem prover program. This turned out to work very well for code and math and lead to the O-series of OpenAI models3 and many open reasoners, particularly Deepseek R1. 

    It’s a pretty natural idea to take a verifiable reward pipeline plug in AI scoring directly: rather than a model generate preference pairs and train a separate reward model, give the model criteria and ask it how well the response satisfies them. This means instead of letting a model work out what “good code” looks like from pairs of different (but similar!) solutions to a problem, you have a model working through a checklist, asking things like “Does it have types? Does it have comments? Would your coworkers hate you if you landed this?”

    These checklists are referred to as rubrics and Snorkel have started an interesting looking blog series introducing rubrics, which offers a definition: 

    A rubric is a structured guide that spells out what “good” looks like for each response from an AI system. 

    A rubric consists of:

    • A list of criteria: Does the code compile? Does it have comments?
    • How the model performed on each criterion: “Compiles” could be yes/no. It could also be more nuanced: yes/yes with warnings/no.
    • Scoring rules that turn performance into numbers: Clean = 0. Warnings = 1. No = 2.

    In Nathan Lambert’s recent interview with Ross Taylor, Taylor calls rubrics out as an underappreciated research opportunity, particularly for agentic training:

    Rubrics are underhyped on social media – they were driving force behind projects like DeepResearch – and GenRMs are interesting but perhaps slightly overhyped.

    This caught my eye, as Moonshot leveraged rubric based rewards heavily in Kimi K2, notably using the model they were training as the judge of itself: 

    The framework operates using a Self-Critique Rubric Reward mechanism, where the model evaluates its own outputs to generate preference signals. To bootstrap K2 as a competent judge, we curated a mixture of open-source and in-house preference datasets and initialize its critic capability in the SFT stage.

    One of the core values of rubrics is that they work for both LLMs and humans. You can iterate on rubrics with people, scale them with LLMs, and spot-check LLM results with human raters to ensure reliability. 

    The paper [2507.17746] Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains formalizes them as a full peer to Verifiable Rewards. The paper sets up rubrics so each criteria is a simple pass/fail and each has a predefined importance weight. They normalize everything so the system can’t get gamed by just adding more criteria4, and then plug in the resulting score in to the RL loop5.

    Of course, you actually have to write the rubrics, which leads to a specificity versus generality tradeoff: take more time to write more rubrics or rely on fewer, more general ones. The RaR paper makes it clear that more is better:

    predefined generic rubrics substantially underperform compared to prompt-specific ones, underscoring the importance of contextualization. Rubrics that include a broader range of criteria—both positive and negative—consistently outperform those limited to essential checks, suggesting that richer evaluation signals lead to better learning.

    As you might have guessed, the solution was more LLM: use a model to generate prompt-specific rubrics:  

    For each domain, the prompt (included in Appendix H) instructs the LLM to generate 7–20 rubric items based on the complexity of the input question. Each item is assigned a categorical weight (e.g., Essential Criteria, Important Criteria) to determine its importance to a correct answer. The rubrics are designed to be fully self-contained which means that non-expert readers should be able to evaluate response quality using only the rubric. 

    This particularly benefited from having a reference answer attached to the prompt. The models do a much better job of coming up with a good rubric if provided with a (human generated) “good” answer to judge against rather than just the question/prompt. This really opens the door to 1:1 rubrics: given questions and reference answers, you can generate a scoring checklist for each one and mix it with verifiable rewards during post-training. 

    The field continues to be turtles all the way down: using LLMs to write rubrics to have LLM judges evaluate LLM training outputs. At some point, someone’s going to suggest we use rubrics to evaluate how good our rubrics are, and honestly, I’m surprised that paper doesn’t already exist6.

    1. Correct in predicting the next token, and helpful, honest and harmless, specifically. ↩︎
    2. With humans still looped in to validate that the ratings were reasonable. The human graders went from generating ratings to rating the raters. ↩︎
    3. This is the part where everyone pretends they know exactly how O1 works, but actually we’re all just pattern-matching from breadcrumbs ↩︎
    4. Else we’d risk giving more focus to problems with more rubrics, and end up with something unthinkable like a coding model that liberally sprinkles emojis everywhere ↩︎
    5. In practice, they also tried a single LLM judge that took in all criteria and weights and generated a scalar reward, which seemed to work fine. ↩︎
    6. It probably does, I’m just scared to look ↩︎
  • Constraints & Orchestrators

    I recently read a few posts that helped connect the dots on why Python is a) so successful as the lingua franca of ML b) also seems likely to be successful in the future1.

    ML code reads like one program, but runs many: CUDA kernels, vectorized CPU loops, graph compilers and a bunch of glue code moving data around and tying things together. Python has continually improved at balancing two somewhat competing challenges: constraining the hot path so compilers can optimize it and structuring an orchestration path so humans can reason about it.

    Hot Path

    constrained languages are easier to optimize by Jynn Nelson touches on this:

    we should not be asking “what language can i use everywhere for every purpose”; we should build meta-languages that allow you to easily use the right tool for the job. this is already true for regular expressions and query languages; let’s go further. i want inline futhark; inline CSS selectors; inline datalog; ffi between python and C that’s trivially easy. the easier we make it to interop, the easier it becomes to pick the right tool for the job.

    Compilers are generally going to perform better if you have regular shapes, minimal side effects, predictable memory access and so on, but you want languages to be expressive and flexible, particularly when “research” is a big part of the work. In practice, that’s precisely what happens with ML : torch.compile lowers PyTorch graphs to an IR and (often) emits Triton kernels. Being able to hand off inner-loops to specialized languages allows compilers and runtimes to optimize and target the use cases they are best at.

    While this is (somewhat) clear for GPUs or other accelerators with distinctive programming models, I think it’s also largely true for getting the best out of modern CPUs. Daniel Lemire’s SEA 2025 talk covers nearly a decade of performance work and sums it up: modern CPUs do nearly as many instructions per cycle as you can feed them. To really maximize performance you need to batch work, reduce instruction counts and vectorize. We can do some of that in the general Python2 runtime but dynamic dispatch, aliasing and side effects all make the job a lot harder. We can add speculative guards, which can be hard to reason about, or give up and lose performance. By having DSLs3 that add additional constraints we can give ourselves the ability to get much, much higher performance without scrificing the overall flow of our program.

    Orchestration Path

    Python is unusually good as an orchestrator. From a readability perspective the language is baseline very readable and as long as libraries and DSLs stay Pythonic they tend to inherit that intelligibility. The challenge with orchestration is coordinating work in such a way that your most precious resources are well utilized. The investments in Free-Threaded Python make it a lot cheaper to do concurrency, but they don’t magically fix the challenge of coordination.

    asyncio: a library with too many sharp corners covers some of the many failure modes the community have encountered with asyncio, and makes a case for Trio or ANyIO style structured concurrency that allows for manageable failure modes.

    asyncio is not a good library. It is constantly full of sharp edges everywhere with implementation details leaking and poorly designed APIs forcing end users into odd code patterns to avoid fundamental flaws in the interfaces.

    This is very much a readability version of the constraints concern on the hot path. Threads are a bad app abstraction over shared mutable state, reasoning about races and cancellation is hard, and primitives are always leaky. But threads are a perfectly fine implementation detail behind a more constrained API, like task groups, or actors, or so on.

    One area that I do think needs sustained improvement is how we debug and trace across this kind of set up: it’s been challenging even in a controlled environment to really understand how all the pieces interact in a reasonably scaled ML workload, and I imagine that problem will only get worse. But I also expect that the flexibility and breadth of Python will end up a boon there as well.

    1. Beyond just sheer momentum, of course. ↩︎
    2. Or any language! Certainly for some optimizations having a JIT for Python would (and does) make life easier. ↩︎
    3. Whether that is an embedded JIT like Triton or a library+execution engine like Polars. ↩︎
  • Overthinking Everything

    Yann was taking laps on Threads a few weeks back over a recent paper, which was one of several recently that have explored aspects of how autoregressive models do as the amount of information they are dealing with gets longer. His general complaint is that each token they generate can either push them towards the right answer or further away from it, and that the models are inherently bad at recovering if they end up too far outside the correct trajectory.

    This “more might be worse” idea shows up anywhere folks are leveraging large context windows, and one of those1 is in agentic tasks. This post summarizes some research trying to measure the fall-off in chances of succeeding as task length2 increases.

    Is there a Half-Life for the Success Rates of AI Agents? — Toby Ord

    It provides indirect evidence that what really is going on under the hood is that tasks are made up of many sequential subtasks and the chance of succeeding at the whole requires succeeding at every individual component. Moreover, this suggests that the current AI agents are not very good at recovering from earlier mistakes.

    The framing they use is a constant hazard rate: each subtask is another roll of the dice, and if you roll a failure you don’t have much chance of recovering. So more (or longer) is pretty much always worse.

    One interesting aspect is that they also investigate the human failure rate, which increases over time, but much more slowly:

    This could indicate a different scaling behaviour of success rate with time horizon for humans compared to AI agents, which would be well worth investigating and may suggest important underlying mechanisms (e.g. that the humans were better at correcting earlier failed subtasks). If human performance scales differently with task length than AI agent performance, that would be an important result, suggesting that there is a notable inefficiency in the current AI paradigm.

    They’re testing with multiple runs, so these aren’t just models hitting problems they can’t do: its models hitting problems they can’t do given the specific tokens they have generated tried before.

    Agentic use cases aren’t the only situation where a model is generating responses that add to its context window. There were a lot of early observations after the release of O1 last year that thinking for longer on easy problems does not add value. This recent paper proposes not only that but suggests that there is an inverse scaling law: more time thinking makes the model worse.

    [2507.14417] Inverse Scaling in Test-Time Compute

    More specifically, they devised some stress tests: things like counting problems in the presence of distracting information, performing a regression where there is easy-to-understand but spurious data, and so on. The performance drops as the trace length increases. Different models are more susceptible to some failure modes than other, but performance consistently drops:

    Our experiments reveal distinct failure modes across model families—Claude models are particularly vulnerable to distraction from irrelevant information, while OpenAI o-series models show greater resistance but overfit to familiar problem framings. Extended reasoning amplifies different weaknesses: models overthink simple problems, shift attention to spurious correlations, and lose focus during Deduction tasks with constraint tracking.

    In contrast, Chroma’s recent Technical Report investigates how models do on single prompts, but of increasingly long contexts.

    Context Rot: How Increasing Input Tokens Impacts LLM Performance | Chroma Research

    Unlike in the agentic case, here all of the context is passed in at once, so the model isn’t poisoning its own context window through bad choices. It is still dealing with a large amount of content where it needs to choose which parts to attend to. Traditionally the test of long context has been needle-in-a-haystack evaluations: a relevant fact is hidden at different points in a long prompt and the test evaluates whether the model can effectively pull it out.

    The Chroma folks make the test a lot more nuanced — adding distractors3 and irrelevant content in both the broader context and the question. They find that performance consistently degrades as context increases.

    More broadly, our findings point to the importance of context engineering: the careful construction and management of a model’s context window. Where and how information is presented in a model’s context strongly influences task performance

    All of these papers rhyme with LeCun’s gripe about autoregressive transformers, which is (roughly!) that they (also) have a constant hazard rate on generating the “right” token.

    This is a very active area of research though. Process-based rewards in RL training make updates on each step vs only at the end. Multi-token prediction reduces the effective generation length or number of chances of misprediction. Summarizing context effectively compresses existing tokens, also reducing error rate.

    Similarly, if you have good verifiers4 you can use beam or tree searches to explore multiple reasoning paths during generation , which can reduce the error rate, at the cost of more compute.

    The closest (LLMish) techniques to LeCun’s vision are things like the recent Hierarchical Reasoning Model that has a layer of persisting hidden state, but it’s still pretty experimental!

    As agentic and reasoning traces get longer, I’m sure we’ll see more entries documenting failure modes, and proposals for techniques to scale around them.

    1. And the one being referenced in the post! ↩︎
    2. In time — they characterize tasks based on how long it takes humans to do them, which is a good control factor ↩︎
    3. As in additional content related to the question, but that doesn’t give the answer. ↩︎
    4. Similar to process-based rewards this is somewhat pushing the problem to the ability to judge how well you are doing during the generation ↩︎
  • The Tools Are Made Up

    The Tools Are Made Up

    It has been hard to keep up with the flurry of strong agentic open-source models coming out of Chinese labs recently, including Moonshot’s Kimi K2, Z.ai’s GLM 4.5, and Qwen3-Coder1.

    Each of them have the mix of clever pre-training recipes and verifiable-rewards post-training. Notably, Kimi and GLM both use the Muon optimizer, which seems to be gaining ground among the OSS labs at least. GLM’s description of the recipe is as follows:

    Our base model undergoes several training stages. During pre-training, the model is first trained on 15T tokens of a general pre-training corpus, followed by 7T tokens of a code & reasoning corpus. After pre-training, we introduce additional stages to further enhance the model’s performance on key downstream domains. Unlike the earlier pre-training stage on large-scale universal documents, these stages leverage medium-sized domain-specific datasets, including instruction data.

    The additional stages, which they refer to as mid-training, extend the context window and help grow capabilities in specific domains. They then move to post-training, with SFT over reasoning and agentic traces followed by RL with Verified Rewards2.

    The Kimi-K2 technical report goes into more details about how to actually train for tool use. Unlike the others, Kimi is not a reasoning model so doesn’t use much in the way of extended thinking. The fact that wasn’t required to get to strong levels of tool use/agentic capability feels pretty notable to me — most of the recent3 agentic models have been built on a reasoning foundation.

    What I really found interesting from the Kimi report was the level of synthetic data that the team used. This starts in pretraining: to extend high quality data sources they rewrite it with another LLM, giving the same facts with new phrasing, instead of looping over the same “good” data for multiple epochs.

    Their approach to tool training takes this kind of idea ever further:

    We construct a comprehensive tool repository through two complementary approaches. First, we directly fetch over 3,000 real MCP (Model Context Protocol) tools from GitHub repositories, leveraging existing high-quality tool specifications. Second, we systematically evolve 82 synthetic tools through a hierarchical domain generation process. We begin with key categories (e.g., financial trading, software applications, robot control), then evolve multiple specific application domains within each category. Specialized tools are then synthesized for each domain, with clear interfaces, descriptions, and operational semantics. This evolution process produces over 20,000 synthetic tools.

    They analyze a set of real tools, generate some novel (but derivative) ones, then domain-specialize them for a lot of use cases.

    Once they have this tool zoo, the actual training loop involves:

    1. Randomly sample a subset of tools and give it to a new agent with a fresh system prompt. Generate tool-appropriate tasks with explicit success rubrics.
    2. Run an LLM-driven user simulator to drive the agent, while running the tools in sandbox that keeps state.
    3. Filter trajectories using another LLM as judge to keep only successful ones for SFT

    They’re using models at every stage to generate data and evaluate options. When it comes to the actual RL training, they are baselining in verifiable rewards wherever possible for the RL stages: They, and the Qwen folks, talk about their simulator set up for code4: thousands of sandbox environments.

    For software engineering tasks, we collect a vast amount of pull requests and issues from GitHub to build software
    development environment that consists of user prompts/issues and executable unit tests. This environment was built on a robust sandbox infrastructure, powered by Kubernetes for scalability and security. It supports over 10,000 concurrent sandbox instances with stable performance, making it ideal for both competitive coding and software engineering tasks

    The combination of very sophisticated synthetic data and operationally intense sandboxes seem like table stakes for the current agentic game, and one which a lot of labs have figured out. Feels very promising for a growth in capabilities of these models over time, particularly as we work out how best to distill them down to smaller sizes for inference.

    1. Which seems a very solid model, but they haven’t released a lot of extra details about how they got there. One interesting component of the release though was that they forked Gemini CLI to make a qwen-code tool that works with any OpenAI compatible API, and I had some success locally plugging it into the smaller Qwen3 (non-coder) releases in case you were looking for some offline agentic capabilities! ↩︎
    2. Then GLM is distilled between the RL and base version of the model, which apparently helps generalize. This seems like a fun and relatively simple way of smoothing out the learning. ↩︎
    3. Though Claude 3.5 wasn’t, and that is really the trend-setter here I guess! ↩︎
    4. And other tasks that allow fully verifiable rewards. They use other models to score softer domains like creative writing. ↩︎

  • PyTorch Conference 2025

    The schedule is up for the 2025 edition of the PyTorch conference, which is now at the Moscone West in San Francisco.

    https://events.linuxfoundation.org/pytorch-conference/program/schedule/

    There are a lot of great sessions, but I’ll highlight some I personally find particularly interesting:

    Post-Training: Clearly a big theme this year, with some interesting talks from multiple groups:

    General Training

    Kernel development

    Compilers

    Inference

    I’m looking forward to October!