Tag: llm

  • The model can probably write the code

    The current vibes in software engineering are a mix of crushing despair at years of accumulated personal skills being displaced by the CEO prompting some stuff, and crushing despair at years of corporate investment in an existing codebase that isn’t vibe-y enough. People worry whether the models will be effective in their programming language of choice, not on some general benchmarks.

    One angle to approach that is to ask how well the language is covered by the distribution of the training data1. An interesting paper the other day gave a pretty clear idea of how to check: 1-shot some prompts against the base model and see if they ever get it right. Getting access to base models is not always possible, but you can certainly call the post-trained models with roughly the same idea: no tools, no iterations, just generate this program.

    To try this, I2 wrote up 20 project-euler like3 puzzles of varying difficulties and had a few different models YOLO solutions in several languages. These ranged from common ones like Python to fairly rare ones like Zig and Hack.

    After validating all the solutions, we can calculate some stats using pass@k: in k trials, how often did the model solve the problem. Here’s some stats for pass@1: what % of the time can you expect the model to one-shot the solution:

    LangGPT-4.1 MiniGemini 3 FlashOLMo 3.1Kimi K2.5GLM-5
    Python.93.99.72.97.98
    Type Script.941.00.43.95.95
    Go.95.91.46.86.86
    Rust.89.94.43.95.95
    Kotlin.90.99.29.91.93
    OCaml.76.86.08.94.90
    Zig.14.55.00.79.88
    Hack.43.76.05.47.68

    And here is the same thing for pass@128: what is the chance it is right at least once in 128 samples:

    LangGPT-4.1 MiniGemini 3 FlashOLMo 3.1Kimi K2.5GLM-5
    Python1.001.00.951.001.00
    Type Script1.001.00.901.001.00
    Go1.001.00.851.001.00
    Rust.951.00.881.001.00
    Kotlin1.001.00.591.001.00
    OCaml.981.00.381.001.00
    Zig.491.00.051.001.00
    Hack.991.00.461.001.00

    To make that a bit more visual, here is a per-language chart for GPT-4.1-mini:

    Line graph showing pass@k curves for various programming languages with k (number of attempts) on the x-axis and pass rate averaged across problems on the y-axis. Languages include Python, TypeScript, Go, Rust, Kotlin, OCaml, Zig, and Hack.

    Given enough chances GPT 4.1-mini solves all the problems, in almost all the languages. Of course, we don’t actually know what GPT 4 was trained on, but we do know what OlMo 3.1 was trained on, thanks to the wonderful folks at AI2. That means we can see how much code-specific data for each language there was4:

    LanguageCode Corpus (GB)Est. Tokens (B)Category
    Python60.4017.3High-resource
    TypeScript26.527.6High-resource
    Go23.786.8High-resource
    Rust9.112.6Medium-resource
    Kotlin5.681.6Medium-resource
    OCaml1.030.29Low-resource
    Zig0.180.05Low-resource
    Hack0.000.00Very-low-resource

    There is a pretty decent correlation between the presence of training data and the pass@k rates. But, importantly, its not 1: despite Hack having no StarCoder data and Zig negligible, the model clearly does know at least something about them. Given enough chances it has a decent chance at coming up with the correct answer for Hack, and a non-zero one for Zig:

    Line graph depicting the relationship between training data volume and average pass@k scores for various programming languages, including Python, Rust, Go, and Zig, with different markers representing pass@1, pass@10, and pass@128 metrics.

    We have seen for human language that models learn a language substrate, enabling them to perform strongly even on tasks they haven’t seen such as translating between unseen language pairs. I suspect something similar happens with code: despite the language differences there is a logical programming substrate, and the model doesn’t need much exposure to the language in order to generalize to it.

    Once you start giving the model multiple attempts, it gets into the right region quickly for the high-resource languages: with GPT-4.1 mini, Python, TypeScript, Go and Kotlin saturate at k=10. The less-common languages continue to rise: the model can write valid OCaml or Zig or Hack but need more attempts to stumble into the right region.

    Thinking models flatten the curve substantially. Kimi K2.5 and GLM 5 both use high effort by default5, and that appears to give them multiple bites at the apple from internally exploring and self-correcting. By k=10 the models saturate all problems on all languages, though at the cost of a remarkable number of tokens6!

    It’s also instructive to see the ways in the which the models get it wrong. There were four patterns that showed up:

    1. Ecosystem: One problem involved a sum of very large digits. GPT-4.1 Mini regularly used num::BigUint. This is a crate, not a standard language feature, and in an agentic loop would probably be a very valid choice but doesn’t strictly work. In contrast, GLM-5, a thinking model, implements digit-by-digit multiplication from scratch with Vec<u32>.
    2. API confusion: The model knows roughly what the code should look like, but chooses the wrong API. For example, OlMo generated while ... do ... in mixing OCaml’s while...do...done loop with Haskell’s do notation and OCaml’s let...in binding.
    3. Surface-form invention: The model has a sense of how things stylistically look in the language, but doesn’t know the real API. GLM occasionally writes Zig with invented functions: std.mem.Allocator.alloc(usize, limit) (Allocator is a type, not a callable) or @intCast(usize, limit), which actually was valid syntax in earlier versions.
    4. Systematic convention gaps: Models would regularly put in <?hh for the hack samples, which broke in modern Hack.

    My takeaway from this is that models learn to code, not just to reproduce syntax. That means you can almost certainly post-train or prompt your way out of most programming language problems with any frontier model: while some models were still pretty poor at Zig even with a lot of tries, Gemini most certainly was not. I doubt the folks at GDM spent a whole lot of time on Zig evals7.

    A well pre-trained model has broad capabilities in programming, and it’s mostly a case of eliciting them rather than having to teach them.

    1. I’m going to take as a given that models are good at generalizing within the distribution of their training data, and poor at generalizing outside it. This is not settled! Reasonable people can disagree! But, it’s a decent starting point. ↩︎
    2. Claude. ↩︎
    3. Not actually project Euler. I confirmed that the models never respond with an actual Euler puzzle answer in the incorrect ones, so I’m fairly (this is not good science) sure it wasn’t memorization. ↩︎
    4. OLMo’s full training corpus (Dolma v1.7) includes a massive web crawl in addition to code-specific data from StarCoder, so the 0.00 GB for Hack means “absent from code specific training ” not “absent from all training data”. Hack documentation and other content are almost certainly present in the web crawl portion. ↩︎
    5. Gemini also reasons, but the 2.5 Flash model was doing minimal reasoning when answering.
      ↩︎
    6. Somehow averaging over 3k per sample for GLM, I say while ruefully staring at my OpenRouter bill. ↩︎
    7. By posting this on the internet I am guaranteed to be corrected, at length, by a Googler ↩︎
  • Do MoEs Think Different?

    When I was writing recently about MoEs I was focused mostly on the architectural reasons that we use them. One thing I hadn’t considered is that they might actually be better at learning as well.

    Meanwhile, Deconstructing Pre-training: Knowledge Attribution Analysis in MoE and Dense Models

    Our findings reveal that MoE architectures form a low entropy backbone of consistently reinforced neurons, which leads to an early consolidation of their importance profiles and, in turn, underpins their functional robustness. This resilience, stemming from more distributed knowledge storage, contrasts with the greater brittleness and knowledge concentration in the dense model. These phenomena collectively demonstrate that architectural sparsity is not merely a computational shortcut but also acts as a useful inductive bias that fosters stable and robust learning

    To land that somewhere between academic prose and GPT-speak1 the results of the paper are suggesting that MoEs learn more effectively, and store their core knowledge more robustly.

    They measure this with Log-Probability Increase (LPI), which lets you estimate how much each column in the output projection for a layer in the model contributes to the final score. It gives you a sense of how much smarter the model gets from that specific chunk of the weights2. They track this “neuron importance” measure over multiple checkpoints using the (very!) open models from AI2, OLMo-7B and OLMoE-1B-7B.

    In the MoE the set of important weights is both more stable and stabilizes earlier in training: the model develops a core of understanding and builds on that. This might mean MoE training is genuinely more effective than dense. The dense model is regularly thrashing its core understanding as updates come in, while the MoE protects it and lets the model focus more on nuance.

    Or! It might be entirely an artifact of model differences. As the authors note the two models are quite different: different training data sets, different lengths of training, and different depths (16 vs 32 layers), as well as, you know, being an MoE or not. Finally, the actual LPI version they use3, Gated-LPI, bakes in the MoE routing. It’s not totally clear whether we are seeing “neurons that matter”, or mostly seeing “routing patterns that matter”.

    I do think4 this is likely showing something interesting, even with some skepticism. The “smearing” of knowledge across weights is how I described what we are trying to avoid with MoEs, and it may be useful to have a more mechanistic understanding of how that actually happens. The authors observe that the stability curve rises, drops and consolidates. Even if this is just an artifact of routing, it’s quite possible there is a critical phase in the training where that routing locks-in.

    If that idea is right, we might already be shaping that phase. The load-balancing tricks that made MoEs practical could be doing double duty as scaffolds for learning.

    1. Sparsity is not just a shortcut — it’s crucial to learning ↩︎
    2. For a given prompt. They actually use some fairly advanced evals for this, rather than the general basic benchmarks ↩︎
    3. And created, to make it plausible to do this work! ↩︎
    4. Do not draw any research conclusions based on this website ↩︎
  • Attention, Compression & Predicting the next token

    Language modelling is one of the great ideas in ML: if you train a model to accurately predict the next word in a sequence of text1, you are forcing it to learn a deep structure for human language. Because language is how we map reality, hopefully then you can do many useful things. This turned out to be right!

    The challenge with actually, you know, doing this is that text is messy. It’s sequential, variable length, and has structure, but the structure is kind of weird: the phrase “the cat, a mellow long-haired persian, sat on the mat” very clearly associates “sat” with “cat”, but the actual words are quite far away2.

    Dealing with sequential, variable length data with a fixed network is a bit of an inherent mismatch. In training you often know the sizes you’re dealing with, but at inference time it’s variable. One elegant solution to that was the Recursive Neural Net (RNN): start at the beginning, read one word at a time and keep a “hidden state” as a scratch pad to provide memory of what has come before.

    Training RNNs was painful, because now you have to backpropagate over multiple steps, and it was a minefield of vanishing and exploding gradients. The hidden state was used for two different things: the long-term memory of the whole sequence and as the key to the next word.

    Getting to Attention

    The architecture that really addressed this was the LSTM: instead of a single memory they split short and long-term memory and added activation functions to keep the gradient updates sane. They also made the updating the memory a function of the input rather than of the weights by adding learnable gates that let the model decide which parts of the input to remember, and what information from the memory to forget. This unlocked real sequence-to-sequence models, which proved immediately useful in areas like machine translation: one model reads a sequence and compresses it to a hidden state (the encoder), another generates new output based on it (the decoder).

    This solved the training stability bottleneck, and introduced a new one: compression. The entire sequence got compressed to a single hidden state, which limited how much complexity could be captured.

    Bahdanau et al. addressed that with the idea of attention in 2014. The hidden state gets updated in the encoder with each new word, so why not keep all the hidden states around? Then, have a small network score which hidden states are relevant to the current decoder state, and make a new contextualized input to the decoder that is a weighted sum of the encoder states. This was called “attention” as it allowed the model to put different amounts of focus on different parts of the input sequence.

    The new bottleneck though was throughput: to generate hidden state n, you first needed hidden state n-1. That made it hard to parallelize, which made it hard to take advantage of emerging accelerators. Luong et al first showed that you could simplify the state scoring to make it more hardware friendly, then Attention Is All You Need in 2017 stripped away the recurrent part entirely. In the Transformer architecture they got rid of the RNN and hidden state, replacing it with another version of the attention mechanism: self-attention.

    Rather than a stack of hidden states that progressively encode the state of the sequence, each incoming word is transformed at once into a contextualized representation that carries information about it and its surroundings. This was really parallelizable; you don’t need to care about previous time steps to make decisions, so you can scale the computation on GPUs and other accelerators.

    In regular attention you can think of the current decoder3 state as a query, and the various encoder hidden states as keys: the scoring function would generate a value for each pair of key and query. In self-attention, all the tokens were projected through key and query networks, and the query for each token was compared to the key of all the others. The transformer also added a value projection: in the older attention the “key” from the hidden state was both “what makes a good match” and “what information the token provides”, but in the transformer the two were decoupled.

    The new bottleneck that emerged was performance, particularly during inference. Comparing everything to everything else is an O(n2) operation. During training you can ameliorate some of that through batching, but you’re directly exposed in inference. And, unlike an RNN, increasing the sequence length (aka context length) gives you a quadratic increase in time, not linear.

    There were various attempts at addressing this one too. In “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention” back in 2020, Katharopoulos et al showed that the quadratic aspect of self-attention comes from having to materialize a big matrix to calculate the softmax for scoring. If you replace the softmax with a map-type function you can chunk the computation and get linear time performance. This was mathematically elegant, but didn’t actually work very well, so more engineering-oriented approaches like KV caching and FlashAttention were the main-stay for tackling the bottleneck.

    So why talk about this now? Because of Moonshot AI, and their excellent Kimi models. Moonshot are perhaps the frontier-est of the Chinese tiger labs, and their recent model releases have involved: Kimi Linear: An Expressive, Efficient Attention Architecture

    The architecture mixes regular, self-attention layers with Kimi Delta Attention. And Kimi Delta Attention is just the latest in a thread of evolution which goes back (sorta!) to RNNs.

    State space models

    For a long time, folks modelled control systems using state-space models. These return both an output and a state, and have a linear update function. RNNs such as LSTMs weren’t strictly state-space models in part because of their use of non-linearities: when updating the memory LSTMs used a tanh activation, for example. If you hand-wave a bit and ignore that, you’re looking at a very similar process.

    But there is a gap between hand-waving and science, and luckily someone crossed it. The benefit of that activation function was that it squashed the state into a known range and avoided the vanishing gradient issue that plagued RNNs. The key realization was that you can drop the non-linearity entirely4 as long as the weight matrix that multiplies the hidden state is well behaved (specifically, has eigenvalues close to, but less than, one).

    Much of this is in the HiPPO and S4 papers, with Albert Gu, Chris Ré and Tri Dao. This was another neat idea, which included a clever bit of linear algebra with a technique called Diagonal+Low Rank to make the state updates relative efficient, but didn’t perform as well as regular transformer models. Gu and Dao identified the challenge as those well-behaved weights that updates the hidden state. Much like with RNNs prior to LSTMs they were adding a fixed amount of information from the input to the state. In Mamba they reused the same kind of trick: adding a small network to gate the updates so the model can learn remember more, or less, depending on the specific input5.

    Then, in the Mamba 2 paper from 2024, Gu and Dao brought everything together. They showed that the 2020 style linear attention, with a decay mask, was the same as a structured state space model like Mamba 1. That means they could apply the same chunking tricks in linear attention and get much better scaling and training, but with the ability to handle long sequences Mamba had.

    The slow recreation of LSTM features in more scalable forms continued with Gated DeltaNet. The Mamba approach ‘faded’ old memories via a decay, but it couldn’t explicitly subtract information like the LSTM forget gate. Gated DeltaNet also calculated the difference (the delta) between the expected and actual state, allowing it to effectively edit the memory rather than just overwriting it6.

    Kimi Linear sped this up, and improved the fading mechanism to be per-dimension rather than a single rate across the memory:

    “Crucially, KDA parameterizes its transition dynamics with a specialized variant of the Diagonal-Plus-Low-Rank (DPLR) matrices [30, 71], enabling a bespoke chunkwise-parallel algorithm that substantially reduces computation relative to general DPLR formulations while remaining consistent with the classical delta rule. Kimi Linear interleaves KDA with periodic full attention layers in a uniform 3:1 ratio.”

    They manage to solve two birds with one stone linear algebra: They reused the DPLR trick from S4 let you take a diagonal vector for the update rate and apply it across the matrix product of a low-rank approximation for the state transition. Moonshot realized that you could replace the approximation with the K and V matrices directly, which is much more efficient, and that you could have the diagonal come from a vector of the same dimension, so you get per-channel forgetting.

    Compression & Recall

    It seems likely we will see more sophisticated mixing of different types of attention in models as labs continue improving architectures. We started with recursive models as a natural expression of the problem, moved to transformers for scale, and have been slowly integrating the two expressions together. We are still just trying to predict the next word, but it turns out the best way to do it is to remember some things, forget most things, and accept that the map is not the territory.

    Reading through the papers on this journey really highlighted how the field moves between compression and breadth of recall. Sometimes researchers get a bad rap from their engineering brethren for being disconnected from reality, but this chain of evolutions is a pragmatic one.

    You want to get the most intelligence in the model as possible. That’s done by compressing the training data into efficient, useful and general representations, but finding those representations is hard! If you hit a limit in finding them, then one approach is to simply add more knowledge: add more parameters, consider more training data, and build more of the imperfect representations to give you more options to choose from.

    MoEs, synthetic data, and various other aspects of modern model training are playing with this same trade off: represent better or represent more. After his recent HotChips talk, Noam Shazeer was asked how we can find more efficient ways of encode knowledge into parameters, closer to how the brain does it. He responded first by asking the questioner: “why are you limited on parameters?”

    1. The idea dates back to Jeff Elman, I think, who showed that training a network on this objective caused the network to learn grammar categories and other features of English. ↩︎
    2. This kind of thing is even hard for humans at sufficient lengths of text: there is a version of War & Peace in English that is largely the original (translated, natch), but normalizes all the character names as they were such a common point of confusion ↩︎
    3. In the original paper they kept the same encoder/decoder set up as with earlier models, as its eminently sensible for translation tasks. The GPT models and others demonstrated you could go decoder-only effectively. What we tend to call “prefill” these days is effectively a (causal) encoder step within the decoder model that contextualizes the input, then the “decoder” is the autoregressive generation process after. ↩︎
    4. There actually still is non-linearity, as you need it for neural networks in general but rather than doing it in the loop memory update, it happens in projection MLPs after the layer. Then in Mamba it moved into the gating, so it’s only dependent on input, not the h_{t-1} state! ↩︎
    5. And it was Orvieto and the DeepMind folks that showed that you can get the same results in an RNN without the non-linearities if you can set up the matrix right. ↩︎
    6. Part of this reason was recall, which Jamba addressed. Because the RNN approach is inherently compression based it was harder to just cut and paste sections of the context when they were relevant. Jamba mixed regular attention layers with Mamba layers, giving back the global context while still providing better scaling. The specific recall problem is really emphasized by the fact that one of the standard long context evals is the “needle in a haystack” task, where a relevant fact is hidden in a long doc and needs to be pulled out. ↩︎
  • Let’s all switch to FP16?

    Serious scientists use FP64 – 64 bit floating point numbers – for high precision simulations, but in the world of machine learning we got by for the longest time with FP32. The perennial quest for increased FLOPS, particularly when memory bound, made even that seem too expensive though.

    FP16 offered a reduced numeric range, but at half the size. Training with it in practice meant embracing autoscaling1 which ensured the values stayed within the range FP16 could represent. Then, Google developed BF16: it moved some of the bits to the exponent from the mantissa, so offered the same numeric range as FP32, but with reduced precision.

    Since TPUv3 back in 2018 and Ampere in 2020 it’s been finding its way into hardware and has become the go-to format for training for many models. Life was good, and training in FP16 was mainly discussed as a memory of hard winters past.

    Last week [2510.26788] Defeating the Training-Inference Mismatch via FP16 dropped and threw ML Twitter into a tither by making the argument everyone was doing Reinforcement Learning wrong and the solution… was FP16.

    “In this work, we take a step back from the complex algorithmic fixes and investigate the root cause of the numerical mismatch: floating-point precision. We identify that the modern standard for mixed-precision training, BFloat16 (BF16), is the primary culprit. While BF16 has a wide dynamic range which is excellent for stable pre-training, its low precision makes it highly susceptible to rounding errors that accumulate and eventually cause the training and inference policies to diverge.”

    The process for RL generally looks like:

    • Get a problem in a prompt
    • Do inference on the model to generate complete responses (a rollout)
    • Get a reward score for the response(s)
    • Run a training loop on the model to update weights based on the reward

    If you want to be on-policy (which generally trains better) you need the “model” in steps 2 and 4 to be identical, but the actual code running around the model in the two steps is different: for example, you don’t use a KV cache in training and you don’t store gradients in inference. But you do want to keep the weights and numerics of the model the same, else your on-policy training becomes a little bit off-policy.

    The last year of LLM research has been scaling this up, which requires managing a training and inference flow efficiently. This ongoing pressure to optimize the two paths independently leads to a risk of divergence. The paper finds that absolutely happens, and the divergence collapses the effectiveness of the learning. Unless, that is, you use FP16:

    This is precisely why switching to FP16 provides a fundamental solution. With its 10 mantissa bits, FP16 offers 8 times more precision (210 values vs. 27 values) than BF16. This higher fidelity means that the outputs of the training and inference engines are much more likely to be numerically identical. The increased precision creates a buffer that absorbs the minor implementation differences between the two engines, preventing rounding errors from accumulating and causing a policy divergence”

    The paper does an excellent job of breaking down the many reasons why this happens, but it pretty clear that FP16 is a patch: if you can’t get your numerics perfectly matched, then having more precision gives you more wiggle room.

    About a month before this the ByteDance folks posted a fantastic deep dive into RL collapse from discrepancies between training and inference: When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

    They identify a range of concerns, including straight up bugs:

    “According to this GitHub issue, we set disable_cascade_attn=True when initializing the vLLM engine and found that it significantly helps reduce the training-inference mismatch in experiments conducted on A100 GPUs.

    Many of the experiments in the FP16 vs BF16 paper were run on A100s2 , so some backlash emerged suggesting that perhaps this whole thing is just a kernel error. But as ByteDance showed, there really is a lot going on that can make things worse.

    Another example is Horace He’s recent work at Thinking Macines around a related problem: Defeating Nondeterminism in LLM Inference – Thinking Machines Lab

    “As mentioned above, one common explanation for why kernels add numbers in different orders is the “concurrency + floating point” hypothesis. The hypothesis states that if the order in which concurrent threads finish is nondeterministic and the accumulation order depends on the order in which concurrent threads finish (such as with an atomic add), our accumulation order will be nondeterministic as well.”

    Horace calls out variance in batching as the primary cause of non-determinism, and hence another quite plausible cause of inference/training mismatch

    “In other words, the primary reason nearly all LLM inference endpoints are nondeterministic is that the load (and thus batch-size) nondeterministically varies! This nondeterminism is not unique to GPUs — LLM inference endpoints served from CPUs or TPUs will also have this source of nondeterminism.”

    The meta-point is that despite being a field fundamentally based in mathematical precision we have been sloppy with numerics, pretty much everywhere.

    Ed Yang’s session in the PyTorch Conference keynote3 a couple of weeks back called this problem out from the perspective of scaling up ML infrastructure. He presented a number of solutions to try and address it, which often comes down to giving folks control over precisely how the numerics work in different parts of their model.

    While the focus here was on RL and FP16, the reality is we deal with this for training->inference in much simpler cases, as well as when moving models between different hardware. Even within generations this can be hard: one of the fun infra problems when the H100 came out was everyone discovering that the FP8 tensor cores in the Hopper used a 22-bit accumulator for intermediate calculations, which wasn’t really documented!

    The balance between speed and accuracy is often effectively made empirically: if something is faster, and works, then at some level it’s right! Reinforcement Learning mixes together different evolutionary chains of optimizations, so maybe those serious scientists with their FP64 were onto something. Not because they absolutely needed the precision, but because they needed to know they had the precision.

    We’re probably not going to switch industry wide back to FP164, but getting a better numerical grounding into the tools we use is going to make everyone’s lives easier, eventually!

    1. torch.cuda.amp and friends ↩︎
    2. Though they did verify on Hopper some as well, which some people seemed to miss ↩︎
    3. Check out the recording: Keynote: PyTorch Technical Deep Dive – Alban Desmaison, Peng Wu, Mark Saroufim & Edward Yang, Meta ↩︎
    4. Especially since most labs are doing so much with FP8 or less these days, and it would probably annoy a bunch of chip designers ↩︎
  • Rubrics

    Rubrics

    Pre-training is about making AI correct, post-training is about making AI helpful1. That helpfulness is (primarily) shaped by reinforcement learning. RL for LLMs really took off with RLHF (RL from Human Feedback), which trained based on the score from a reward model.

    The reward model was designed to score responses based on how well they met certain preferences, and the preferences were inferred from a set of human ratings: the graders were told what to look for in pairs of responses, and the reward model was trained to predict what they would pick. This worked, but was gated on how much signal you could get into the reward model and hence how many humans you had to generate preference data.

    RLAIF (RL from AI Feedback) naturally extended this to using an LLM to make the preference picks rather than humans2. Folks also started to use LLMs in an LLM-as-Judge pattern for evaluation after training: give the model a list of criteria, and ask it to rate how well the responses meet them. 

    The next notable step was RLVR (RL with Verifiable Rewards), which uses ground-truth data to provide rewards scores instead of a model. For example, a math problem might have a defined numeric answer, or a generated proof could be verified by a dedicated theorem prover program. This turned out to work very well for code and math and lead to the O-series of OpenAI models3 and many open reasoners, particularly Deepseek R1. 

    It’s a pretty natural idea to take a verifiable reward pipeline plug in AI scoring directly: rather than a model generate preference pairs and train a separate reward model, give the model criteria and ask it how well the response satisfies them. This means instead of letting a model work out what “good code” looks like from pairs of different (but similar!) solutions to a problem, you have a model working through a checklist, asking things like “Does it have types? Does it have comments? Would your coworkers hate you if you landed this?”

    These checklists are referred to as rubrics and Snorkel have started an interesting looking blog series introducing rubrics, which offers a definition: 

    A rubric is a structured guide that spells out what “good” looks like for each response from an AI system. 

    A rubric consists of:

    • A list of criteria: Does the code compile? Does it have comments?
    • How the model performed on each criterion: “Compiles” could be yes/no. It could also be more nuanced: yes/yes with warnings/no.
    • Scoring rules that turn performance into numbers: Clean = 0. Warnings = 1. No = 2.

    In Nathan Lambert’s recent interview with Ross Taylor, Taylor calls rubrics out as an underappreciated research opportunity, particularly for agentic training:

    Rubrics are underhyped on social media – they were driving force behind projects like DeepResearch – and GenRMs are interesting but perhaps slightly overhyped.

    This caught my eye, as Moonshot leveraged rubric based rewards heavily in Kimi K2, notably using the model they were training as the judge of itself: 

    The framework operates using a Self-Critique Rubric Reward mechanism, where the model evaluates its own outputs to generate preference signals. To bootstrap K2 as a competent judge, we curated a mixture of open-source and in-house preference datasets and initialize its critic capability in the SFT stage.

    One of the core values of rubrics is that they work for both LLMs and humans. You can iterate on rubrics with people, scale them with LLMs, and spot-check LLM results with human raters to ensure reliability. 

    The paper [2507.17746] Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains formalizes them as a full peer to Verifiable Rewards. The paper sets up rubrics so each criteria is a simple pass/fail and each has a predefined importance weight. They normalize everything so the system can’t get gamed by just adding more criteria4, and then plug in the resulting score in to the RL loop5.

    Of course, you actually have to write the rubrics, which leads to a specificity versus generality tradeoff: take more time to write more rubrics or rely on fewer, more general ones. The RaR paper makes it clear that more is better:

    predefined generic rubrics substantially underperform compared to prompt-specific ones, underscoring the importance of contextualization. Rubrics that include a broader range of criteria—both positive and negative—consistently outperform those limited to essential checks, suggesting that richer evaluation signals lead to better learning.

    As you might have guessed, the solution was more LLM: use a model to generate prompt-specific rubrics:  

    For each domain, the prompt (included in Appendix H) instructs the LLM to generate 7–20 rubric items based on the complexity of the input question. Each item is assigned a categorical weight (e.g., Essential Criteria, Important Criteria) to determine its importance to a correct answer. The rubrics are designed to be fully self-contained which means that non-expert readers should be able to evaluate response quality using only the rubric. 

    This particularly benefited from having a reference answer attached to the prompt. The models do a much better job of coming up with a good rubric if provided with a (human generated) “good” answer to judge against rather than just the question/prompt. This really opens the door to 1:1 rubrics: given questions and reference answers, you can generate a scoring checklist for each one and mix it with verifiable rewards during post-training. 

    The field continues to be turtles all the way down: using LLMs to write rubrics to have LLM judges evaluate LLM training outputs. At some point, someone’s going to suggest we use rubrics to evaluate how good our rubrics are, and honestly, I’m surprised that paper doesn’t already exist6.

    1. Correct in predicting the next token, and helpful, honest and harmless, specifically. ↩︎
    2. With humans still looped in to validate that the ratings were reasonable. The human graders went from generating ratings to rating the raters. ↩︎
    3. This is the part where everyone pretends they know exactly how O1 works, but actually we’re all just pattern-matching from breadcrumbs ↩︎
    4. Else we’d risk giving more focus to problems with more rubrics, and end up with something unthinkable like a coding model that liberally sprinkles emojis everywhere ↩︎
    5. In practice, they also tried a single LLM judge that took in all criteria and weights and generated a scalar reward, which seemed to work fine. ↩︎
    6. It probably does, I’m just scared to look ↩︎
  • Overthinking Everything

    Yann was taking laps on Threads a few weeks back over a recent paper, which was one of several recently that have explored aspects of how autoregressive models do as the amount of information they are dealing with gets longer. His general complaint is that each token they generate can either push them towards the right answer or further away from it, and that the models are inherently bad at recovering if they end up too far outside the correct trajectory.

    This “more might be worse” idea shows up anywhere folks are leveraging large context windows, and one of those1 is in agentic tasks. This post summarizes some research trying to measure the fall-off in chances of succeeding as task length2 increases.

    Is there a Half-Life for the Success Rates of AI Agents? — Toby Ord

    It provides indirect evidence that what really is going on under the hood is that tasks are made up of many sequential subtasks and the chance of succeeding at the whole requires succeeding at every individual component. Moreover, this suggests that the current AI agents are not very good at recovering from earlier mistakes.

    The framing they use is a constant hazard rate: each subtask is another roll of the dice, and if you roll a failure you don’t have much chance of recovering. So more (or longer) is pretty much always worse.

    One interesting aspect is that they also investigate the human failure rate, which increases over time, but much more slowly:

    This could indicate a different scaling behaviour of success rate with time horizon for humans compared to AI agents, which would be well worth investigating and may suggest important underlying mechanisms (e.g. that the humans were better at correcting earlier failed subtasks). If human performance scales differently with task length than AI agent performance, that would be an important result, suggesting that there is a notable inefficiency in the current AI paradigm.

    They’re testing with multiple runs, so these aren’t just models hitting problems they can’t do: its models hitting problems they can’t do given the specific tokens they have generated tried before.

    Agentic use cases aren’t the only situation where a model is generating responses that add to its context window. There were a lot of early observations after the release of O1 last year that thinking for longer on easy problems does not add value. This recent paper proposes not only that but suggests that there is an inverse scaling law: more time thinking makes the model worse.

    [2507.14417] Inverse Scaling in Test-Time Compute

    More specifically, they devised some stress tests: things like counting problems in the presence of distracting information, performing a regression where there is easy-to-understand but spurious data, and so on. The performance drops as the trace length increases. Different models are more susceptible to some failure modes than other, but performance consistently drops:

    Our experiments reveal distinct failure modes across model families—Claude models are particularly vulnerable to distraction from irrelevant information, while OpenAI o-series models show greater resistance but overfit to familiar problem framings. Extended reasoning amplifies different weaknesses: models overthink simple problems, shift attention to spurious correlations, and lose focus during Deduction tasks with constraint tracking.

    In contrast, Chroma’s recent Technical Report investigates how models do on single prompts, but of increasingly long contexts.

    Context Rot: How Increasing Input Tokens Impacts LLM Performance | Chroma Research

    Unlike in the agentic case, here all of the context is passed in at once, so the model isn’t poisoning its own context window through bad choices. It is still dealing with a large amount of content where it needs to choose which parts to attend to. Traditionally the test of long context has been needle-in-a-haystack evaluations: a relevant fact is hidden at different points in a long prompt and the test evaluates whether the model can effectively pull it out.

    The Chroma folks make the test a lot more nuanced — adding distractors3 and irrelevant content in both the broader context and the question. They find that performance consistently degrades as context increases.

    More broadly, our findings point to the importance of context engineering: the careful construction and management of a model’s context window. Where and how information is presented in a model’s context strongly influences task performance

    All of these papers rhyme with LeCun’s gripe about autoregressive transformers, which is (roughly!) that they (also) have a constant hazard rate on generating the “right” token.

    This is a very active area of research though. Process-based rewards in RL training make updates on each step vs only at the end. Multi-token prediction reduces the effective generation length or number of chances of misprediction. Summarizing context effectively compresses existing tokens, also reducing error rate.

    Similarly, if you have good verifiers4 you can use beam or tree searches to explore multiple reasoning paths during generation , which can reduce the error rate, at the cost of more compute.

    The closest (LLMish) techniques to LeCun’s vision are things like the recent Hierarchical Reasoning Model that has a layer of persisting hidden state, but it’s still pretty experimental!

    As agentic and reasoning traces get longer, I’m sure we’ll see more entries documenting failure modes, and proposals for techniques to scale around them.

    1. And the one being referenced in the post! ↩︎
    2. In time — they characterize tasks based on how long it takes humans to do them, which is a good control factor ↩︎
    3. As in additional content related to the question, but that doesn’t give the answer. ↩︎
    4. Similar to process-based rewards this is somewhat pushing the problem to the ability to judge how well you are doing during the generation ↩︎
  • Toward a Theory of Tokenization in LLMs

    [2404.08335] Toward a Theory of Tokenization in LLMs

    Tokenization has always struck me as one of the odder aspects of natural language deep learning. Despite the extensive end-to-end learning processes we typically use, tokenization initially involves creating a dictionary of optimal sub-word segments from your dataset. One of the appealing concepts in the Byte Latent Transformers paper is the potential to learn tokenization dynamically, recognizing that tokenizers solve deeper problems than merely providing a fixed vocabulary.

    This paper addresses tokenization from a theoretical perspective by modeling sequences using kth-order Markov processes, where the likelihood of each token depends on the preceding sequence, as in natural language. The parameter k corresponds to the model’s context window size. Key findings include:

    1. Training without tokenization leads models to effectively behave as unigram predictors, significantly limiting performance.
    2. Using a well-designed tokenizer (e.g., Byte Pair Encoding – BPE) enables models to achieve nearly optimal performance in capturing sequence dependencies.
    3. Increasing the tokenizer’s dictionary size improves the model’s performance, moving it closer to the ideal probability distribution.

    Tokenizers which do a good job at learning patterns in the data and assigning these frequent patterns as tokens in the dictionary are compatible with an i.i.d. model over tokens.

    This insight suggests that despite the complexity of natural language’, a good tokenizer converts sequences into something approximating an independent and identically distributed (i.i.d.) format, which brings the modeling tasks for transformers closer to the one they can solve.

    While the paper does not explicitly explore the Byte Latent approach, I wonder if its entropy-driven dynamic token allocation might similarly achieve this i.i.d. simplification. In BLT the entropy model, trained separately, could be dynamically transform inputs into a distribution that is more palatable for transformers.

  • Fused Linear Cross-Entropy

    Fused Linear Cross-Entropy is a popular optimization that combines the final linear projection and cross-entropy loss into a single operation. This fusion is very valuable for training large language models efficiently, as it can reduce memory usage significant, particularly for larger vocabularies.

    If you look at a LLM training loop, you generally see something like:

    logits = model(input_ids)
    loss = cross_entropy(logits, targets)

    And if you look at the end of the model, you’ll see something like the below, where h is the hidden state so far and output is output = nn.Linear(embed_dim, vocab_size, bias=False)

    # shape: [b, seq_len, out_dim]
    output = self.output(h)

    That final logics value can be pretty big: sequence length is long and the vocabulary size is large (128k for Llama 3, 202k for llama 4), so logits can be GB of memory: with a 16k context window, a 128k vocab, and 4k embedding dimensions even at a batch size of 1, you get 8bn entries. At BF16, that’s 4GB. You’ll also need to capture the gradient, which will give you another 4GB in the backwards.

    That set of logits has a range of values that are a bit all over the place, one for each of the possible targets.

    Cross-entropy is a loss between two probability distributions. Jay Mody has an excellent blog post breaking down softmax and CE loss

    Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it’s common to use cross entropy as a loss function for classification problems where:

    • q is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called logits, denoted as y^), that is q=softmax(y^)
    • p is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position y (our label for the correct class): pi={1i=y 0i≠y

    This means that cross-entropy simplifies to F.nll_loss(F.log_softmax(x, 1), target)

    Softmax makes our previously messy logits into a nice probability distribution where all the values are positive and sum to one. log softmax is usually used in LLMs, for numerical stability and efficiency.

    When we implement softmax, the naive implementations looks something like:

    out = torch.log(torch.exp(x) / torch.sum(torch.exp(x)))

    This isn’t numerically stable, so you want to subtract the max to avoid overflows and underflows in the exp. This is the common log-sum-exp implementation:

    x_max = torch.max(x)
    shifted_x = x - x_max
    exp_shifted = torch.exp(shifted_x)
    out = shifted_x - torch.log(torch.sum(exp_shifted)

    In general the memory and compute cost of this grows with the size, which gets painful for our hefty logits. We can instead keep a running log-sum-exp so we don’t have to deal with the whole input at once.

    lse = xs[0]
    for x in xs[1:]:
        m = torch.max(torch.stack([lse, x]))
        lse = m + torch.log(torch.exp(lse - m) + torch.exp(x - m))
    out = lse

    This is the online log-sum-exp approach, and makes our life easier! We can now compute incrementally, but we are still generating the big logits before hand.

    Fused Linear Cross-Entropy replaces the output projection, softmax and loss calculation with a single kernel that a tiles across all of it.

    This is the core of the idea: instead of computing all logits at once (which creates a massive tensor), we can:

    1. Compute logits for small chunks of the vocabulary
    2. Compute the softmax incrementally
    3. Only store the logits we need for the loss calculation

    Quoting https://github.com/mgmalek/efficient_cross_entropy

    This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.

    Roughly, the loop looks like:

    For each of the token i in the sequence, with output layer weights h

    • Compute a partial dot product si = hi dot W_tile
    • Reduce for a running max and exp-sum
    • Return only the si[targeti] needed for the loss.

    This gives you quite a lot of memory wins, which also reduce peak memory bandwidth needs. But this also introduces some potential pain!

    1. You’re fusing the final layer op into the loss, which might be defined in different places in your model code
    2. You’re accumulating, so you have to use fp32 or risk introducing numeric errors
    3. You have to write you own backwards op as well, which will generally do some extra computation, so you are paying some extra FLOPS
    4. Debugging can be harder, so you want a good recipe prior to swapping in the op
    5. May require some futzing for best implementations on different hardware.

    Actually implementing is pretty straightforward.

    @staticmethod
    def forward(ctx, h, W, target):
        B, D = h.shape
        V, _ = W.shape
        
        chunk_size = min(1024, V)
        
       # Initialize online softmax accumulators
       max_logits = torch.full((B,), -float('inf'), device=h.device, dtype=torch.float32)
       sum_exp = torch.zeros(B, device=h.device, dtype=torch.float32)
       target_logits = torch.zeros(B, device=h.device, dtype=torch.float32)
            
        # Process vocabulary in chunks
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
                
            # Compute logits for this chunk only
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Update running max
            chunk_max = logits_chunk.max(dim=1).values
            new_max = torch.maximum(max_logits, chunk_max)
                
            # Adjust previous sum_exp by exp(old_max - new_max)
            sum_exp *= torch.exp(max_logits - new_max)
            
            # Add this chunk's contribution to sum_exp
            sum_exp += torch.exp(logits_chunk - new_max.unsqueeze(1)).sum(dim=1)
            
            # Update max
            max_logits = new_max
                
            # Extract target logits if target is in this chunk
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
            is_target_in_chunk = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            target_logits += (logits_chunk * is_target_in_chunk).sum(dim=1)
        
        # Compute loss: -log(p_target) = -(target_logit - log_sum_exp)
        log_sum_exp = max_logits + torch.log(sum_exp)
        loss = log_sum_exp - target_logits
        
        # Save for backward
        ctx.save_for_backward(h, W, target, max_logits, sum_exp)
        ctx.chunk_size = chunk_size
            
        return loss.mean()

    Here we chunk the vocabulary, calculate the partial transform for the chunk h @ W_chunk.T, do online softmax and accumulate the target logits.

    The backward calculates the gradients:

    @staticmethod
    def backward(ctx, grad_output):
        h, W, target, max_logits, sum_exp = ctx.saved_tensors
        chunk_size = ctx.chunk_size
            
        B, D = h.shape
        V, _ = W.shape
            
        # Scale gradient by batch size (since we use mean reduction)
        grad_scale = grad_output / B
            
        # Initialize gradient accumulators
        grad_h = torch.zeros_like(h)
        grad_W = torch.zeros_like(W)
            
        # Process vocabulary in chunks (same as forward)
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
                
            # Recompute logits for this chunk
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Compute softmax probabilities for this chunk
            # p_i = exp(logit_i - max) / sum_exp
            probs_chunk = torch.exp(logits_chunk - max_logits.unsqueeze(1)) / sum_exp.unsqueeze(1)
                
            # Gradient w.r.t. logits: grad_logits = p - 1_{y=i}
            grad_logits_chunk = probs_chunk * grad_scale
                
            # Subtract 1 from target positions
            is_target = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            grad_logits_chunk -= is_target.float() * grad_scale
                
            # Accumulate gradients
            grad_h += grad_logits_chunk @ W_chunk
                
            grad_W[chunk_start:chunk_end, :] = grad_logits_chunk.T @ h
            
        return grad_h, grad_W, None

    In the backwards we recompute the logits for the chunks, and calculate the logits.

    This is a very simplified implementation that trades off a bunch of kernel launches, so gives up a lot of performance, but you can see the memory savings:

    Regular:
    Time: 285.18 ms
    Memory (total): 3072.0 MB
    Loss: 11.142737
    Chunked online softmax:
    Time: 470.27 ms
    Memory (total): 356.0 MB
    Loss: 11.142738

    For a more sophisticated implementation, you can look at the repo mentioned before or Liger has a good quality kernel with further optimizations. These calculate the gradients in the forward pass, then can just scale them in the backwards. This trades off a bit more memory for less of a compute hit. In general there are a few options for choosing the right point

  • MegaScale-infer: disagg MoE inference

    https://arxiv.org/abs/2504.02263

    LLMs have most of their parameters in the FFN parts of the transformer layers — 50+bn params of the Llama 3 70b model, for example. The compute and memory requirements are a bit different between the FFN and attention parts of the model: attention requires a different KV cache for each request, so attention tends to be memory bound while the dense FFNs tend to be compute bound.

    Because of this it’s pretty common to split up tasks at inference time. The initial prefill stage (processing the initial prompt) populates the KV cache for the following autoregressive decoding. The decode can be more aggressively batched for getting better utilization. vLLM really helped popularize this idea!

    ByteDance extend this idea for mixture of expert models. In MoEs the compute intensity of the FFNs is limited by needing to load different experts, and having only a proportion of tokens going through a given expert. They extend the disagg idea to go from M “attention” GPUs to N (fewer!) expert GPUs, with a larger batch size for each of the expert calls. This gets better utilization on the matmuls and lowers overall cost of serving. The natural structure of transformer layers alternating attention and FFN lends itself well to a ping-pong pipelining approach that lets them hide the comms overhead.

    We present MegaScale-Infer, an efficient and cost-effective system designed for large-scale MoE serving. MegaScale-Infer disaggregates the attention and expert modules, assigning them to separate GPUs—a strategy we term disaggregated expert parallelism. Our approach offers two major benefits. First, it enables independent scaling of each module with customized model parallelism strategies. Specifically, attention modules are replicated using data parallelism, while FFN modules are scaled with expert parallelism. By consolidating requests from multiple attention replicas, the GPU utilization of each expert increases significantly as the batch size per attention replica grows. Second, it enables the deployment of attention and FFN modules on heterogeneous GPUs to fully leverage their different capabilities and achieve lower costs. For example, attention modules can be deployed on GPUs with more cost-effective memory capacity and bandwidth, while FFN modules can utilize GPUs with more affordable compute capability. As shown in Figure 1(c), FFN can easily become compute-intensive in MegaScale-Infer, while attention achieves higher GPU utilization per cost under heterogeneous deploymen

  • Useful Reasoning Behaviors

    [2503.01307] Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs

    Very useful insight in this paper out of Stanford.

    Test-time inference has emerged as a powerful paradigm for enabling language models to “think” longer and more carefully about complex challenges, much like skilled human experts. While reinforcement learning (RL) can drive self-improvement in language models on verifiable tasks, some models exhibit substantial gains while others quickly plateau.

    The authors were running a reasoning post-training process on both Qwen 2.5 3B and Llama 3.2 3B. They noticed that while both learned, Llama was consistently worse than Qwen, which feels odd as both models are strong. In looking at the reasoning approaches exhibited they observed 4 distinct reasoning strategies:

    • verification
    • backtracking
    • subgoal setting
    • backward chaining 

    They noticed that Qwen exhibited these behaviors more from the base model, and those behaviors were enhanced more in the RL process.

    While the larger Llama-3.1-70B showed generally increased activation of these behaviors compared to Llama-3.2-3B, this improvement was notably uneven — backtracking, in particular, remained limited even in the larger model.

    They then generated some custom reasoning traces that intentionally demonstrated all 4 behaviors, using Claude.

    We generate these datasets using Claude-3.5-Sonnet4, leveraging its ability to produce reasoning trajectories with precisely specified behavioral characteristics. While Claude does not always produce the correct answer (see Fig. 9), it consistently demonstrates the requested reasoning patterns, providing clean behavioral primitives for our analysis.

    They found that when using the SFT set before RL they closed most of the gap between Llama and Qwen. They also found that it isn’t even important that the reasoning traces are correct – demonstrating the behavior is more important than the reasoning itself at the SFT stage.

    Priming models with cognitive behaviors, by a small amount of finetuning, enabled significant performance gains even in models that initially lack these capabilities. Remarkably, this holds even when primed with incorrect solutions that exhibit the target behavioral patterns, suggesting that cognitive behaviors matter more than solution accuracy. 

    This fits with the elicitation idea — the SFT is training a style. By having increased activation of the reasoning styles the RL process is more able to explore these capabilities and reinforce extended reasoning generation.

    This also fits with my mental model that a base model’s capabilities are often pretty underexplored: a combination of targeted SFT + RL seems to be a very powerful elicitation tool!

  • Post-Training & Elicitation

    Nathan Lambert of the Allen Institute writes about their (very strong) Olmo 2 32B release, and the just released Gemma 3 model from Google. One of the many interesting points:

    Comparing Gemma 3 27B to OLMo 32B, the pretraining evaluations for both are super similar, but Gemma 3 scores are way better after post-training. The ceiling on post-training expectations has been shifting extremely fast among open models.

    Given that Google have about the best crawling infrastructure in the world, and that Al2 have published the complete pretraining dataset used for Olmo, I think this is slightly surprising. You can see the benchmarks in the blog and technical report: for example, Gemma 3 27B gets 78.8 on winogrande from pretraining (a little below Gemma2 as it happens) while Olmo2 32B get 78.7.

    The vibes have definitely shifted to post-training for where model differentiation is coming from, opening the question of what exactly is happening there. Nathan also posted about that recently, linking to this post by Mohit Ragavendra of Scale and Georgia tech:

    The post looks at The Superficial Alignment Hypothesis, which is (largely) that post-training is just about preference tuning for behaviors the base model can already do

    […]

    It initially seems like “Less Is More” in the sense that the LIMA model response was highly preferred by the GPT-4 evaluator for Math prompts (in-line with the work’s original claim). However, these model responses were also largely incorrect – the accuracy of models fine-tuned specifically for Math was substantially better, with the same data budget. If we went by subjective win-rate comparisons, we would have picked a model that was significantly worse.

    In the post (and the two linked papers) Mohit breaks down how post-training actually helps. Starting with SFT, the work shows that mimicking style happens quickly, with relatively few samples.

    with just a hundred finetuning examples, the model’s formatting mistakes were virtually solved – the model was perfect at mimicking the expected style. However, the model took a lot more supervised finetuning data to get better at reasoning – the substance of the task.

    They find though that, largely, more-is-more when it comes to SFT, but that there is a power-law style scaling curve: big gains initially followed by slower, marginal gains. Adding in RL doesn’t change the fundamental curve, but it does shift it, leading more efficiently to the model gaining the reasoning capabilities they were training towards:

    Preference data offers a weaker signal compared to supervised finetuning data. So, running DPO directly on the base model on reasoning tasks, is asking the model to learn a completely different response style from its reference model, with a weaker signal, while penalizing for being different from the reference model. Small amount of SFT on the base model teaches it the reasoning style and PFT can use the reward signal to focus on reasoning within the required response space.

    I did wonder when reading this whether the results would look different with an online process (like PPO), rather than an offline. Luckily, Mohit links to another recent paper on this topic:

    We prove that under idealized assumptions, online
    and offline PFT techniques should return policies of
    equivalent quality

    but also

    we observe that despite the lack of information-theoretic separation, online PFT out-performs offline PFT across different sampling distributions, labelers, and model sizes. Furthermore, it appears to be “easier” to learn a global RM than it is to learn a local RM, leading to higher validation likelihood.

    The result here seems to be that the reward model is simply easier to model, and it helps “translate” the problem of the distribution.

    This all feels like a continuum: at some level the superficial alignment hypothesis is directionally correct but its not that “superficial”: the base models have a lot of capabilities that are hard to elicit, and fine tuning/post training can juice them effectively, while adding some learning of its own (as more data is better!)

    The best way of performing that elicitation turns out to be solving different problems at different levels: SFT for format, then RL for the deeper capability, and having a reward model effectively simplifies the learning process again.

  • Demystifying Long Chain-of-Thought Reasoning in LLMs

    https://arxiv.org/abs/2502.03373

    Very clear paper on how RL and SFT combine to elicit reasoning capabilities, with some practical takeaways:

    • SFT on long chains of thoughts are much better than short
    • Reward shaping is important for getting stable scaling
    • Reasoning approaches (like backtracking) are probably present in the base model, if it’s big enough

    Based on these observations, we hypothesize that RL primarily guides the model to recombine skills it already internalized during pre-training towards new behaviors to improve performance on complex problem-solving tasks.

  • Byte-Latent Transformers

    Who needs a tokenizer anyway!

    [2412.09871] Byte Latent Transformer: Patches Scale Better Than Tokens

    This paper, from back in December last year, presents an interesting approach to handling raw byte sequences in LLMs without relying on tokenization.

    Vocab sizes for tokenizers have gone up over the last couple of years with attendant gains in usefulness, but this remains a particularly hand-tuned number in the training process. BLT proposes a method that processes raw UTF-8 byte sequences directly, leveraging a dynamic patching mechanism to group bytes into variable-length patches based on entropy.

    Higher-entropy regions receive more attention and shorter patches, while lower-entropy regions can be processed more efficiently.

    There are conceptually three levels of processing:

    • Local Encoder: A small transformer stack encodes raw byte sequences into higher level representations, which are then structured into patches.
    • Latent Global Transformer: A standard large transformer model operating on patch-level representations
    • Local Decoder: The encoded patches are decoded back into byte sequences, using a cross-attention mechanism to reconstruct text.

    In the paper they show they can achieve parity in pretraining with a traditional tokenized approach in llama for similar parameter count, while being more robust and offering some inference time performance gains. The patching approach allows for allocating compute where needed most.

    Retrofitting existing models

    One of the ideas I found most interesting is starting with a traditionally pretrained model. The paper discusses using the main transformer layers from Llama and training the byte latent approach successfully.

    I gave the approach a go with a simplified local encoder, entropy and patching approach, and took the transformer layers from Qwen 2.5 3B, a strong model that could still be trained locally (no corporate resources were harmed, etc).

    The basic approach was replacing the tokenizer, adding a small transformer and patch pooling based on a local entropy measure to generate patches, then cross-attending in some of the Qwen layers. Its training a new encoder while leveraging Qwen for the backbone of the global transformer and adding new cross-attention params to make it also the decoder, with the embedding layers at each end chopped off – so a significant domain shift. For inference I leverage the same patch generation process to try and generate effective tokens.

    You can find my Torchtune recipe on GitHub, running through the Alpaca dataset. Thus far I’m still training so while loss is improving, I have no idea whether it will turn into something useful. The fact that there is something trainable is fun though, and I have hopes that this kind of technique will lead to some breakthroughs in tokenizer-free models in the future!

  • Streaming DiLoCo

    [2501.18512] Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch

    Every paper in this series has been required reading in (very) large language model training. The basic theme is that model training requires gang-semantics, where a large cluster of accelerators need to do coordinated work together in order to make progress, which gets progressively more expensive to enable and harder to do reliably as the number of devices in the cluster increases.

    The prior papers explored ways of splitting up the training into an inner loop where the model trained fairly traditionally, and an outer optimization loop that aggregated the differences and updated based on them – the outer optimizer works on the deltas between parameter values at the sync point. The outer optimizer still runs on the same cluster as all the inner loops, but it means that only at the “outer” sync point do you need to do synchronization between all the devices. This loosens the coupling between devices and allows introducing failure domains.

    This paper addresses the challenge that when you do synchronize you still have to send data for all the parameters, which requires a lot of bandwidth and can block forward progress. Streaming DiLoCo divides the model layers into different shards and syncs those at different times (in practicality, ever 5 inner optimizer steps), lowering the peak bandwidth required. They take shards in a strided fashion rather than sequentially to mildly improve stability and performance.

    To further reduce bandwidth, the communication between devices for the outer loop is done in 4-bit floating point! They still do the accumulations/optimization in 32 bit, but they didn’t see any performance loss when using the lower bit rate for comms. All of these comms are overlapped with the inner loop training, which helps minimize stalls.

  • GRPO & Verifiable Rewards

    GRPO (Group Relative Policy Optimization) is an RL technique originally proposed in the DeepSeekMath paper. Instead of using a full-blown value network like PPO does, GRPO samples a group of completions for a given prompt and then computes a relative (normalized) reward for each output. The rewards are “verifiable” because they come from checking the final answer against ground truth and confirming. E.g. does the response follow the expected format (i.e. a <think>…</think> block for reasoning and an <answer>…</answer> block for the solution) and is the answer accurate against a predetermined fact. Not every problem fits this model, but there are a bunch that do, including math reasoning with the GSM8K dataset of grade-school math word problems. These look like this:

    “Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?”

    How Does the Training Work?

    1. Sampling Completions: For each prompt, the model generates a group of candidate completions. These are produced in inference mode (gradients aren’t collected) using a KV cache for speed (or a dedicated inference engine like VLLM)
    2. Verifiable Reward Calculation: Each completion is scored between 0 and 1—rewarding outputs that follow the prescribed format and yield the correct answer.
    3. Forward Pass for Gradients: Both the “policy” (the model being tuned) and a reference (typically the base, unmodified model) are used for a forward pass with the prompt and completions to compute per-token logits and log-probabilities.
    4. Loss and Backwards: The loss is then calculated as a combination of the (group-averaged) reward and a KL divergence term between the tuned model and the baseline, to constrain learning to similar responses. This loss is backpropagated through the policy model based on the earlier forward pass.

    Getting it going in TorchTune

    Over last weekend I hacked up a quick and dirty version of the training loop in the TorchTune, and over a couple of bus rides to Menlo Park cleaned it up into something that could work as a more general recipe(PR). Most of the work goes into the recipe and getting the dataset shaped properly to generate completions. This version—tested on a smaller model (the 1B Llama 3.2 variant, with LoRA)—showed some promising improvements in approach but I didn’t get to the point of having something converge enough to be confident in the overall recipe. In the DeepSeek R1 paper they had discussed trying a smaller model, but found 3B was the lowest they were able to get results on with some of their fine-tuning approaches.

    Luckily for everyone, at around the same time Ariel Kwiatkowski also put together a version that included distributed device support, making it easier to experiment on bigger models. This PR is more modular, and I’m excited to see it refined and landed so the recipe is widely available!

    There’s a growing energy around tools like torchtune, and it’s exciting to see how easy it is to “hack on” these ideas. It’s also great to see the techniques show up in other libraries, like HuggingFace’s TRL, which is being used as part of the OpenR1 replication effort!