Blog

  • How does Triton do Warp Spec?

    Kapil Sharma from the PyTorch team has a great series of posts diving into the Triton compiler process: 1, 2, 3. As covered there, Triton lowers to a series of intermediate representations, and each level has a set of transformational passes that implement optimizations. TTIR is the generic Triton IR and leverages a number of standard MLIR passes like common subexpression elimination, as well as some Triton specific passes like managing broadcast ops. That’s then lowered to TTGIR, a GPU-specific IR1

    triton/third_party/nvidia/backend/compiler.py at rc/3.3.x · triton-lang/triton

    The different backends configure the passes appropriate for their targets, so the Nvidia TTGIR configuration above details the passes for Nvidia hardware. Some are gated on the specific backend targeted, like warp specialization:

    passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
    opt.reg_dec_producer, opt.reg_inc_consumer)
    passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
    passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
    passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)

    Another example is using the Tensor Memory Accelerator for async loading on Hopper+

    if capability // 10 >= 9:
       nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
       nvidia.passes.ttnvgpuir.add_fence_insertion(pm)

    For a quick recap of the why and how of Warp Specialization, check Colfax’s guide to optimizing a GEMM:

    The most basic method by which GPU programmers can create overlapping is via excess warps (warps are groupings of 32 contiguous threads). Nvidia GPUs allow a large number of warps per SM (streaming multiprocessors), and can switch between them with minimal overhead. In particular, the warp schedulers can simply switch to another warp if one warp encounters a slow memory fetch. In order to give the warp schedulers more opportunity to hide latency, a technique called warp-specialization was introduced circa 2011 [1, 2]. With warp-specialization, some warps are dedicated to memory fetches (producers), while others are dedicated to compute (consumers), and named barriers are used for synchronization between them. The idea is that the warp schedulers can then more easily hide the latency of copy operations within compute (and vice-versa).

    Even more generally than overlapping memory loads you can overlap other kinds of work. SMs have 1 Tensor Core to 32 ALUs (Cuda Core) (x4 on recent hardware). This means that you can overlap other kinds of work, stuff that isn’t dot products. It’s really common to want to load memory, do a matmul then apply a pointwise function like a relu or other activation function. You aim to keep the Tensor Core as busy as possible with a series of matmuls, and warp specialization lets you do that.

    The transforms to implement this are implemented in triton/lib/Dialect/TritonGPU/Transforms at rc/3.3.x · triton-lang/triton

    The first task partitions ops in the kernel. This task looks for load ops and dot product ops, and partitions them into producer (loads) and consumer (process) groups.

    // Step 1. Select loads into the first task, which is the producer task by
    // default. Place dots into the second task, which is the consumer.
    // Only consider loads that are connected to a dot op in a loop.

    The next task is a bookkeeping one to propagates the task IDs, so if there are unlabeled ops they are attached one of the partitions.

    The WSDataPartition transform partitions the dot operations into consumer groups. It splits the dot products inside loops along M or N dimensions to be processable within a warp group, ensuring all dot operations are sliced and the slices labelled with task_ids.

    Just to look at some of the numbers: A warp consists of 32 threads, and a sync op (for loading) uses a whole warp. A warp group is a set of 4 warps: this is pertinent because the TensorCore MMA prefers 4 warps working on 64x(64/128/256)xK tiles. Triton already has a WarpGroupDotOp that tries to set this up, and that’s one of the operations targeted in this pass. The pass splits a Triton CTA tile, which may be 128×256, so that each consumer warp group has a 64 row (or 256 column) chunk.

     if (sliceSizeM >= 64) {
          LLVM_DEBUG({ LDBG("partition along M\n"); });
          partitionDim = 0;
          partitionSize = sliceSizeM;
          partitionOperand = opndA;
        } else if (sliceSizeN >= 256) {
          LLVM_DEBUG({ LDBG("partition along N\n"); });
          partitionDim = 1;
          partitionSize = sliceSizeN;
          partitionOperand = opndB;
        } else {
          LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN);
          return false;
        }

    The next pass is WSCodePartition. This is a big transform. It takes the task-sliced IR from the DataPartition and sets up the producer warp group to copy from global GPU mem to SMEM (or on blackwell TMEM). It also drops in barriers using product.acquire/cpmmit and consumer.wait/release to ensure proper ordering between the groups. The transform identifies “channels”, places where data is loaded (using load or descriptor_load for TMA on Hopper+) and associates the producer task_id with all the consumer task IDs that need to process that data.

    Conceptually, the transform is turning the loops in the original kernel into something like this:

    for k: # K-dimension tiles)
        ##### PRODUCER #####
        producer.acquire(token, idx, phase) # reserve smem[idx]
        async_copy_global_to_local(smem[idx]) # GMEM → SMEM[idx]
        producer.commit(token, idx) # make slot visible
    
        #####  CONSUMERS (run in parallel warps) #####
        ## Consumer 0 ##
        consumer.wait(token, idx, phase) # sleep until slot ready
        mma_sync(accum0, smem[idx], …) # read-only, do matmul
        consumer.release(token, idx) 
    
        ## Consumer 1 ##
        consumer.wait(token, idx, phase)
        mma_sync(accum1, smem[idx], …)
        consumer.release(token, idx)
    
        # Repeat for extra consumers. 
    
        # increment circular buffer
        idx   = (idx + 1) % numBuffers 
        # Toggle each time we hit producer, indicate old vs new data.
        phase = phase ^ (idx == 0)
    }

    The next pass is a generic pipeline pass. Each op is assigned a stage based on latency: e.g. slower ops like loads go into stage 0. This is then transformed with modulo scheduling. Any sync ops are converted to async variants (in lowerLoops), before writing out (in expandLoops) prologue, kernel and epilogue loops that contain all the relevant ops.

    Finally the WSLowering pass takes the various operators we have (like producer.acquire) and replaces them with the hardware specific variants (e.g. wait_barrier). It also handles the bookkeeping like generating the warp group and task_ids from the warp ID:

    OpBuilder builder(op);
    Value _4 = builder.create<arith::ConstantIntOp>(loc, WARPS_PER_TASK, 32);
    Value warpId = builder.create<ttng::GetCanonicalWarpIdOp>(loc);
    Value asyncTaskId = builder.create<arith::DivUIOp>(loc, warpId, _4);
    op.getResult().replaceAllUsesWith(asyncTaskId);

    This is a wordy IR way of saying

    warpId = gpu.thread_id().x / 32 
    task_id = warpId / 4

    Now the code is ready for lowering to regular PTX, and all of the warp-specific stuff is captured explicitly!

    1. Note: Because of some project weirdness, warp specialization is quite different in the release branches from main, so I’ll refer to 3.3 from here on. It’s in very active development (by teams at Meta, OpenAI and Nvidia!) so the specifics are quite likely to change over coming releases! ↩︎
  • The Mold Linker

    rui314/mold: Mold: A Modern Linker 🦠

    Having discussed this a bit at work, I finally got round to checking out the project. The README has good examples of what it is, and why to care.

    mold is a faster drop-in replacement for existing Unix linkers. It is several times quicker than the LLVM lld linker, the second-fastest open-source linker, which I initially developed a few years ago. mold aims to enhance developer productivity by minimizing build time, particularly in rapid debug-edit-rebuild cycles.

    Looking at fundamentals with a performance-oriented perspective and an awareness of modern systems continues to be a source of remarkable opportunities. At the same time, getting something so foundational established takes time and awareness.

  • LSP & Standards

    https://www.michaelpj.com/blog/2024/09/03/lsp-good-bad-ugly.html

    Having recently spent a lot more time around typecheckers, I’ve been reading about the Language Server Protocol that glues IDEs and language support services together. The article, from September last year, gives a breakdown of the good and the bad about the protocol, and is a really great dive into the broader topic.

    Much of the pain stems from how the protocol emerged and is managed:

    There is zero open discussion of features before they are added to the spec. Typically they are implemented in VSCode, and then the specification is updated as a fait accompli to document those changes. Implementers of open-source language servers get very influence on the development of the specification.6 There is not even a community space for implementers of language servers to get together and talk about the many tricky corners.

    I feel echoes of this in a lot of different projects I have been around, including PHP internals, ZeroMQ’s protocol, various CNCF working groups, PyTorch and Triton. Protocols and technologies emerge from a need, and grow because that need is shared, but transitioning from a narrow and highly connected problem source to a true standard is difficult: attempt to standardize and bring in voices too early and you just slow down progress to the point something else emerges which solves immediate needs better; leave it too long and the governance questions can be sufficient to encourage folks to rally around forks or alternatives.

    One example of that playing out at the moment is Anthropic’s (and dsp’s!) Model Context Protocol. Tim Kellog wrote a nice post the other day comparing it to OpenAPI, concluding:

    Standards are mostly sociological advancements. Yes, they concern technology, but they govern how society interacts with them. The biggest reason for MCP is simply that everyone else is doing it. Sure, you can be a purist and demand that OpenAPI is adequate, but how many clients support it?

    The reason everyone is agreeing on MCP is because it’s far smaller than OpenAPI. Everything in the tools part of an MCP server is directly isomorphic to something else in OpenAPI. In fact, I can easily generate an MCP server from an openapi.json file, and vice versa. But MCP is far smaller and purpose-focused than OpenAPI is.

  • Scalably Solving Assistant Games

    Scalably Solving Assistance Games | OpenReview

    Assistant games are an RL approach where the assistant and human cooperate on achieving a goal, and receive a reward signal for the joint effort. This paper proposes them as a better mechanism for aligning models in post-training than RLHF.

    Normally, RLHF is focused on single responses, or a single “turn” or interaction:

    Assistance games avoid the aforementioned drawbacks of RLHF by explicitly accounting for both the interactive nature of assistance and uncertainty about the user’s goal. In particular, an assistance game is a two player game in which an assistant and a user take actions in a shared environment. The two agents share a reward function, but crucially the assistant is initially uncertain about it.

    Assistance games remove incentives for deception since the assistant’s performance depends on the true latent reward function, rather than human feedback. They also incentivize the assistant to interact with the user to resolve its uncertainty about the reward function.

    The paper uses building structures in Minecraft as the learning environment and get some very positive results. They mention the possible applications for chatbot alignment as a post-script.

    Practically this requires, given chat history h, predicting:

    • the next assistant message (or tool call)
    • the next human message in response to that
    • how satisfied the human is with the response

    The algorithm does a tree search, trying various different replies and responses and picks the assistant action which showed up the best. They generally sampled ~100 actions in the paper.

    In the Minecraft example, they can see whether a placed or removed block moves the shape towards the human target, so they can give a reward score each step; doing the same thing with conversations might need some clever goal crafting or propagating back from only a final reward signal.

  • Richard Zou on torch.compile

    GPU MODE 4/19 Q&A – Google Docs

    PyTorch compiler engineer Richard Zou did a great Q&A session with the GPU Mode discord community recently. You can watch the session on YouTube, but Richard also collected questions into a doc with some nice snippets and references.

    Our value proposition: You can sit down for hours/days/weeks tuning a custom kernel. torch.compile provides good baseline performance so you don’t need to do that all the time!

    The goal with the compiler is that you can spend most of your time thinking about the model, get the majority of the speedups, and only have to go down to custom kernel authoring when you’ve established an opportunity or need for further performance.

  • Autotuning in PyTorch & Triton

    torch.compile offers some knobs for controlling the trade-off of execution performance with longer compile times. This is particularly useful for inference, where the same model will be running for a long time.

    model_autotune = torch.compile(model, mode="max-autotune")

    Passing the max-autotune option to instructs the compiler to test more options for the operations. The compiler has the option to use pre-built aten kernels, leverage kernels from libraries like CuDNN or Cutlass, or use templated Triton kernels. When autotuning, specific variants are tested on device with the shape information identified during tracing, and the fastest options are selected. Thanks to Triton templates, it can also use options like fusions where pointwise ops can be fused into a single kernel via a Triton template, saving kernel launch overhead.

    The downside of this is that testing the options takes more time, so using max-autotune can lead to some very extended compile times. You also need a hefty enough GPU to get the benefit: is_big_gpu gates it on the number of SMs, so it works best on a 3090, V100 or above.

    You can see a lot of the autotuning options in _inductor/config.py. Backends that are considered are set separately for GEMMs and convolution ops:

    max_autotune_gemm_backends = os.environ.get(
        "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
    ).upper()

    Each kernel has implementations using for the different backends which are added to possible choices. e.g. in _inductor/kernels/mm.py you can see calls to use_[backend]_template that verify whether the backend in question is a choice:

    if is_nonzero and use_cutlass_template(layout, m, n, k):
            CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

    _inductor/select_algorithm.py does the actual benchmarking through the choices.

    If you run autotuning, you’ll get some log output, and caches will be written to /tmp/torchinductor_yourusername.

    We can try this out on a simple MLP:

    import torch, time
    
    class SimpleMLP(torch.nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.linear1 = torch.nn.Linear(in_features, hidden_features)
            self.relu = torch.nn.ReLU()
            self.linear2 = torch.nn.Linear(hidden_features, out_features)
        def forward(self, x):
            return self.linear2(self.relu(self.linear1(x)))
    
    # Set up device and model
    device = 'cuda'
    model = SimpleMLP(in_features=1024, hidden_features=1024, out_features=1024).to(device)
    x = torch.randn(256, 1024, device=device)  # batch of 256, 1024 features each
    
    # Compile the model in default mode and max-autotune mode
    model_default = torch.compile(model, mode="default")
    
    # Warm-up runs (to trigger compilation)
    torch.compiler.reset()
    with torch.no_grad():
        model_default(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of default compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_default(x)
        end.record()
    torch.cuda.synchronize()
    time_default_ms = start.elapsed_time(end) / 50.0
    torch.compiler.reset()
    
    model_autotune = torch.compile(model, mode="max-autotune")
    
    with torch.no_grad():
        model_autotune(x)
    torch.cuda.synchronize()  # ensure warm-up completes
    
    # Measure performance of max-autotune compiled model
    start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
    with torch.no_grad():
        start.record()
        for _ in range(50):
            _ = model_autotune(x)
        end.record()
    torch.cuda.synchronize()
    time_autotune_ms = start.elapsed_time(end) / 50.0
    
    print(f"Average inference time - torch.compile default: {time_default_ms:.3f} ms")
    print(f"Average inference time - torch.compile max-autotune: {time_autotune_ms:.3f} ms")
    

    Disappointedly, this is the result:

    Average inference time - torch.compile default: 0.113 ms
    Average inference time - torch.compile max-autotune: 3.251 ms

    We can turn on logging with the TORCH_LOG env variable: some useful options are inductor, autotuning, and perf_hints.

    TORCH_LOGS="perf_hints" python tune.py

    You can control many more autotune options via the options flags, though its incompatible with passing a mode value. We can recreate the max-autotune mode, and turn on some useful tracing options like this (note that the options version uses an underscore, the mode a hypen!)

    model_autotune = torch.compile(
    model,
    options={
    "max_autotune": True,
    "triton.cudagraphs": True,
    "coordinate_descent_tuning": True,
    "trace.enabled": True,
    "trace.graph_diagram": True,
    },
    )

    Options "trace.enabled": True, "trace.graph_diagram": True generate trace outputs, and output a nice diagram of the captured graph. Cudagraphs turned out to be the culprit here, which is common enough there is a non-cudagraph mode available to stop you having to remember all the options:

    model_autotune = torch.compile(model, mode="max-autotune-no-cudagraphs")

    As you can see here in the graphs of with and without, the slower version actually has an extra fusion performed!

    Captured graphs for the two runs

    Triton Autotuning

    Triton also conducts autotuning, but it’s a little more explicit. When authoring a Triton kernel you can specify configurations. At compile time each config variant will be tested, the most performant one picked and the choice stored for future calls. A key value can be provided to indicate when to re-autotune based on changing inputs:

    import os
    import torch
    import triton
    import triton.language as tl
    
    # Just to save passing this on the command line
    os.environ["TRITON_PRINT_AUTOTUNING"] = "1"  
    
    @triton.autotune(
        configs=[
            triton.Config({'BLOCK_SIZE': 128}, num_warps=4,  num_stages=2),
            triton.Config({'BLOCK_SIZE': 256}, num_warps=8,  num_stages=2),
        ],
        key=['N']            # re‑tune only if the length N changes
    )
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        pid   = tl.program_id(0)
        offs  = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        x     = tl.load(x_ptr  + offs, mask=mask, other=0.0)
        y     = tl.load(y_ptr  + offs, mask=mask, other=0.0)
        tl.store(out_ptr + offs, x + y, mask=mask)
    
    def vec_add(x: torch.Tensor, y: torch.Tensor):
        assert x.is_cuda and y.is_cuda
        N   = x.numel()
        out = torch.empty_like(x)
        grid = (triton.cdiv(N, 128),)         # 128 = smallest BLOCK_SIZE we declared
        vecadd_kernel[grid](x, y, out, N)    
        return out
    
    x = torch.randn(1 << 20, device="cuda")   # 1 048 576 elements
    y = torch.randn_like(x)
    
    _ = vec_add(x, y)  # first call → autotuning prints to stdout
    _ = vec_add(x, y)  # second call → no autotuning, uses the best config found
    

    Setting the env variable TRITON_PRINT_AUTOTUNING documents the process as it goes:

    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 256, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
    Triton autotuning for function vecadd_kernel finished after 0.44s; best config selected: BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;

    You can use the same do_bench tester that the autotuner does, and see how the performance varies yourself:

    import torch, triton, triton.testing as tt
    import triton.language as tl
    
    @triton.jit
    def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
        offs  = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask  = offs < N
        tl.store(out_ptr + offs,
                 tl.load(x_ptr + offs, mask=mask) +
                 tl.load(y_ptr + offs, mask=mask),
                 mask=mask)
    
    # tensors
    N   = 1 << 20
    x   = torch.randn(N, device='cuda')
    y   = torch.randn_like(x)
    out = torch.empty_like(x)
    
    def bench(block_size, num_warps):
        grid = (triton.cdiv(N, block_size),)
        # tt.do_bench returns [median, p20, p80] in micro‑seconds
        return tt.do_bench(
            lambda: vecadd_kernel[grid](x, y, out, N, BLOCK_SIZE=block_size, num_warps=num_warps),
            warmup=5, rep=16, return_mode="all", quantiles=(0.5, 0.2, 0.8)
        )
    timings = {
        "128/4": bench(128, 4),
        "256/8": bench(256, 8),
    }
    
    print("timings:", timings)
    

    Running that gives shows that both kernels are basically equivalent, but the first one is slightly faster over the 16 runs.

    timings: {'128/4': [0.01945599913597107, 0.01945599913597107, 0.02028159946203232], '256/8': [0.01945599913597107, 0.01945599913597107, 0.020479999482631683]}

  • Profiling Triton

    There are a couple of different options to profile a Triton kernel.

    Proton

    Proton is the profiler that ships with Triton (profiler for triton). You can enable it and (optionally) activate/deactive around specific regions you want to profile. You have the ability to annotate functions with specific metrics as well.

     session = proton.start()  # Start profiling session
    
    bias = torch.rand((256,), device='cuda', dtype=torch.float16)  # Bias vector
    flops = 2 * M * N * K
    bytes_accessed = A.element_size() * M*K + B.element_size() * K*N + C.element_size() * M*N  # rough bytes
        with proton.scope(f"fused_gemm_bias_relu [M={M}, N={N}, K={K}]", {"flops": flops, "bytes": bytes_accessed}):
            fused_gemm_bias_relu[grid](  
                A, B, C, bias, 
                M, N, K, 
                A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1),
                BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
            )
    
     proton.finalize() 

    The output can be visualized with the built-in viewer:

    proton-viewer -m time/ms,tflop/s ./proton.hatchet

    0.040 6.645 ROOT
    ├─ 0.004 nan _ZN2at6native55_GLOBAL__N__11f7a751_22_DistributionUniform_[...]_15PhiloxCudaStateESH_SI_
    └─ 0.037 7.284 fused_gemm_bias_relu [M=1024, N=256, K=512]
    └─ 0.037 nan fused_gemm_bias_relu

    In this case you can see both the (trimmed!) generated name for the bias tensor set up as well as the name of my custom kernel.

    nsight-compute

    Nvidia also have a good range of tools for looking at performance. Note will need to enable access to the counters on device for this:

    NVIDIA Development Tools Solutions – ERR_NVGPUCTRPERM: Permission issue with Performance Counters | NVIDIA Developer

    On the offchance you’re doing this on WSL, https://peterchng.com/blog/2024/03/02/profiling-cuda-programs-on-wsl-2/ walks through the set up!

    Nvidia ships nsight system which tracks larger system wide metrics, and nsight compute which is more focused on profiling execution. You can run it against a script like so:

    ncu -o profile_results python test.py

    The tool comes with a nice GUI for inspecting the results. It can show you the PTX or SASS source for the kernels, offers metrics like actively used registers (good for checking on register spilling), and calls out warnings on poor utilization or memory clashes.

    Upcoming intra-kernel profiler

    [tool][proton] Intra kernel profiling support by fywkevin · Pull Request #4861 · triton-lang/triton

    There is an extension coming for Proton that enables profiling within kernels. This reserves a pre-allocated buffer on device and logs metrics locally, for reading out at the end of the execution. It outputs as a chrome trace for use within a wide range of dev tools. While this isn’t merged into mainline yet, you can see an example of the usage in the dev repo.

  • JEPA

    JEPA is an example of a predictive coding architecture. Predictive coding is the idea that the brain works by predicting the next sensory input and learns from the error between the prediction and the actual input. Hierarchies of this kind of prediction allow higher level elements to predict the outputs of lower-level elements, building up deeper and more complex semantics.

    The core idea in JEPA is to take two related things (say consecutive frames of a video) x and y, encode each of them into a latent space (an embedding), and then predict the embedding for y s(y) based on s(x). The encoders can be Transformer-based models — in practice models like I-JEPA have trained the x encoder and updated the weights via a moving average for the y (target) encoder.

    The learning is not based on how well the end-result predicts the target e.g. how close the pixels of the next frame are predicted. Instead, it’s based on how well the latent representation of the next frame is predicted.

    The advantage of working in the latent space for the prediction is the model can choose what level of detail it wants to capture, discarding some aspects and focusing on more foundational concepts. This helps build a more robust world-model, with the hope being that training in this way will then allow easier generalization to more tasks, with less data required

    Similarities

    This is somewhat similar to autoencoders. Autoencoders take an input, compress it in a latent space, then reconstruct the original from the latent space and propagate back the error. JEPA does a similar process across two different items with separate encoders, and only cares about error within the latent space.

    Contrastive models embed two different items into the same space and try to increase similarity between the embeddings for things known to be similar and make them dissimilar to other items. This is used in CLIP and other multimodal text-image encoders, where the text and the image embed to the same space so that a text caption and a matching image are close in embedding space. This requires a lot of pairwise comparisons, while JEPA is a more straightforward s(x)->s(y) prediction in training.

    Challenges

    Because JEPA models leave you with a latent they need to be paired with a generator for getting an observable/human viewable output, which is a per-domain challenge. This makes it harder to evaluate how well the model is learning, beyond measuring loss.

    Training stability can also be tricky — it is possible for the model to collapse and learn trivial representations to minimize prediction error. Even without complete collapse it can require some experimentation to ensure the model is learning a deep enough conceptual level. For example, I-JEPA, which worked in image space, found that using large enough masked patches was important to ensure the model captured sufficient detail.

  • Free-Threaded Python community

    As Thomas posted on the Python Discuss board there is a discord for discussing Free-Threaded/NoGil Python: https://discord.gg/rqgHCDqdRr

    Other than the helpful docs the Quansight folks maintain (py-free-threading) it’s been interesting to see some projects and tools pop up on there I hadn’t see. One being Zsolt’s py-free-threading which checks whether your project and deps have FT wheels for any non-pure python deps, which can be run as a 1-liner thanks to uv:

    uvx python-ft-deps

  • Colfax on Blackwell GEMMs

    CUTLASS Tutorial: Writing GEMM Kernels Using Tensor Memory For NVIDIA® Blackwell GPUs – Colfax Research

    Dives deep into TMEM into particular, and the trend over the last few Nvidia generations of special-casing GEMMS in hardware:

    Tensor Memory and UMMA do for MMA just what TMA did for copy, making it a single-threaded, asynchronous operation that does not consume registers. As a result, registers can primarily be used for other tasks like scheduling and fused epilogue operations.

    Edit: link no longer seems to be working! It was a great post though, so hopefully comes back! Edit edit: it did!

  • MegaScale-infer: disagg MoE inference

    https://arxiv.org/abs/2504.02263

    LLMs have most of their parameters in the FFN parts of the transformer layers — 50+bn params of the Llama 3 70b model, for example. The compute and memory requirements are a bit different between the FFN and attention parts of the model: attention requires a different KV cache for each request, so attention tends to be memory bound while the dense FFNs tend to be compute bound.

    Because of this it’s pretty common to split up tasks at inference time. The initial prefill stage (processing the initial prompt) populates the KV cache for the following autoregressive decoding. The decode can be more aggressively batched for getting better utilization. vLLM really helped popularize this idea!

    ByteDance extend this idea for mixture of expert models. In MoEs the compute intensity of the FFNs is limited by needing to load different experts, and having only a proportion of tokens going through a given expert. They extend the disagg idea to go from M “attention” GPUs to N (fewer!) expert GPUs, with a larger batch size for each of the expert calls. This gets better utilization on the matmuls and lowers overall cost of serving. The natural structure of transformer layers alternating attention and FFN lends itself well to a ping-pong pipelining approach that lets them hide the comms overhead.

    We present MegaScale-Infer, an efficient and cost-effective system designed for large-scale MoE serving. MegaScale-Infer disaggregates the attention and expert modules, assigning them to separate GPUs—a strategy we term disaggregated expert parallelism. Our approach offers two major benefits. First, it enables independent scaling of each module with customized model parallelism strategies. Specifically, attention modules are replicated using data parallelism, while FFN modules are scaled with expert parallelism. By consolidating requests from multiple attention replicas, the GPU utilization of each expert increases significantly as the batch size per attention replica grows. Second, it enables the deployment of attention and FFN modules on heterogeneous GPUs to fully leverage their different capabilities and achieve lower costs. For example, attention modules can be deployed on GPUs with more cost-effective memory capacity and bandwidth, while FFN modules can utilize GPUs with more affordable compute capability. As shown in Figure 1(c), FFN can easily become compute-intensive in MegaScale-Infer, while attention achieves higher GPU utilization per cost under heterogeneous deploymen

  • Pydantic Evals

    https://ai.pydantic.dev/evals/

    Ed Yang was recently recommending keeping your own benchmark of LLM evals, so you can test newer models on problems that they have struggled with in the past. I have recommended similar things to people, but there is some barrier to entry into knowing how to start. Ed references (and forks) Nicolas Carlini’s personal benchmark repo, but its nice to have some light(ish) weight options too.

    Pydantic Evals is a powerful evaluation framework designed to help you systematically test and evaluate the performance and accuracy of the systems you build, especially when working with LLMs.

    You can install the library with uv or pip:

    uv add pydantic-evals
    

    I tried it out with a strawberry test, calling openrouter with different models. I needed a custom eval as the default Contains is a bit rigid, but the approach seems nice!

    import os
    import asyncio
    from dataclasses import dataclass
    from pydantic_evals import Case, Dataset
    from pydantic_evals.evaluators import Evaluator, EvaluationReason, EvaluatorContext
    from pydantic_evals.evaluators.common import _truncated_repr
    from openai import OpenAI
    from typing import Any, Optional, cast
    @dataclass
    class FlexibleContains(Evaluator[object, object, object]):
        """
        Check if the output contains any one of the expected options.
        """
        value: Any
        case_sensitive: bool = False
        def evaluate(
            self, ctx: EvaluatorContext[object, object, object]
        ) -> EvaluationReason:
            failure_reason: Optional[str] = None
            # Normalize value into a list of options if it isn't already a list or tuple.
            options = self.value if isinstance(self.value, (list, tuple)) else [self.value]
            output_str = str(ctx.output)
            if not self.case_sensitive:
                output_str = output_str.lower()
            match_found = False
            for opt in options:
                opt_str = str(opt)
                if not self.case_sensitive:
                    opt_str = opt_str.lower()
                if opt_str in output_str:
                    match_found = True
                    break
            if not match_found:
                failure_reason = (
                    f"Output string {_truncated_repr(output_str, max_length=100)} does not contain "
                    f"any of expected strings: {[str(opt) for opt in options]}"
                )
            return EvaluationReason(value=match_found, reason=failure_reason)
    strawberry = Case(
        name="strawberry",
        inputs="How many rs are in strawberry?",
        evaluators=[FlexibleContains(value=["3", "three"])],
        metadata={"difficulty": "easy"},
    )
    dataset = Dataset(cases=[strawberry])
    MODELS = [
        "anthropic/claude-3.5-sonnet",
        "openai/gpt-4o",
        "meta-llama/llama-4-maverick:free",
        "meta-llama/llama-4-scout:free",
        "openrouter/optimus-alpha",  # secret model!
    ]
    def generate_completion(inputs: str, model: str) -> str:
        """Generate a completion using OpenRouter with specified model"""
        client = OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant."},
                {"role": "user", "content": inputs},
            ],
            max_tokens=50,
            temperature=0.7,
        )
        return response.choices[0].message.content.strip()
    def evaluate_models():
        """Run evaluations across multiple models"""
        for model in MODELS:
            print(f"\nResults for model: {model}")
            print("=" * 50)
            # Wrap the synchronous generate_completion in an async function:
            async def model_specific_generate(inputs: str) -> str:
                loop = asyncio.get_running_loop()
                return await loop.run_in_executor(None, generate_completion, inputs, model)
            # Run evaluation for this model
            report = dataset.evaluate_sync(model_specific_generate)
            # Print results for this model
            report.print(include_input=True, include_output=True, include_durations=False)
    def main():
        evaluate_models()
    if __name__ == "__main__":
        main()
    

    To give a trimmed output:

    Results for model: openrouter/optimus-alpha
    ==================================================
                                          Evaluation Summary: model_specific_generate
    ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    ┃ Case ID    ┃ Inputs                         ┃ Outputs                                                   ┃ Assertions ┃
    ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    │ strawberry │ How many rs are in strawberry? │ The word **"strawberry"** contains **three** letter "r"s. │ ✔          │
    ├────────────┼────────────────────────────────┼───────────────────────────────────────────────────────────┼────────────┤
    │ Averages   │                                │                                                           │ 100.0% ✔   │
    └────────────┴────────────────────────────────┴───────────────────────────────────────────────────────────┴────────────┘
  • Bonanza Build System

    https://blogsystem5.substack.com/p/bazel-next-generation

    JMMV shares this thoughts on build beyond bazel, and highlights Ed Schouten (of BuildBarn) and his experiments with Bonanza. Bonanza maintains Bazel compatibility, but is remote execution first, does analysis on the remote workers with a distributed cache to minimize cold builds, and has starlark-only for rules (as with Buck2). Also like Buck2, its written in Rust.

    When I was at Lyft we went through a laborious build system migration, including on Android going from Gradle to Buck(1) to Bazel, which at one point involved a full shim that allowed single build files to work on either Buck or Bazel. The idea of being able to keep the same build definitions but swap out engines is pretty appealing.

    Julio actually calls for yet-another Bazel replacement as well, this one more focused on the small project/local build case: I can definitely see the appeal there!

    The time for these next-generation Bazel-compatible build systems is now. Google has spent the last 10 years Starlark-ifying Bazel, making the core execution engine replaceable. We are reaching a point where the vast majority of the build logic can be written in Starlark as Bonanza proves, and thus we should be able to have different build tools that implement the same build system for different use cases

  • Improving Recommendation Systems with LLMs

    https://eugeneyan.com/writing/recsys-llm/

    Eugene Yan has put together a really extensive survey of recent research exploring the use of LLMs in recommendation systems.

    Although early research in 2023—that applied LLMs to recommendations and search—often fell short, these recent efforts show more promise, especially since they’re backed by industry results. It suggests that there are tangible benefits from exploring the augmentation of recsys and search systems with LLMs, increasing performance while reducing cost and effort.

    Recommendation systems are enormously important to a large swathes of tech business, primarily for e-commerce, content and advertisement targeting. Traditional deep recommenders typically use a two-tower architecture: one tower for users and another for items, independently encoding features into embeddings that can be scored together to retrieve and rank items. Features in each tower include both sparse (usually categorical, e.g., item categories, user histories) and dense (often continuous, e.g., age, price).

    This design is popular because its effective and scalable: you can cache each tower’s embedding vectors and only pull in the ones you need for a given query (e.g. the batch of users you are getting recommendations for right now).

    Despite the effectiveness and scalability of this approach, traditional systems often struggle with a set of known issues, such as cold-start problems—predicting relevant content for new items or users—and typically don’t consider interaction recency without additional engineering.

    Yan categorizes recent research into four areas:

    1. LLM/Multimodal Architectures:
      • Directly embedding content understanding within the models. Content understanding has been used for a long time via separate models to generate additional metadata for content items to help both with cold start and accuracy.
      • Generative approaches, which reframe recommendation as predicting future user actions based on interaction sequences.
    2. LLM-Assisted Data Generation and Analysis:
      • Improving content understanding and generating richer metadata for items.
    3. Scaling Laws, Transfer, and Distillation:
      • Adapting LLMs to meet latency requirements of recommendations, through smaller models and efficient inference techniques. RecSys, particularly models for advertising, tend to have very low latency requirements.
    4. Unified Architectures for Search and Recommendations:
      • Consolidating search and recommendation tasks into unified models that enable returning items based on interaction histories and/or user queries simultaneously.

    There are a couple of common themes from reading the summaries:

    • Semantic Content Integration & Joint Tasks: Techniques like YouTube’s Semantic IDs and Kuaishou’s M3CSR generate content-based identifiers replacing traditional hashed IDs. The idea is to have inputs to the models represent the content in a way that carries meaning, rather than represent an identifier for the content.
    • Efficiency in Inference: Teacher-student distillation and efficient fine-tuning allow generating smaller, performant models for specific needs. For instance, Alibaba’s MLoRA trains a base model then LoRA fine-tunes for specific types of content, replacing a number of independently trained models.

    These two combine somewhat to enable a trend towards more foundation- model-like training in RecSys that tackle a variety of user personalization tasks with a unified view of users, content, and user/content interactions.

  • Dynamic Shapes in PyTorch

    Dynamic shapes are one of the more distinctive parts of torch.compile. Rather than specializing a graph to static shapes (which works in many cases!), PyTorch’s approach allows a single graph to work for a variety of sizes, so things like sequence length or batch size can vary. It does this by reasoning about shapes symbolically: instead of using fixed shape values, it uses placeholders and infers rules that constrain those shapes.

    Tracing & Symbolic Shapes

    PyTorch uses tracing in Python (via Dynamo) to capture the graph of operations. By default, it marks shapes as static during tracing. If a shape marked as static changes at runtime, it is marked as dynamic and treated symbolically in a recompilation. You can also proactively mark a dimension as dynamic to encourage symbolic treatment from the start:

    torch._dynamo.mark_dynamic(x, 0)  # Mark dim 0 as dynamic

    Under the hood, PyTorch uses SymPy to represent and manipulate symbolic shapes. Each dynamic shape is replaced with a SymInt and tracked by a central ShapeEnv.

    Every operation in PyTorch has a meta function — a lightweight implementation that computes metadata like shape changes without actually performing the computation. This lets PyTorch propagate symbolic shapes through the graph. For example, concatenating two tensors along dimension zero is represented symbolically as:

    s0 = s_x0 + s_y0

    To support branching logic, symbolic shapes carry a “hint” of their current concrete value. This allows specific branches of conditionals like if tensor.size(0) > 2 to be taken during tracing based on the hint. PyTorch adds a guard at this point to ensure that the resulting graph is only used if that branch is the correct one.

    Guards

    Guards are runtime checks inserted into the compiled graph to ensure the assumptions made during tracing still hold. For example, in the case of tensor.size(0) > 2, if the tensor is the result of concatenation, the guard will check that a.size(0) + b.size(0) > 2. If this fails, the code is retraced, and a graph for the new branch generated. Multiple graphs can be cached and selected at runtime based on guard validation.

    Guards don’t need to assert exact sizes; they can use symbolic constraints like x.size(0) > 2. This allows dimensions to vary within bounded ranges. The backend compiler (usually Inductor) can then compile code that operates over symbolic dimensions, as long as the variability is within the guarded constraints.

    For example, operations like broadcasting typically generalize well to symbolic shapes. In contrast, if an op specializes on a fixed shape (e.g., optimized path for 1D input), it may require conditional tracing and guards.

    What this means in practice is that most of the time compilation will follow this process:

    • Take a batch of data, assume all shapes are static, insert guards, and pass static sizes to the compiler
    • On the next batch see which guards have been violated, mark those dimensions as dynamic, add appropriate guards and pass symbolic dimensions to the compiler
    • Assuming no control flow, continue to reuse this dynamic graph for without recompilation

    Backed vs. Unbacked SymInts

    Most symbolic shapes are backed, meaning they have an associated concrete value at trace time. These are usually derived from inputs and show up in traces as s0, s1, etc.

    Unbacked SymInts lack a concrete value. These arise from data-dependent operations, e.g.:

    n = (x > 1).nonzero().size(0)

    Here, n depends on the data in x, so its size cannot be known at trace time. It will be represented as an unbacked SymInt like u0.

    If a control flow decision depends on an unbacked SymInt, tracing cannot proceed, resulting in a graph break or a GuardOnDataDependentSymNode error (when full_graph=True).

    However, you can guide the compiler with additional constraints, e.g.:

    torch._check(x.size(0) == y, lambda: f"size mismatch: {x.size(0)} != {y}")

    This lets PyTorch treat x.size(0) as equivalent to y throughout the graph. The check will be validated at runtime.

    There are other APIs to help mark unbacked SymInts as size-like to enable meta function compatibility (see Ed’s docs for more).

    Controlling Dynamic Shape Usage

    You can control dynamic behavior in torch.compile with the dynamic flag:

    • Not passed: default shape inference behavior
    • dynamic=false: force all shapes to be static
    • dynamic=True: treat all shapes as dynamic

    The default is usually best, but dynamic=True can help in testing.

    Use full_graph=True to attempt to generate a single, complete graph without breaks. This is often critical for performance, as graph breaks can drastically affect runtime and it’s easy to make innocuous looking code changes that can trigger additional breaks!