Tag: pytorch

  • Native DSLs Ops in PyTorch

    You may have noticed that FlashAttention 4 was supported in PyTorch really quickly. That required a bit of new infrastructure: torch.native by Simon Layton. Prior versions of FlashAttention were written in Cutlass/C++, but for FA4 the team implemented the kernel in CuteDSL.

    Edit: Simon kindly pointed the FA4 work integration work predated his formalization of this pattern and was the impetus for it: long-time-SDPA maintainer Driss landed the change. As always PyTorch takes a village and I am glad for everyone’s contribution!

    You wouldn’t think that using an embedded Python DSL in a Python based Ml framework would be a challenge, except that almost all of the stuff that does ML in PyTorch is in fact… not written in Python. Replacing a PyTorch operator meant shipping a new native kernel and dealing with the build and dispatch pipeline.

    Layton’s change opened the door to overriding default ops with ones authored in a embedded DSL, initially Triton or CuteDSL.

    To be clear, this is not a replacement for custom ops, which most of the time is the best way of adding a new operator. torch.library.triton_op already lets you register a customer Triton kernel, for example. But FA4 is the kind of situation where wended an an alternative: it’s the right path for newer GPUs, it’s written in CuteDSL, and the PyTorch team wanted it to be available quickly to all PyTorch users without modifying their models.

    To give an example, we can replace the built-in aten::_fused_rms_norm with a Triton version1:

    """
    Triton kernel for fused RMS normalization.
    RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
    """
    import triton
    import triton.language as tl
    import torch
    @triton.jit
    def _rms_norm_fwd_kernel(
    X_ptr,
    W_ptr,
    Y_ptr,
    RRMS_ptr, # reciprocal RMS, saved for backward
    stride_x_row,
    N_COLS: tl.constexpr,
    eps: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    ):
    # [...]
    tl.store(RRMS_ptr + row_idx, rrms)
    def triton_rms_norm_forward(
    x: torch.Tensor,
    normalized_shape: list[int],
    weight: torch.Tensor | None,
    eps: float | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
    """Fused RMSNorm forward pass using Triton."""
    # [...]
    return y.reshape(orig_shape), rrms

    Actually hooking it up requires calling a DSL-specific op override function, in this case triton_utils.register_op_override. This goes directly into the dispatch architecture, which means it works with autograd, torch.compile and so on.2

    """
    Register a Triton-based RMSNorm as a native op override using torch._native.
    """
    from torch._native import triton_utils
    def _triton_fused_rms_norm(dispatch_keys, x, normalized_shape, weight, eps):
    """
    Wrapper that lazily imports the Triton kernel on first call.
    """
    from triton_kernels import triton_rms_norm_forward
    return triton_rms_norm_forward(x, normalized_shape, weight, eps)
    def register():
    """Register the Triton RMSNorm override."""
    triton_utils.register_op_override(
    "aten", # lib_symbol: override an aten op
    "_fused_rms_norm", # op_symbol: the specific op
    "CUDA", # dispatch_key: only on CUDA
    _triton_fused_rms_norm, # impl: our wrapper
    unconditional_override=False, # receives dispatch_keys as first arg
    )

    Now when we call torch.ops.aten._fused_rms_norm(x, shape, weight, eps) PyTorch will automatically use our Triton override!3

    The unconditional_override param in the registration call is a helpful one: if false the function receives torch.DispatchKeySet as its first argument. This allows overriding only in specific circumstances. For example, our Triton kernel is faster than the C++ one only for larger shapes, so we could gate the decision on that:

    def _smart_rms_norm(dispatch_keys, x, normalized_shape, weight, eps):
    n_rows = x.numel() // normalized_shape[-1]
    if n_rows < 4096:
    # Fall back to default C++ kernel for small shapes..
    return torch.ops.aten._fused_rms_norm.default(x, normalized_shape, weight, eps)
    from triton_kernels import triton_rms_norm_forward
    return triton_rms_norm_forward(x, normalized_shape, weight, eps)

    Going back to FlashAttention4, this overrides aten::_scaled_dot_product_flash_attention so that any code using torch.nn.functional.scaled_dot_product_attention will transparently get the FA4 implementation.

    torch._native fundamentally lowers the barriers to entry for bringing new kernel implementations into PyTorch. That’s good for mainline PyTorch, and it also allows ML infrastructure teams to ship optimized kernels for new hardware without waiting for the PyTorch release cycle.

    1. Trimmed for length, ask your favorite coding agent to write you an RMSnorm kernel, or look at the gist. ↩︎
    2. To avoid hammering the import latency, all DSL runtimes are lazily loaded when the kernel is first called ↩︎
    3. In this case we only override for CUDA tensors, so CPU ops will continue to be handled by the default implementation. ↩︎
  • Helion and the evolving GPU programming model

    Helion: A High-Level DSL for Performant and Portable ML Kernels – PyTorch

    Lots of announcements around the Triton and PyTorch Conferences this week, including the 1.0 of Helion, a high-level kernel authoring DSL:

     It establishes a new layer of abstraction that bridges the user-friendly simplicity of PyTorch with the performance of a lower level language. By automating tedious and error-prone tasks like tensor indexing, memory management, and hardware-specific tuning, Helion empowers developers to focus on algorithmic logic rather than hardware-specific implementation details. Helion achieves this balance by pairing a familiar, PyTorch-centric syntax with a powerful autotuning engine that automates the complex search for optimal kernel configurations. This results in a system that delivers performance portability across hardware architectures while drastically reducing development effort. 

    There has been a bit of an explosion in kernel-authoring options recently with CuTe-DSL and CuTile from Nvidia, TileLang (as featured in recent DeepSeek releases), Gluon and TLX1 as well as evolutions to core Triton, Thunderkittens, Pallas, and others.

    There are a couple of different axes of progress occurring in GPU authoring. The first is between iterable, researcher-friendly declarative code and tightly written hardware-friendly imperative code.

    Its a classic developer-experience trade off: you let people tell you what they want to do (matmul these things then apply a softmax) or you let people tell you precisely how to do it (run this dot product on these SMs then aggregate the result).

    In general you want to stay as high-level as possible, particularly if you are experimenting with lots of different variants in a research type setting, but you may have a bound on the performance hit you can accept. A common example is you want to iterate on some attention variant, but don’t want to completely give up on the performance wins of Flash Attention.2

    Triton and others provided an interesting middle ground: it was easy enough to iterate with thanks to being embedded in Python, and was performant enough as it leveraged a compiler to automatically apply some optimizations. You are still much more imperative in a PyTorch program, but you work at a higher level of abstraction: rather than writing programs which own a thread of data, as in CUDA, you think about a tile of data. The ThunderKittens docs put this well:

    A GPU is not really a 1000×1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16×16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16×16 values.

    The next abstraction that frameworks developed was how to represent data across the memory hierarchy. To take advantage of the tensor cores you have to have data laid out in a specific way in registers. But you are better off loading data in a different order in global or shared memory. CuTe offered a big benefit by giving you types to represent layouts that could be composed, making it easier to keep track of the data movement required. Triton and others leaned on the compiler to infer the right layouts and offered higher-level APIs to copy data between stages.

    This started to get challenging on Hopper, thanks to TMA3 and the limitations of memory bandwidth, which gets to the second evolution happening in GPU kernels. How do you orchestrate the movement of data between memory layers while ensuring that data was you keep the tensor cores saturated. This involved techniques like warp specialization, where individual warps do different operations towards a shared goal. That means carefully allocating ownership over registers to avoid warps stepping on each other. Blackwell4 made this even trickier with the addition of TMEM, 2-CTA mode and other features that offered more performance but required even more careful orchestration.

    In compiler terms this is a scheduling problem and in general the industry is quite good at it! CPUs give compilers a lot of leeway to schedule operations efficiently because they have a great deal of support for out-of-order execution, well documented ops, and substantial caches. GPUs process groups of threads5 in lockstep and demand strict timing about when to insert barriers, issues async operations and so on. 

    A GPU scheduler has to tag operations to specific warp-slots in advance, assign numbers of registers to them to avoid conflicts, and sync them with barriers. It’s a lot more brittle: if we guess wrong, we can idle the Tensor cores and tank efficiency. The actual execution model is a bit of a black box too: the target for compilers (PTX) is actually further compiled to SASS by nvcc.

    Across the industry we’ve been exploring ways to be more explicit without giving way all of the operational and developer efficiency gains of higher-level languages. CuTe-DSL offers a very close-to-hardware model but in a Pythonic package6, Gluon (OpenAI) and TLX (Meta) add extensions to allow modelling pipelines in code without getting rid of the Triton compiler, TileLang builds on TVM with explicit pipeline declarations.

    One of the reasons for this variety is we don’t quite know how to express a warp-group pipelined execution model. For example, TileLang has a pipelined construct:

    for k in T.Pipelined(loop_range, num_stages=num_stages):
        MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)  # Q @ K^T
        Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
        Rescale(acc_o, scores_scale)  # Apply correction
        MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)  # P @ V

    Gluon has a descriptor that allocated resources like registers explicitly to warps:

    gl.warp_specialize(
            (config, chnls, descs, M, STAGE),     # Args to correction stage
            _attn_fwd_correction,                  # Trunk task (1 warp, 192 regs)
            (config, chnls, descs, M, STAGE),     # Args to specialized tasks
            [
                _attn_fwd_softmax0,    # 4 warps, 192 registers - Softmax tile 0
                _attn_fwd_softmax1,    # 4 warps, 192 registers - Softmax tile 1
                _attn_fwd_mma,         # 1 warp, 24 registers  - Matrix multiplies
                _attn_fwd_load,        # 1 warp, 24 registers  - TMA loads
                _attn_fwd_epilogue,    # 1 warp, 24 registers  - Store results
            ],
            [4, 4, 1, 1, 1],          # Warps per stage
            [192, 192, 24, 24, 24]    # Registers per stage
        )

    And TLX tags sections of code with contexts to indicate groupings, and also allocates resources:

    with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS,
                        registers=232,
                        replicate=NUM_MMA_GROUPS):

    They can all work and finding the best trade off is a good goal, but in all cases they do force a lot of decisions. As an example, that allocation of how many registers to use is not only operation dependent, its hardware dependent, and that makes portability between hardware (even different generations from the same vendor) expensive. Manual controls are necessary: it takes time to develop the compiler passes and heuristics to optimally divide work, so handing explicit control over7 is beneficial, particularly when serving at scale. The cost is complexity and portability. This is where Helion takes a different tack

    Anyway, so what about Helion?

    Helion instead take a point on the line above Triton, but below the ML frameworks. It focuses on just expressing what you want to happen from the tile perspective.

    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
        out[tile_m, tile_n] = acc

    Under the hood, this compiles down to Triton. You might think would be a bit of a no-op on performance, but in practical terms its often better. The reason is search: Helion can autotune across a wide number of parameters, then let you bake them into your kernel once you’ve identified good ones for your specific setup. The example in the blog posts shows how many dimensions of search need to occur:

    @helion.kernel(config=helion.Config(
        block_sizes=[64, 64, 64],
        loop_orders=[[0, 1]],
        l2_groupings=[4],
        range_unroll_factors=[0, 1],
        range_warp_specializes=[None, False],
        range_num_stages=[0, 3],
        range_multi_buffers=[None, False],
        range_flattens=[None, None],
        num_warps=8,
        num_stages=6,
        indexing='block_ptr',
        pid_type='flat'
    ))
    def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    This makes moving to different hardware as simple as redoing the search process, and offers a much more comprehensive exploration than most folks would do when hand-rolling a lower level kernel. Its a very interesting idea, and I’m glad to see more people kicking the tires!

    Low-level optimizations aren’t going away any time soon, but I’m glad to have more exploration in the kernel development space. Finding the right abstractions and right compiler approaches to keep scaling kernel development will help make it accessible to more and more people and ensure that we can evolve our kernels with the hardware.

    1. Also a Meta thing, disclaimer. ↩︎
    2. This is the logic behind FlexAttention, whch was one of the lights that guided the way towards Helion. ↩︎
    3. Fully async copies – a separate execution engine to move data ↩︎
    4. Well, datacenter blackwell. Consumer blackwell lacks TMEM and 2-CTA, so is a bit more Hopper-like programming model. I’m not sure yet what the DGX Sparks have! ↩︎
    5. Warps – 32 threads on Nvidia, or waves, 64 threads on AMD. The important distinction is that all these threads are doing the same thing: you can mask some of them out, but they have a fairly simple march through the instruction. ↩︎
    6. With a JIT! ↩︎
    7. Without making people write templated C++, sorry Ben ↩︎
  • Layouts

    You could have invented CuTe hierarchical layout (but maybe not the rest of it?) : ezyang’s blog

    Ed posted the best intro to CuTe layouts I have seen, by showing how to extrapolate them from PyTorch striding1.

    Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.)

    Relatedly, Simon Veitner put together a quite visual understanding of layouts. https://veitner.bearblog.dev/intuition-behind-hierarchical-layouts/ – the graphics are helpful once you have the baseline intuition from Ed’s post!

    1. If you’re not familiar with striding, Ed’s PyTorch Internals talk/post remains the best intro! ↩︎
  • PyTorch Conference 2025

    The schedule is up for the 2025 edition of the PyTorch conference, which is now at the Moscone West in San Francisco.

    https://events.linuxfoundation.org/pytorch-conference/program/schedule/

    There are a lot of great sessions, but I’ll highlight some I personally find particularly interesting:

    Post-Training: Clearly a big theme this year, with some interesting talks from multiple groups:

    General Training

    Kernel development

    Compilers

    Inference

    I’m looking forward to October!

  • Monarch: PyTorch Single Controller

    I’ve been excited for this to make it to OSS: The PyTorch team at Meta recently soft-launched Monarch on Github.

    pytorch-labs/monarch: PyTorch Single Controller

    Back in 2022, Google’s Pathways paper proposed (revisiting) a single-controller approach for managing machine learning runs. Typically, ML jobs use an SPMD (Single Program, Multiple Data) approach, distributing identical code across multiple hosts. Each host runs independently, synchronizing during collective operations. This works, as evidenced by the many large training runs in the world. It also introduces complexity, especially with pipeline parallelism where conditional logic for different ranks can clutter up your training code. Even without that, subtle issues can arise: for example, slight differences in torch.compile optimization have (in the past!) lead to deadlocks by placing collectives differently on separate nodes.

    The single-controller model simplifies this by centralizing program execution on one main node and using generic workers on the hosts that execute assigned tasks. This provides a consistent, global view of the entire computation, making it easier to get to a correct implementation of parallelisms and other distributed work. This doesn’t come for free though: the main node must efficiently manage (potentially!) thousands of GPUs without becoming a bottleneck, and existing code must adapt to this new centralized model.

    Monarch is the PyTorch team’s implementation of this single-controller concept. It provides a familiar PyTorch frontend, additional module wrappers, and a high-performance Rust-based actor system for distributing and managing work.

    The fundamental abstraction in Monarch is the Actor. Each Actor executes on their own accelerator, maintains state and behavior. Communication with other Actors is via async message passing on methods decorated with @endpoint. The nice thing about the programming model is you can interact with all of the actors in your mesh in a consistent way.

    Monarch is appealing even if you’re not GPU-rich. For instance, at home, I have a machine equipped with two (mismatched) 3090s, and Monarch allows me to run and debug jobs directly in notebooks without relying on external services.

    Installation had minor hurdles because I built from source rather than using the available pip package. Although the README specifies Python 3.10, Python 3.13 worked fine. The dependencies reference dnf (reflecting Meta’s internal Linux distro choice), so adapting commands to other Linux distributions (Ubuntu in my case) was necessary. Additionally, I had to set BINDGEN_EXTRA_CLANG_ARGS="-I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11" to resolve Rust compilation issues.

    Once installed, running Monarch’s distributed data-parallel notebook was straightforward (see: monarch/examples/notebooks/spmd_ddp.ipynb). The notebook shows that minimal code changes to the standard DDP example are required, mainly subclassing Actor (e.g., class DDPActor(Actor)), while keeping the training loop virtually identical. Monarch handles the rest, including distributed execution and debugging across multiple GPUs.

    Setting up the environment means providing the mesh configuration and launching the actors, which can be done from a cell:

    # Spawn a process mesh
    local_proc_mesh = await proc_mesh(
        gpus=WORLD_SIZE,
        env={
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12455",
        },
    )
    # Spawn our actor mesh on top of the process mesh
    ddp_actor = await local_proc_mesh.spawn("ddp_actor", DDPActor)

    I didn’t have to manually start any other services; it all happened under the hood. Triggering the run is just:

    await ddp_actor.demo_basic.call()

    Which output:

    self.rank=0 Running basic DDP example
    self.rank=1 Running basic DDP example
    self.rank=1 Finished running basic DDP example
    self.rank=0 Finished running basic DDP example

    What I find really appealing is how easy it is to execute across ranks. For example, to query for system info:

    print("Gathering system info from all ranks...")
    system_info = await ddp_actors.get_system_info.call()
    
    print("\n SYSTEM INFORMATION ACROSS ALL RANKS:")
    print("=" * 60)
    
    for point, rank_info in system_info:
        print(f"Rank {rank_info['rank']}: PID={rank_info['process_id']}, "
              f"Device={rank_info['device_name']}, "
              f"GPU Memory={rank_info['gpu_memory_allocated']/1e6:.1f}MB")
    
    print(f"\nFound {len(system_info)} ranks in the mesh")
    
    all_rank_info = [value for point, value in system_info]
    print(f"Total GPU memory across all ranks: {sum(info['gpu_memory_allocated'] for info in all_rank_info)/1e6:.1f}MB")

    Outputting:

    Gathering system info from all ranks...
    [Rank 0] System Info: PID=10519, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    [Rank 1] System Info: PID=10520, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    
     SYSTEM INFORMATION ACROSS ALL RANKS:
    ============================================================
    Rank 0: PID=10519, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    Rank 1: PID=10520, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    
    Found 2 ranks in the mesh
    Total GPU memory across all ranks: 0.1MB

    I can also stop training and dump state if I need to , making it easy to check norms and debug:

    print("Running training steps...")
    for step in range(3):
        print(f"\n--- Step {step + 1} ---")
        
        step_results = await ddp_actors.train_step.call()
        
        all_results = [value for point, value in step_results]
        
        losses = [result['loss'] for result in all_results]
        grad_norms = [result['grad_norm'] for result in all_results]
        throughputs = [result['throughput'] for result in all_results]
        
        print(f"Losses across ranks: {[f'{l:.4f}' for l in losses]}")
        print(f"Gradient norms: {[f'{g:.4f}' for g in grad_norms]}")
        print(f"Avg throughput: {sum(throughputs)/len(throughputs):.1f} samples/sec")
    --- Step 1 ---
    [Rank 1] Step 1: Loss=1.1128, GradNorm=0.3198, Time=0.241s
    [Rank 0] Step 1: Loss=1.0414, GradNorm=0.3198, Time=0.253s
    Losses across ranks: ['1.0414', '1.1128']
    Gradient norms: ['0.3198', '0.3198']
    Avg throughput: 129.8 samples/sec
    
    --- Step 2 ---
    [Rank 0] Step 2: Loss=1.1526, GradNorm=0.3096, Time=0.003s
    [Rank 1] Step 2: Loss=1.0546, GradNorm=0.3096, Time=0.003s
    Losses across ranks: ['1.1526', '1.0546']
    Gradient norms: ['0.3096', '0.3096']
    Avg throughput: 9800.9 samples/sec
    
    --- Step 3 ---
    [Rank 1] Step 3: Loss=0.9116, GradNorm=0.2243, Time=0.002s
    [Rank 0] Step 3: Loss=0.9662, GradNorm=0.2243, Time=0.002s
    Losses across ranks: ['0.9662', '0.9116']
    Gradient norms: ['0.2243', '0.2243']
    Avg throughput: 19977.5 samples/sec

    While the distributed stuff here is cool, it’s not wildly different than using a distributed framework like Ray and a little bit of setup (though I am pretty allergic to setup). What is most interesting is how this changes the programming model of PyTorch, and makes it really easy to build out distributed experiments.

    For example, if I was building a param server the sync only requires an await’d read of the weights from another object, taking advantage of the RDMA support for an efficient cop1y:

        @endpoint
        async def worker_sync_with_ps(self, param_server) -> bool:
            """Synchronize with parameter server and get RDMA handles"""
                
            self._log("Synchronizing with parameter server...")
            
            # Get RDMA buffer handles
            self.weight_buffers = await param_server.ps_get_weight_handles.call_one()
            self.gradient_buffers = await param_server.ps_get_gradient_handles.call_one()
            
            # Get metadata
            metadata = await param_server.ps_get_metadata.call_one()
            self.weight_metadata = metadata['weights']
            self.gradient_metadata = metadata['gradients']
            
            # Perform initial weight sync
            sync_time = await self._sync_weights_from_ps()
            
            self._log(f"Synchronized with parameter server (sync time: {sync_time:.3f}s)")
            return True

    Getting those weight buffers is as simple as creating the right Monarch object:

    def tensor_to_rdma_buffer(tensor: torch.Tensor) -> RDMABuffer:
        # RDMA requires 1D contiguous uint8 tensors
        byte_tensor = tensor.view(torch.uint8).flatten()
        return RDMABuffer(byte_tensor)

    For an early preview of a library, Monarch is surprisingly complete, and definitely worth a look.

    1. Not that this would do anything for my 3090s! ↩︎
  • Fused Linear Cross-Entropy

    Fused Linear Cross-Entropy is a popular optimization that combines the final linear projection and cross-entropy loss into a single operation. This fusion is very valuable for training large language models efficiently, as it can reduce memory usage significant, particularly for larger vocabularies.

    If you look at a LLM training loop, you generally see something like:

    logits = model(input_ids)
    loss = cross_entropy(logits, targets)

    And if you look at the end of the model, you’ll see something like the below, where h is the hidden state so far and output is output = nn.Linear(embed_dim, vocab_size, bias=False)

    # shape: [b, seq_len, out_dim]
    output = self.output(h)

    That final logics value can be pretty big: sequence length is long and the vocabulary size is large (128k for Llama 3, 202k for llama 4), so logits can be GB of memory: with a 16k context window, a 128k vocab, and 4k embedding dimensions even at a batch size of 1, you get 8bn entries. At BF16, that’s 4GB. You’ll also need to capture the gradient, which will give you another 4GB in the backwards.

    That set of logits has a range of values that are a bit all over the place, one for each of the possible targets.

    Cross-entropy is a loss between two probability distributions. Jay Mody has an excellent blog post breaking down softmax and CE loss

    Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it’s common to use cross entropy as a loss function for classification problems where:

    • q is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called logits, denoted as y^), that is q=softmax(y^)
    • p is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position y (our label for the correct class): pi={1i=y 0i≠y

    This means that cross-entropy simplifies to F.nll_loss(F.log_softmax(x, 1), target)

    Softmax makes our previously messy logits into a nice probability distribution where all the values are positive and sum to one. log softmax is usually used in LLMs, for numerical stability and efficiency.

    When we implement softmax, the naive implementations looks something like:

    out = torch.log(torch.exp(x) / torch.sum(torch.exp(x)))

    This isn’t numerically stable, so you want to subtract the max to avoid overflows and underflows in the exp. This is the common log-sum-exp implementation:

    x_max = torch.max(x)
    shifted_x = x - x_max
    exp_shifted = torch.exp(shifted_x)
    out = shifted_x - torch.log(torch.sum(exp_shifted)

    In general the memory and compute cost of this grows with the size, which gets painful for our hefty logits. We can instead keep a running log-sum-exp so we don’t have to deal with the whole input at once.

    lse = xs[0]
    for x in xs[1:]:
        m = torch.max(torch.stack([lse, x]))
        lse = m + torch.log(torch.exp(lse - m) + torch.exp(x - m))
    out = lse

    This is the online log-sum-exp approach, and makes our life easier! We can now compute incrementally, but we are still generating the big logits before hand.

    Fused Linear Cross-Entropy replaces the output projection, softmax and loss calculation with a single kernel that a tiles across all of it.

    This is the core of the idea: instead of computing all logits at once (which creates a massive tensor), we can:

    1. Compute logits for small chunks of the vocabulary
    2. Compute the softmax incrementally
    3. Only store the logits we need for the loss calculation

    Quoting https://github.com/mgmalek/efficient_cross_entropy

    This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.

    Roughly, the loop looks like:

    For each of the token i in the sequence, with output layer weights h

    • Compute a partial dot product si = hi dot W_tile
    • Reduce for a running max and exp-sum
    • Return only the si[targeti] needed for the loss.

    This gives you quite a lot of memory wins, which also reduce peak memory bandwidth needs. But this also introduces some potential pain!

    1. You’re fusing the final layer op into the loss, which might be defined in different places in your model code
    2. You’re accumulating, so you have to use fp32 or risk introducing numeric errors
    3. You have to write you own backwards op as well, which will generally do some extra computation, so you are paying some extra FLOPS
    4. Debugging can be harder, so you want a good recipe prior to swapping in the op
    5. May require some futzing for best implementations on different hardware.

    Actually implementing is pretty straightforward.

    @staticmethod
    def forward(ctx, h, W, target):
        B, D = h.shape
        V, _ = W.shape
        
        chunk_size = min(1024, V)
        
       # Initialize online softmax accumulators
       max_logits = torch.full((B,), -float('inf'), device=h.device, dtype=torch.float32)
       sum_exp = torch.zeros(B, device=h.device, dtype=torch.float32)
       target_logits = torch.zeros(B, device=h.device, dtype=torch.float32)
            
        # Process vocabulary in chunks
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
                
            # Compute logits for this chunk only
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Update running max
            chunk_max = logits_chunk.max(dim=1).values
            new_max = torch.maximum(max_logits, chunk_max)
                
            # Adjust previous sum_exp by exp(old_max - new_max)
            sum_exp *= torch.exp(max_logits - new_max)
            
            # Add this chunk's contribution to sum_exp
            sum_exp += torch.exp(logits_chunk - new_max.unsqueeze(1)).sum(dim=1)
            
            # Update max
            max_logits = new_max
                
            # Extract target logits if target is in this chunk
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
            is_target_in_chunk = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            target_logits += (logits_chunk * is_target_in_chunk).sum(dim=1)
        
        # Compute loss: -log(p_target) = -(target_logit - log_sum_exp)
        log_sum_exp = max_logits + torch.log(sum_exp)
        loss = log_sum_exp - target_logits
        
        # Save for backward
        ctx.save_for_backward(h, W, target, max_logits, sum_exp)
        ctx.chunk_size = chunk_size
            
        return loss.mean()

    Here we chunk the vocabulary, calculate the partial transform for the chunk h @ W_chunk.T, do online softmax and accumulate the target logits.

    The backward calculates the gradients:

    @staticmethod
    def backward(ctx, grad_output):
        h, W, target, max_logits, sum_exp = ctx.saved_tensors
        chunk_size = ctx.chunk_size
            
        B, D = h.shape
        V, _ = W.shape
            
        # Scale gradient by batch size (since we use mean reduction)
        grad_scale = grad_output / B
            
        # Initialize gradient accumulators
        grad_h = torch.zeros_like(h)
        grad_W = torch.zeros_like(W)
            
        # Process vocabulary in chunks (same as forward)
        for chunk_start in range(0, V, chunk_size):
            chunk_end = min(chunk_start + chunk_size, V)
            chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
                
            # Recompute logits for this chunk
            W_chunk = W[chunk_start:chunk_end, :]
            logits_chunk = h @ W_chunk.T  # [B, chunk_size]
                
            # Compute softmax probabilities for this chunk
            # p_i = exp(logit_i - max) / sum_exp
            probs_chunk = torch.exp(logits_chunk - max_logits.unsqueeze(1)) / sum_exp.unsqueeze(1)
                
            # Gradient w.r.t. logits: grad_logits = p - 1_{y=i}
            grad_logits_chunk = probs_chunk * grad_scale
                
            # Subtract 1 from target positions
            is_target = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
            grad_logits_chunk -= is_target.float() * grad_scale
                
            # Accumulate gradients
            grad_h += grad_logits_chunk @ W_chunk
                
            grad_W[chunk_start:chunk_end, :] = grad_logits_chunk.T @ h
            
        return grad_h, grad_W, None

    In the backwards we recompute the logits for the chunks, and calculate the logits.

    This is a very simplified implementation that trades off a bunch of kernel launches, so gives up a lot of performance, but you can see the memory savings:

    Regular:
    Time: 285.18 ms
    Memory (total): 3072.0 MB
    Loss: 11.142737
    Chunked online softmax:
    Time: 470.27 ms
    Memory (total): 356.0 MB
    Loss: 11.142738

    For a more sophisticated implementation, you can look at the repo mentioned before or Liger has a good quality kernel with further optimizations. These calculate the gradients in the forward pass, then can just scale them in the backwards. This trades off a bit more memory for less of a compute hit. In general there are a few options for choosing the right point

  • Richard Zou on torch.compile

    GPU MODE 4/19 Q&A – Google Docs

    PyTorch compiler engineer Richard Zou did a great Q&A session with the GPU Mode discord community recently. You can watch the session on YouTube, but Richard also collected questions into a doc with some nice snippets and references.

    Our value proposition: You can sit down for hours/days/weeks tuning a custom kernel. torch.compile provides good baseline performance so you don’t need to do that all the time!

    The goal with the compiler is that you can spend most of your time thinking about the model, get the majority of the speedups, and only have to go down to custom kernel authoring when you’ve established an opportunity or need for further performance.

  • Dynamic Shapes in PyTorch

    Dynamic shapes are one of the more distinctive parts of torch.compile. Rather than specializing a graph to static shapes (which works in many cases!), PyTorch’s approach allows a single graph to work for a variety of sizes, so things like sequence length or batch size can vary. It does this by reasoning about shapes symbolically: instead of using fixed shape values, it uses placeholders and infers rules that constrain those shapes.

    Tracing & Symbolic Shapes

    PyTorch uses tracing in Python (via Dynamo) to capture the graph of operations. By default, it marks shapes as static during tracing. If a shape marked as static changes at runtime, it is marked as dynamic and treated symbolically in a recompilation. You can also proactively mark a dimension as dynamic to encourage symbolic treatment from the start:

    torch._dynamo.mark_dynamic(x, 0)  # Mark dim 0 as dynamic

    Under the hood, PyTorch uses SymPy to represent and manipulate symbolic shapes. Each dynamic shape is replaced with a SymInt and tracked by a central ShapeEnv.

    Every operation in PyTorch has a meta function — a lightweight implementation that computes metadata like shape changes without actually performing the computation. This lets PyTorch propagate symbolic shapes through the graph. For example, concatenating two tensors along dimension zero is represented symbolically as:

    s0 = s_x0 + s_y0

    To support branching logic, symbolic shapes carry a “hint” of their current concrete value. This allows specific branches of conditionals like if tensor.size(0) > 2 to be taken during tracing based on the hint. PyTorch adds a guard at this point to ensure that the resulting graph is only used if that branch is the correct one.

    Guards

    Guards are runtime checks inserted into the compiled graph to ensure the assumptions made during tracing still hold. For example, in the case of tensor.size(0) > 2, if the tensor is the result of concatenation, the guard will check that a.size(0) + b.size(0) > 2. If this fails, the code is retraced, and a graph for the new branch generated. Multiple graphs can be cached and selected at runtime based on guard validation.

    Guards don’t need to assert exact sizes; they can use symbolic constraints like x.size(0) > 2. This allows dimensions to vary within bounded ranges. The backend compiler (usually Inductor) can then compile code that operates over symbolic dimensions, as long as the variability is within the guarded constraints.

    For example, operations like broadcasting typically generalize well to symbolic shapes. In contrast, if an op specializes on a fixed shape (e.g., optimized path for 1D input), it may require conditional tracing and guards.

    What this means in practice is that most of the time compilation will follow this process:

    • Take a batch of data, assume all shapes are static, insert guards, and pass static sizes to the compiler
    • On the next batch see which guards have been violated, mark those dimensions as dynamic, add appropriate guards and pass symbolic dimensions to the compiler
    • Assuming no control flow, continue to reuse this dynamic graph for without recompilation

    Backed vs. Unbacked SymInts

    Most symbolic shapes are backed, meaning they have an associated concrete value at trace time. These are usually derived from inputs and show up in traces as s0, s1, etc.

    Unbacked SymInts lack a concrete value. These arise from data-dependent operations, e.g.:

    n = (x > 1).nonzero().size(0)

    Here, n depends on the data in x, so its size cannot be known at trace time. It will be represented as an unbacked SymInt like u0.

    If a control flow decision depends on an unbacked SymInt, tracing cannot proceed, resulting in a graph break or a GuardOnDataDependentSymNode error (when full_graph=True).

    However, you can guide the compiler with additional constraints, e.g.:

    torch._check(x.size(0) == y, lambda: f"size mismatch: {x.size(0)} != {y}")

    This lets PyTorch treat x.size(0) as equivalent to y throughout the graph. The check will be validated at runtime.

    There are other APIs to help mark unbacked SymInts as size-like to enable meta function compatibility (see Ed’s docs for more).

    Controlling Dynamic Shape Usage

    You can control dynamic behavior in torch.compile with the dynamic flag:

    • Not passed: default shape inference behavior
    • dynamic=false: force all shapes to be static
    • dynamic=True: treat all shapes as dynamic

    The default is usually best, but dynamic=True can help in testing.

    Use full_graph=True to attempt to generate a single, complete graph without breaks. This is often critical for performance, as graph breaks can drastically affect runtime and it’s easy to make innocuous looking code changes that can trigger additional breaks!

  • PyTorch and Python Free Threading

    https://trent.me/articles/pytorch-and-python-free-threading/

    Trent Nelson has written an extremely detailed breakdown of his experiments with running inference on GPT-2 on PyTorch and the GIL-free version of Python from 3.13 and 3.14.

    He implements parallel generation using multiple threads (on one GPU and later multiple devices), parallel model loading, and then some of the challenges with torch.compile (which doesn’t work great with nogil yet!)

    Hopefully this encourages more folks to experiment with free-threaded Python, or perhaps port their existing Python packages to play nicely when installed in a free-threaded Python environment. I personally can’t wait until free-threaded Python is the default! Although that’s probably at least five or so years out at this point.

    Free threaded python really changes the performance trade-offs around Python, and I expect it to be the default for ML work a lot sooner than that!

  • A HBS paper on the PyTorch Foundation

    Igniting Innovation: Evidence from PyTorch on Technology Control in Open Collaboration by Daniel Yue, Frank Nagle :: SSRN

    Unexpectedly interesting HBS paper!

    This study looks at the impact of technology control on external contributions in open collaboration contexts by examining the case of PyTorch, a popular machine learning framework, which shifted its governance from a for-profit corporation (Meta) to a non-profit foundation in 2022. The results show that this shift led to a significant decrease in contributions from Meta but a notable increase from external companies.

    The PyTorch project was moved to a foundation in 2022, and that has been a pretty big success (from most any angle you care to look). This paper uses PyTorch as a natural experiment where an already-open-source project changed in governance structure, and what the result was.

    The net result they find is a similar level of overall contributions, but increased contributions from hardware companies. They conclude that there was previously a concern on project direction that could hold up certain types of contributions:

    openness does not magically create incentives for
    external participation without costs, but rather shifts incentives between focal [meaning the originators of the project] and external firms. In particular, control rights theory emphasizes that consideration of the optimal allocation of control rights (with respect to overall welfare) depends on the marginal returns to the ex-ante effort of each party

    There is a lot of adding structure to intuitively reasonable ideas, e.g. that “Users” have a higher incentive to collaborate or contribute as they capture value by by the API, which is unlikely to change regardless of directional shifts by the project owner. “Complementors” on the other hand benefit when their product is used in conjunction with the framework, and therefore they need more ongoing cooperation, so are more sensitive to the control of the project. In PyTorch’s case hardware manufacturers are complementors, and hence their contributions are expected (and did) increase with the changes in governance.

    What’s interesting, as they note in the paper, is that on a technical level the governance of PyTorch hasn’t changed that much. It does seem though a fair conclusion that adding the overall project governance group makes it more likely that such changes could be made, if needed:

    Nevertheless, by changing the governance to a model run by a voting board of other organizations and bringing in the LF, Meta’s singular control of the technical direction of the project (and potentially its social status as the creator of the tool) was greatly diluted.

    My main concern with the conclusions is the confounder of massive extra interest in generative AI post-Chat GPT. They do address that, an attempt to control by looking at TensorFlow:

    usage of AI technologies dramatically increased in December 2022 due to public release of OpenAI’s ChatGPT. And while Chip Manufacturer and Application Developer companies are both affected by this demand shock, it is possible that they are differentially affected by this change in a way that confounds our analysis. To rule out this possibility, we augment our sample by further gathering the external company commits data to TensorFlow, Google’s open source machine learning framework.

    TensorFlow feels like not a great baseline due to the different adoption by the research community. I would be mildly interested to see if XLA had any changes in contributions: the Google diaspora have certainly spread the technology via Jax!

    Regardless, this is an interesting paper and a good contribution in the larger economics of open source. More foundation led projects are better for everyone, so I’m glad to see research in these areas!

  • Quantization in PyTorch

    Jerry Zhang recently posted a couple of updates on the evolution of the quantization APIs in PyTorch, and the unification around TorchAO.

    If you haven’t spent much time around quantization, or are used mainly to quantizing LLM models via tools, the range of options can be pretty gnarly.

    Quantization compresses the model by taking a number format with a wide range and replacing it with something shorter. To recover the original value you track a scale factor and a zero point (sometimes referred to as affine quantization).

    For example, if you have a float32 layer, but all of the parameters with it are between 1 and 10, then using a int8 (256 values) to represent the whole float range will compress all those values to a single value, losing a huge amount of information. Instead, you set the scale factor to cover the range of values actually present. This lowers the quantization error: the difference between the upscaled value and the original.

    That’s easy to do for the weights of the models, but you need to calculate the activations as well (so you are multiplying matrics of the same time). “Static” quantization determines the scale factor and zero point for activations up front, while “dynamic” quantization calculates them at inference time, resulting in better accuracy. The downside is that generally the approach only works on CPUs as it’s inherently data-dependent.

    You don’t have to quantize everything the same way, and its common to see quantization schemes in the format of A16W8 . That indicates the weights are quantized to 8-bit (usually int8), but the activations are kept at 16 bit (usually float16 or bfloat16). In those cases, the weights are upcast to a matching dtype at compute time, but you still benefit from the faster loading and lower persistent memory usage.

    The general flow for quantization is:

    • Identify which parts of the model you want to quantize: often you’ll want to quantize some parts, and not others (for example, all the linear layers, but not a softmax)
    • Prepare the parts being quantized by adding quantize and dequantize operations around the normal ops.
    • For static quantization (where scale/zero point are set ahead of time for activations) calibrate the model by sending input data through and observing the ranges.
    • Convert the model, by replacing the layers with their quantized, lower bit representations, and ensure the appropriate operators are in place to dequantize when the

    Quantizing after training is Post-Training Quantization (PTQ). Quantization-Aware Training (QAT) introduces quantization during training, allowing the model to learn and partially recover accuracy loss.

    Quantizing in PyTorch

    There are (at least!) 4 different approaches to quantization in PyTorch:

    • Eager Mode: Deprecated, simple to use with quantize_dynamic for quick dynamic quantization.
    • FX Mode: Also deprecated; separates quantization from model code via FX graph tracing but requires FX-traceability.
    • PT2E Mode: Current method using PyTorch 2 export. It captures the model graph, supports backend-specific configs (like XNNPack), and is preferable for exported models.
    • TorchAO Quantization: Latest method optimized for torch.compile. Includes easy-to-use features like autoquant for automatic quantization tuning as well as manual options.

    Eager mode quantization

    The original, and soon to be deprecated, quantization method in PyTorch was eager mode quantization. The simplest version to use is calling torch.ao.quantization.quantize_dynamic. This takes a config (identify what parts of the model you want to quantize), and a target dtype. That call scales the weights and downcasts, and injects ops to dynamically quantize activations. The prepare and convert steps are handled automatically, and there is no calibration as the activations will be scaled dynamically.

    Eager mode can also static quantization with torch.ao.quantization.quantize. This requires a lot more modification: you manually add torch.ao.quantization.QuantStub() and DequantStub around the nn.Module calls you want to operate quantized. Then you torch.ao.quantization.prepare the model and call the forward with some example data to calibrate. Prepare adds observers that collect statistics on the activations to determine good zero point and scale values. Finally torch.ao.quantization.convert processes and returns the quantized model.

    Quantization aware training works the same as before in terms of adding stubs, but rather than calibrating with input data you call torch.ao.quantization.prepare_qat and then run a regular training loop on the model.

    Modifying the model code itself is pretty painful as well, particularly if you are in an active, multi-collaborator code base. This led to the second evolution of PyTorch quantization:

    FX Mode Quantization

    The idea behind FX mode quantization was to give the same range of options as before, without having to change model code. This is particularly helpful when you have (say) a research team developing a model, and production team that is trying to make it fast for inference!

    The FX graph is a graph of operations created by tracing the model, and by working on the FX ops directly the library can make the quantization modifications while leaving the original code untouched. This works through pattern-matching in the FX graph and applying the transforms to add quant and dequant stubs automatically. The downside is it needs your model to be FX-traceable (as it would if you were using TorchScript), which often requires model changes.

    torch.ao.quantization.quantize_fx contains the methods, and they follow the same pattern: a quantization config, then prepare_fx, then convert_fx with options for running input through for calibration or running a training loop for QAT.

    FX tracing and TorchScript have been on the outs for a while due to their mix of complexity and inflexibility, and under the hood both this and the Eager variant use a quantized Tensor type which has been slated for deprecation. So, in their place we have…

    PT2E Quantization

    This one is not deprecated! But it is still changing a bit. The basic idea is very similar to FX quantization, except instead of using FX to capture the graph, we use PyTorch 2’s export feature. torch.ao.quantization.quantize_pt2e offer prepare_pt2e and convert_pt2e, again with an optional calibration/QAT step.

    One big difference is that the quantization config is now backend specific. Prior to calling prepare, you would set up the backend: e.g. for the XNNPack library:

    XNNPACKQuantizer().set_global(
        get_symmetric_quantization_config()
    )

    This allows different backends to define quantization set up based on the ops available to them.

    This does require capturing the full model graph with PT2 export, which can be painful, but if you’re going through the pain its best to do so with this flow rather than FX!

    If you don’t want to work with full graph capture, there is one other option which integrates with torch.compile:

    TorchAO Quantization

    Also not deprecated! The TorchAO library has a toolkit for quantization, and one particularly nice feature is autoquant:

    model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

    Underneath this will try different quantization schemes to identify the best size reduction vs accuracy reduction tradeoff.

    You can also quantize manually:

    quantize_(m, Int4WeightOnlyConfig(group_size=group_size))

    Here, the quantize function wraps the prepare/convert steps for you, so its pretty user friendly!

    Not sure what to use? Jerry’s table can help you distinguish, but if you’re quantizing for use in PyTorch (and using torch.compile) then prefer the TorchAO quantization. If you are quantizing as part of making the model available in non-Python environments then you’ll want to be able to export and use the PT2E flow.

  • Horace He’s thoughts on PyTorch

    https://www.thonking.ai/p/why-pytorch-is-an-amazing-place-to

    I read a the internal version of this (which also doubled as Horace’s leaving note!) and really glad to see he has published it broadly. Its a great look at working in open source and how it tied in to a (very successful) career as an engineer, as well as covering his thoughts on Thinky and the space as a whole!

    In high school, one of the things I feared most was that I would work on some project for 10 years and eventually realize that I’ve wasted my life improving something that nobody cared about. One of the greatest things about working on PyTorch is the certainty that I haven’t.

  • Ping-Pong Kernels on Hopper

    Deep Dive on CUTLASS Ping-Pong GEMM Kernel | PyTorch

    A useful deep dive on this performance technique. The TL;DR is, using warp specialization, set up a producer groupthat loads data (using TMA), and two consumer groups executing MatMuls on the Tensor core. When a consumer group finishes it executes the epilogue (e.g. copying the results elsewhere, but you could imagine doing something else on a Cuda core) while the other consumer group takes over the Tensor core. Hence, I presume, the name as Tensor core usage ping-pongs between the two consumers!

    The producer warp group focuses on producing data movement to fill the shared memory buffers (via TMA). Two other warp groups are dedicated consumers that process the math (MMA) portion with tensor cores, and then do any follow up work and write their results back to global memory (epilogue).

  • GRPO & Verifiable Rewards

    GRPO (Group Relative Policy Optimization) is an RL technique originally proposed in the DeepSeekMath paper. Instead of using a full-blown value network like PPO does, GRPO samples a group of completions for a given prompt and then computes a relative (normalized) reward for each output. The rewards are “verifiable” because they come from checking the final answer against ground truth and confirming. E.g. does the response follow the expected format (i.e. a <think>…</think> block for reasoning and an <answer>…</answer> block for the solution) and is the answer accurate against a predetermined fact. Not every problem fits this model, but there are a bunch that do, including math reasoning with the GSM8K dataset of grade-school math word problems. These look like this:

    “Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?”

    How Does the Training Work?

    1. Sampling Completions: For each prompt, the model generates a group of candidate completions. These are produced in inference mode (gradients aren’t collected) using a KV cache for speed (or a dedicated inference engine like VLLM)
    2. Verifiable Reward Calculation: Each completion is scored between 0 and 1—rewarding outputs that follow the prescribed format and yield the correct answer.
    3. Forward Pass for Gradients: Both the “policy” (the model being tuned) and a reference (typically the base, unmodified model) are used for a forward pass with the prompt and completions to compute per-token logits and log-probabilities.
    4. Loss and Backwards: The loss is then calculated as a combination of the (group-averaged) reward and a KL divergence term between the tuned model and the baseline, to constrain learning to similar responses. This loss is backpropagated through the policy model based on the earlier forward pass.

    Getting it going in TorchTune

    Over last weekend I hacked up a quick and dirty version of the training loop in the TorchTune, and over a couple of bus rides to Menlo Park cleaned it up into something that could work as a more general recipe(PR). Most of the work goes into the recipe and getting the dataset shaped properly to generate completions. This version—tested on a smaller model (the 1B Llama 3.2 variant, with LoRA)—showed some promising improvements in approach but I didn’t get to the point of having something converge enough to be confident in the overall recipe. In the DeepSeek R1 paper they had discussed trying a smaller model, but found 3B was the lowest they were able to get results on with some of their fine-tuning approaches.

    Luckily for everyone, at around the same time Ariel Kwiatkowski also put together a version that included distributed device support, making it easier to experiment on bigger models. This PR is more modular, and I’m excited to see it refined and landed so the recipe is widely available!

    There’s a growing energy around tools like torchtune, and it’s exciting to see how easy it is to “hack on” these ideas. It’s also great to see the techniques show up in other libraries, like HuggingFace’s TRL, which is being used as part of the OpenR1 replication effort!

  • Gradient Accumulation (was) busted

    This weekend I was reading the Tulu v3 paper (link), which offers a deep dive into building robust post-training setups. This is an very good resource for anyone aiming to build a really robust fine-tuning workflows. It covers critical elements like data set selection, synthetic data generation (with example prompts!), strategies for SFT and preference tuning, and various things they struggled with.

    One struggle was an issue with gradient accumulation they ran into where the loss was worse than without it on. The community at large also hit this, and fixed it, thanks to an excellent blog post by Unsloth (link).

    The bug

    Gradient accumulation is a technique used to simulate larger batch sizes by accumulating gradients over several smaller batches before performing a backward pass. This approach is particularly useful for managing memory constraints during training, so comes up a lot when post-training on a biggish model with more limited hardware.

    The problem arises when dealing with sequences of varying lengths within these mini-batches. In standard practice, the loss is calculated and normalized by the number of non-padded (i.e., valid) tokens in each sequence. However, when accumulating gradients across multiple mini-batches, each with different sequence lengths, the naive summation of gradients can lead to an incorrect total loss calculation.

    The discrepancy occurs because the cross-entropy loss function normalizes by the number of valid tokens, and this normalization factor can vary between mini-batches. When these normalized losses are accumulated without proper adjustment, the final loss does not match what would have been obtained using a single large batch. This results in a higher observed loss during training when using gradient accumulation compared to full batch training.

    Daniel and co at unsloth addressed this issue by developing a methodology that ensures the accumulated gradients are correctly scaled, accounting for the varying sequence lengths across mini-batches. This fix aligns the gradient accumulation process more closely with the theoretical foundations of full batch training, leading to more accurate loss calculations and improved training performance.

    Fixes and Workarounds

    Recent updates in both Hugging Face Transformers (pull request) and TorchTune (pull request) offer fixes. And, at least in Evan’s case, a little bit snark:

    In honor of the day the ML community first discovered the fact that (x1 / n1) + (x2 / n2) != (x1 + x2) / (n1 + n2)

    I really like seeing these small but practical problems pop up, and seeing the community rally around to fix them. I missed this when it happened in October, so glad to look back at it now!