Category: til

Today I Learned, and other useful tips

  • Native DSLs Ops in PyTorch

    You may have noticed that FlashAttention 4 was supported in PyTorch really quickly. That required a bit of new infrastructure: torch.native by Simon Layton. Prior versions of FlashAttention were written in Cutlass/C++, but for FA4 the team implemented the kernel in CuteDSL.

    Edit: Simon kindly pointed the FA4 work integration work predated his formalization of this pattern and was the impetus for it: long-time-SDPA maintainer Driss landed the change. As always PyTorch takes a village and I am glad for everyone’s contribution!

    You wouldn’t think that using an embedded Python DSL in a Python based Ml framework would be a challenge, except that almost all of the stuff that does ML in PyTorch is in fact… not written in Python. Replacing a PyTorch operator meant shipping a new native kernel and dealing with the build and dispatch pipeline.

    Layton’s change opened the door to overriding default ops with ones authored in a embedded DSL, initially Triton or CuteDSL.

    To be clear, this is not a replacement for custom ops, which most of the time is the best way of adding a new operator. torch.library.triton_op already lets you register a customer Triton kernel, for example. But FA4 is the kind of situation where wended an an alternative: it’s the right path for newer GPUs, it’s written in CuteDSL, and the PyTorch team wanted it to be available quickly to all PyTorch users without modifying their models.

    To give an example, we can replace the built-in aten::_fused_rms_norm with a Triton version1:

    """
    Triton kernel for fused RMS normalization.
    RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
    """
    import triton
    import triton.language as tl
    import torch
    @triton.jit
    def _rms_norm_fwd_kernel(
    X_ptr,
    W_ptr,
    Y_ptr,
    RRMS_ptr, # reciprocal RMS, saved for backward
    stride_x_row,
    N_COLS: tl.constexpr,
    eps: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    ):
    # [...]
    tl.store(RRMS_ptr + row_idx, rrms)
    def triton_rms_norm_forward(
    x: torch.Tensor,
    normalized_shape: list[int],
    weight: torch.Tensor | None,
    eps: float | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
    """Fused RMSNorm forward pass using Triton."""
    # [...]
    return y.reshape(orig_shape), rrms

    Actually hooking it up requires calling a DSL-specific op override function, in this case triton_utils.register_op_override. This goes directly into the dispatch architecture, which means it works with autograd, torch.compile and so on.2

    """
    Register a Triton-based RMSNorm as a native op override using torch._native.
    """
    from torch._native import triton_utils
    def _triton_fused_rms_norm(dispatch_keys, x, normalized_shape, weight, eps):
    """
    Wrapper that lazily imports the Triton kernel on first call.
    """
    from triton_kernels import triton_rms_norm_forward
    return triton_rms_norm_forward(x, normalized_shape, weight, eps)
    def register():
    """Register the Triton RMSNorm override."""
    triton_utils.register_op_override(
    "aten", # lib_symbol: override an aten op
    "_fused_rms_norm", # op_symbol: the specific op
    "CUDA", # dispatch_key: only on CUDA
    _triton_fused_rms_norm, # impl: our wrapper
    unconditional_override=False, # receives dispatch_keys as first arg
    )

    Now when we call torch.ops.aten._fused_rms_norm(x, shape, weight, eps) PyTorch will automatically use our Triton override!3

    The unconditional_override param in the registration call is a helpful one: if false the function receives torch.DispatchKeySet as its first argument. This allows overriding only in specific circumstances. For example, our Triton kernel is faster than the C++ one only for larger shapes, so we could gate the decision on that:

    def _smart_rms_norm(dispatch_keys, x, normalized_shape, weight, eps):
    n_rows = x.numel() // normalized_shape[-1]
    if n_rows < 4096:
    # Fall back to default C++ kernel for small shapes..
    return torch.ops.aten._fused_rms_norm.default(x, normalized_shape, weight, eps)
    from triton_kernels import triton_rms_norm_forward
    return triton_rms_norm_forward(x, normalized_shape, weight, eps)

    Going back to FlashAttention4, this overrides aten::_scaled_dot_product_flash_attention so that any code using torch.nn.functional.scaled_dot_product_attention will transparently get the FA4 implementation.

    torch._native fundamentally lowers the barriers to entry for bringing new kernel implementations into PyTorch. That’s good for mainline PyTorch, and it also allows ML infrastructure teams to ship optimized kernels for new hardware without waiting for the PyTorch release cycle.

    1. Trimmed for length, ask your favorite coding agent to write you an RMSnorm kernel, or look at the gist. ↩︎
    2. To avoid hammering the import latency, all DSL runtimes are lazily loaded when the kernel is first called ↩︎
    3. In this case we only override for CUDA tensors, so CPU ops will continue to be handled by the default implementation. ↩︎
  • TIL: You don’t need PIL for decoding images with TorchVision

    pytorch.org/vision/main/generated/torchvision.io.decode_image

    The always busy Nicolas Hug was sharing this at work, and I hadn’t realized just how comprehensive the image decoding support had become in TorchVision. Over the last year TorchVision has added a lot of image decoding capabilities and got a better entry API. It should generally be faster than PIL now (with the exception of animated GIFS).

    Rather than decoding with PIL:

    from PIL import Image
    # Load the image
    image_path = "chungus.png"
    image = Image.open(image_path)

    You can use the built in decoders like this:

    from torchvision.io import read_file, decode_image
    # Load the image as a tensor
    image_path = "chungus.png"
    image_data = read_file(image_path)
    image_tensor = decode_image(image_data)

    TorchVision’s transforms support PIL transparently, so you might be using it when not intending! Relatedly, you’ll want to use the v2 transforms if you happen to be using the older versions.

    In general this complements the release of TorchCodec which has been improving decoding for video – you now have a really good range of options for decoding media in a PyTorch native way!

  • Functionalization in PyTorch

    Functionalization in PyTorch: Everything You Wanted To Know – compiler – PyTorch Developer Mailing List

    Over a year old, but a very in depth breakdown from Brian Hirsh of how AOTAutograd functionlizes – e.g. removes mutations from – various graphs, what that enables, and what kind of edge cases exist. Inductor as a backend can handle mutation, but many other situations (including export!) can’t. It got bumped up because of a question on exactly that!

    torch.export uses functionalization. In particular, when you export for inference, you’ll get out a functionalized ATen graph!

  • TIL: weights-only model loading will be the default in PyTorch 2.6

    I had missed this, but weights-only is going to be the default for torch.load in Pytorch 2.6:

    https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573

    This is one of those small-sounding changes which requires quite a lot of follow-through to actually land. The default torch.load supports pickled Python code, so allows for arbitrary code execution: very helpful in a lot of cases (hence the many places that need special consideration!), but, particularly these days when many users may be trying models of fairly unknown provenance, a source of ongoing security concerns. Making that behavior an explicit opt-in is a great win for the wider community. HuggingFace have done some good work in this area too with their safetensors project, and having the core safe-by-default is a very welcome addition!

  • TIL: torchdbg

    https://github.com/ezyang/torchdbg

    Step by step debugging through a PyTorch program and see the underlying operators and shapes. Helpful for getting a view of the graph and shapes – just annotate the code with with torchdbg.LoggingMode(): and add TORCH_TRACE=./log to dump the logging file. Comes with a handy viewer.

  • TIL: TunableOp in PyTorch

    I wasn’t aware of this particular autotuning lever! There is a breakdown of TunableOp on the AMD blog from back in July:

    https://rocm.blogs.amd.com/artificial-intelligence/pytorch-tunableop/README.html

    Instead of using the default GEMMs, TunableOp will search for the best GEMMs for your specific environment. It does so by first querying the underlying BLAS library for a list of all solutions for a given GEMM, benchmarking each of them, and then selecting the fastest. TunableOp then writes the solutions to disk, which can then be used on subsequent runs. 

    Though the infrastructure is generic, this is effectively an AMD-specific tuning tool right now, as mentioned in the original docs.

    Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of PyTorch will function correctly when using TunableOp but the only solution available to CUDA builds is the ‘Default’ implementation i.e. the original cuBLAS default, now called through TunableOp.

  • Chinese Tech Term Glossary

    Very interesting list, via Jeff Ding at ChinAI.

    from Chinese primary sources on technology and security, with expert translations and annotations by CSET’s translation team. It’s created and maintained by Ben Murphy with support from the Emerging Technology Observatory.

    https://docs.google.com/spreadsheets/d/15MS8Qp9U-KOaoQF0R_e7lVxF3UBMyZFpC6-cohhnsD0/htmlview#gid=0

  • TIL: When does PyTorch upgrade Python versions

    The policy is in the RELEASE documentation:

    PyTorch supports all minor versions of CPython that are not EOL: https://devguide.python.org/versions/

    This is a little more consistent than how it was handled in the past, with annual upgrades and deprecations to match the cpython release schedule.

  • Unsupported: dynamic shape operator: aten.nonzero.default with boolean masks in torch.compile

    The error message you get actually tells you the fix, but I found it non-intuitive to what I was doing enough I was hesitant to actually just try the config:

    torch._dynamo.config.capture_dynamic_output_shape_ops = True

    The general issue is capturing shapes on scalars isn’t turned on by default due to various issues, but for your case it may actually work. It is also interesting to see where TorchVision hit this, and worked around with torch.where instead.

  • TIL: ROCM is actually open source

    I think I did know this at some point, but I was reminded today that unlike the (sometime) black box that is CUDA https://github.com/ROCm/ROCm is actually available on Github, which is operationally much nicer!

    I also recall learning that NCCL, which is open, https://github.com/NVIDIA/nccl is in part because some of it was funded by the Lawrence Berkeley National Laboratory!

  • TIL: conda cudatoolkit

    Every time I have to set up a clean system I manage to mess up cuda somehow, so leaving this as an aide memoir. In general, there is a default nvidia-cuda-toolkit package that ships with Ubuntu-based systems, and you should ignore that. The right options are either:

    • conda install cudatoolkit
      • This gives you the same thing in a conda env, but is somewhat limited in terms of the range of versions available
    • nvidia-container-toolkitt
      • Lets you run different cuda versions in docker containers, which can work for having completely isolated environments.

    The conda option generally works for me, but do try and install this first as you will end up in a painful package resolution process if you forget and have installed PyTorch first.

  • TIL: Numerical stability in TorchScript vs Export

    A very good passing comment from Nikita at work. When folks train in eager mode they are generally able to leverage the full set of aten ops, with their attendant implementations in different backends.

    Taking the trained model and exporting using TorchScript was largely guaranteed the same results, assuming it TorchScripted cleanly, as it’s just using libtorch underneath. The downside of this is shipping with a lot of stuff you probably don’t need, which is not great for LiteInterpreter on mobile in particular, plus of course the downside of the all the TorchScript weirdness.

    Torch.export offers a much cleaner flow, but on the flip side does a lot more processing of the model. The aten ops are decomposed into a simpler set of ~200 ops for the export IR, which might then be further processed for specific hardware. While you have a lot of control of this process, it means there are a lot of places to potentially introduce subtle differences between training time and inference, further training or whatever you are doing next.

    The important thing here is not to treat the two technologies as straight substitutes but as different paths to the same goal, that require different processes around them.

  • TIL: How to measure memory usage from your PyTorchmodel without running it

    Fantastic tip from Alban, particularly useful when you have a giant model and limited VRAM.

    The short answer is this 30 lines TorchDispatchMode that tracks all Tensor memory use

    https://dev-discuss.pytorch.org/t/how-to-measure-memory-usage-from-your-model-without-running-it/2024

  • Timeline on the Google Reader shutdown

    Interesting to read this breakdown of what was happening with Reader: I recall the discussions internally at Google and knew some folks floating around it (including at least one person who volunteered to take on maintenance) but this was largely new to me. Uncomfortably enough was on stage giving a talk for Google the day the shut down happened. I asked for questions, got a lot of hands, then qualified I was looking for questions about the talk and not Reader, and most of the hands went down.

    https://blog.persistent.info/2013/06/google-reader-shutdown-tidbits.html?m=1

  • DiLoCo

    https://arxiv.org/html/2311.08105v3

    Interesting paper from Google doing data-parallel training with an inner-optimizer, then a central node that collects gradients every so often, optimizes and shards back out. Somewhat federated learning like, but avoids full-sync across a larger cluster. Important ideas as scaling starts exceeding network capabilities.

    standard approaches to training LLM require a large number of tightly interconnected accelerators, with devices exchanging gradients and other intermediate states at each optimization step. While it is difficult to build and maintain a single computing cluster hosting many accelerators, it might be easier to find several computing clusters each hosting a smaller number of devices. In this work, we propose a distributed optimization algorithm, Distributed Low-Communication (DiLoCo), that enables training of language models on islands of devices that are poorly connected. The approach is a variant of federated averaging, where the number of inner steps is large, the inner optimizer is AdamW, and the outer optimizer is Nesterov momentum