Tag: pytorch

  • TIL: You don’t need PIL for decoding images with TorchVision

    pytorch.org/vision/main/generated/torchvision.io.decode_image

    The always busy Nicolas Hug was sharing this at work, and I hadn’t realized just how comprehensive the image decoding support had become in TorchVision. Over the last year TorchVision has added a lot of image decoding capabilities and got a better entry API. It should generally be faster than PIL now (with the exception of animated GIFS).

    Rather than decoding with PIL:

    from PIL import Image
    # Load the image
    image_path = "chungus.png"
    image = Image.open(image_path)

    You can use the built in decoders like this:

    from torchvision.io import read_file, decode_image
    # Load the image as a tensor
    image_path = "chungus.png"
    image_data = read_file(image_path)
    image_tensor = decode_image(image_data)

    TorchVision’s transforms support PIL transparently, so you might be using it when not intending! Relatedly, you’ll want to use the v2 transforms if you happen to be using the older versions.

    In general this complements the release of TorchCodec which has been improving decoding for video – you now have a really good range of options for decoding media in a PyTorch native way!

  • Functionalization in PyTorch

    Functionalization in PyTorch: Everything You Wanted To Know – compiler – PyTorch Developer Mailing List

    Over a year old, but a very in depth breakdown from Brian Hirsh of how AOTAutograd functionlizes – e.g. removes mutations from – various graphs, what that enables, and what kind of edge cases exist. Inductor as a backend can handle mutation, but many other situations (including export!) can’t. It got bumped up because of a question on exactly that!

    torch.export uses functionalization. In particular, when you export for inference, you’ll get out a functionalized ATen graph!

  • TIL: weights-only model loading will be the default in PyTorch 2.6

    I had missed this, but weights-only is going to be the default for torch.load in Pytorch 2.6:

    https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573

    This is one of those small-sounding changes which requires quite a lot of follow-through to actually land. The default torch.load supports pickled Python code, so allows for arbitrary code execution: very helpful in a lot of cases (hence the many places that need special consideration!), but, particularly these days when many users may be trying models of fairly unknown provenance, a source of ongoing security concerns. Making that behavior an explicit opt-in is a great win for the wider community. HuggingFace have done some good work in this area too with their safetensors project, and having the core safe-by-default is a very welcome addition!

  • TIL: torchdbg

    https://github.com/ezyang/torchdbg

    Step by step debugging through a PyTorch program and see the underlying operators and shapes. Helpful for getting a view of the graph and shapes – just annotate the code with with torchdbg.LoggingMode(): and add TORCH_TRACE=./log to dump the logging file. Comes with a handy viewer.

  • Ways to use torch.export

    Ways to use torch.export

    Part 2 of the series, covers export, AOTInductor for getting runnable models on the server and ExecuTorch for runnable models on edge device (phones, wearables etc). There are a number of good examples from real world experience in there of how to use these tools as well. As usual I learned about something I didn’t know existed, in this case it was intermediate tensor logging in AOTI:

    AOTInductor has an option to add dumps of intermediate tensor values in the compiled C++ code. This is good for determining, e.g., the first time where a NaN shows up, in case you are suspecting a miscompilation.

  • TIL: TunableOp in PyTorch

    I wasn’t aware of this particular autotuning lever! There is a breakdown of TunableOp on the AMD blog from back in July:

    https://rocm.blogs.amd.com/artificial-intelligence/pytorch-tunableop/README.html

    Instead of using the default GEMMs, TunableOp will search for the best GEMMs for your specific environment. It does so by first querying the underlying BLAS library for a list of all solutions for a given GEMM, benchmarking each of them, and then selecting the fastest. TunableOp then writes the solutions to disk, which can then be used on subsequent runs. 

    Though the infrastructure is generic, this is effectively an AMD-specific tuning tool right now, as mentioned in the original docs.

    Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of PyTorch will function correctly when using TunableOp but the only solution available to CUDA builds is the ‘Default’ implementation i.e. the original cuBLAS default, now called through TunableOp.

  • nonzero_static in PyTorch

    Natalia Gimelshein added cuda support for nonzero_static to PyTorch the other day — its a feature Jax has had for a while that lets you avoid a bit of data dependent annoyingness.

    I most often see nonzero pop up in logs: its underlies a lot of boolean mask operations and torch.where. A downside of torch.nonzero is that we don’t know the size of the returned tensor. The shape is data dependent, which causes pain for the compiler, and when running on an accelerator requires a device to host sync. This can significantly slow down otherwise fast operations.

    nonzero_static overcomes this by allowing you to supply a shape for the output — if the actual result is smaller it is padded, if larger the result is truncated. The PR linked above enables that for CUDA. You can’t seamlessly use it with bool masks or torch.where, but you can easily replace the call with one to nonzero_static.

    To see the difference, imagine we have a big tensor where we have some sense of the output shape, for example searching for a one-hot index like this:

    import torch
    
    # Set device
    device = torch.device("cuda:0")
    
    # Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
    size = 100_000_000
    large_tensor = torch.zeros(size, device=device, dtype=torch.long)
    large_tensor[size // 2] = 1  # Place a single 1 in the middle of the tensor
    
    # Record GPU events to ensure asynchronicity
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    ones_indices = torch.nonzero(large_tensor)
    end_event.record()
    torch.cuda.synchronize()
    
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"torch.nonzero() execution time: {elapsed_time_ms:.2f} ms")
    print("Indices of 1s:", ones_indices.cpu())  

    That gives us (on my laptop)

    torch.nonzero() execution time: 177.58 ms
    Indices of 1s: tensor([[50000000]])

    If we modify it to use nonzero_static:

    import torch
    
    # Set device
    device = torch.device("cuda:0")
    
    # Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
    size = 100_000_000
    large_tensor = torch.zeros(size, device=device, dtype=torch.long)
    large_tensor[size // 2] = 1  # Place a single 1 in the middle of the tensor
    
    # Record GPU events to ensure asynchronicity
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    ones_indices = torch.nonzero_static(large_tensor, size=1)
    end_event.record()
    torch.cuda.synchronize()
    
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"torch.nonzero_static() execution time: {elapsed_time_ms:.2f} ms")
    print("Indices of 1s:", ones_indices.cpu())  
    
    torch.nonzero_static() execution time: 78.53 ms
    Indices of 1s: tensor([[50000000]])

    A very nice improvement!

  • Way to use torch.compile

    Ways to use torch.compile

    The dream which we sold with torch.compile is that you could slap it on the top of your model and get a speed up. This turns out to… not quite be true? But the fact remains that if you’re willing to put in some work, there is almost always performance waiting at the end of the road for you. Some tips

  • PyTorch while_loop

    I’ve been following the development of the higher order ops in PyTorch nightlies for a little bit, and got a chance to try out while_loop. The best examples right now are in the tests, but as another, here’s a mandlebrot example:

    import torch
    from torch._higher_order_ops.while_loop import while_loop
    import matplotlib.pyplot as plt
    
    def mandelbrot_step(z, c):
        """Performs one iteration of the Mandelbrot sequence."""
        return z**2 + c
    
    def mandelbrot(c, max_iter, threshold):
        """Compute Mandelbrot set membership for a grid of complex numbers."""
        def cond_fn(z, iter_count, mask):
            return torch.any(mask & (iter_count < max_iter))
    
        def body_fn(z, iter_count, mask):
            z_next = mandelbrot_step(z, c)
            diverged = torch.abs(z_next) > threshold
            mask_next = mask & ~diverged
            iter_count_next = iter_count + mask_next
            return z_next, iter_count_next, mask_next
    
        # Initialize variables
        z0 = torch.zeros_like(c)
        iter_count = torch.zeros(c.shape, dtype=torch.int32)
        mask = torch.ones(c.shape, dtype=torch.bool)  # All points start as candidates
        final_state = while_loop(cond_fn, body_fn, (z0, iter_count, mask))
        
        _, iterations, _ = final_state
        return iterations
    
    # Define the grid of complex numbers
    x = torch.linspace(-2.0, 1.0, 500)
    y = torch.linspace(-1.5, 1.5, 500)
    xx, yy = torch.meshgrid(x, y)
    complex_grid = xx + 1j * yy
    
    # Compute the Mandelbrot set
    max_iter = 100
    threshold = 2.0
    mandelbrot_set = mandelbrot(complex_grid, max_iter, threshold)
    
    # Plot the Mandelbrot set
    plt.figure(figsize=(10, 10))
    plt.imshow(mandelbrot_set, extent=(-2, 1, -1.5, 1.5), cmap="inferno")
    plt.colorbar(label="Iteration count")
    plt.title("Mandelbrot Set")
    plt.xlabel("Real")
    plt.ylabel("Imaginary")
    plt.show()

    In general, the only non-obvious thing about while_loop is that the cond_fn is returning a tensor, not a bool, so make sure you are getting your types right, and that the shapes must be consistent from loop to loop. If you need more accumulating type behavior, look at scan!

  • Unsupported: dynamic shape operator: aten.nonzero.default with boolean masks in torch.compile

    The error message you get actually tells you the fix, but I found it non-intuitive to what I was doing enough I was hesitant to actually just try the config:

    torch._dynamo.config.capture_dynamic_output_shape_ops = True

    The general issue is capturing shapes on scalars isn’t turned on by default due to various issues, but for your case it may actually work. It is also interesting to see where TorchVision hit this, and worked around with torch.where instead.

  • TIL: Numerical stability in TorchScript vs Export

    A very good passing comment from Nikita at work. When folks train in eager mode they are generally able to leverage the full set of aten ops, with their attendant implementations in different backends.

    Taking the trained model and exporting using TorchScript was largely guaranteed the same results, assuming it TorchScripted cleanly, as it’s just using libtorch underneath. The downside of this is shipping with a lot of stuff you probably don’t need, which is not great for LiteInterpreter on mobile in particular, plus of course the downside of the all the TorchScript weirdness.

    Torch.export offers a much cleaner flow, but on the flip side does a lot more processing of the model. The aten ops are decomposed into a simpler set of ~200 ops for the export IR, which might then be further processed for specific hardware. While you have a lot of control of this process, it means there are a lot of places to potentially introduce subtle differences between training time and inference, further training or whatever you are doing next.

    The important thing here is not to treat the two technologies as straight substitutes but as different paths to the same goal, that require different processes around them.

  • TorchFT: Fault tolerant training

    https://github.com/pytorch-labs/torchft

    The repo from Tristan Rice and Chirag Pandya’s poster at the PyTorch conference has been continually being updated with refinements and improvements.

    torchft implements a lighthouse server that coordinates across the different replica groups and then a per replica group manager and fault tolerance library that can be used in a standard PyTorch training loop. This allows for membership changes at the training step granularity which can greatly improve efficiency by avoiding stop the world training on errors.

    There are lots of clever techniques in here. They center around the idea of having replica groups which serve as the failure domain, rather than the whole training job. In a loose sense, this means that when there is a failure you simply drop the replica group its in and carry on with the rest to the next batch, adding the replica group back in when its recovered.

    To make that possible, there’s custom comms that allow for error handling, health monitoring of individual processes, and fast checkpointing that allows recovered workers to be quickly added back to replica sets.

  • Alphafold3 PyTorch

    The always prolific Lucidrains has been implementing AlphaFold3 in PyTorch:

    https://github.com/lucidrains/alphafold3-pytorch

    This illustrated explainer is a great breakdown of the model in general:

    https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/

  • TIL: How to measure memory usage from your PyTorchmodel without running it

    Fantastic tip from Alban, particularly useful when you have a giant model and limited VRAM.

    The short answer is this 30 lines TorchDispatchMode that tracks all Tensor memory use

    https://dev-discuss.pytorch.org/t/how-to-measure-memory-usage-from-your-model-without-running-it/2024

  • Inductor notes

    Inductor is PyTorch’s compiler backend designed to optimize and generate high-performance code for arbitrary models. It works over a few phases:

    1. aotautograd: Capturing the Forward and Backward Pass**

    The graph of ops is traced when executing forward or backward graphs. They’re wrapped in torch.fx.GraphModule containers, shapes and sizes are made symbolic, using SymPy.

    Operators are decomposed to a standard IR, either AtenIR, or further decomposed into a simpler set of ops (PrimsIR) for backends like Inductor that can do their own fusions. Backends can specify their own decompositions as well, they’re passed when hooking up to AOTAutograd:

    prims_decomp = torch._decomp.get_decompositions([
        torch.ops.aten.add,
        torch.ops.aten.expand.default,
    ])

    2. Inductor Lowering

    This is the start of Inductor itself as a backend, and it starts by converting the ATenIR into a Python based define-by-run IR. Define-by-run means it allows dynamic execution – the same IR ops can do different things in different passes.

    In this process Inductor:

    • Eliminates views (operations that do not change the underlying data but modify how tensors are accessed).
    • Removes broadcasting overhead by explicitly adjusting tensor shapes.
    • Simplifies indexing patterns to enable more efficient execution.
    • Does classic compiler things like dead code elimination

    3. Scheduling

    The scheduler plans the execution of operations to optimize performance. IT does vertical fusion (operations along the graph) and horizontal fusion (operations across different tensors), sets up tiling and uses reductions for sums, averages etc.

    This section also does autotuning – profiling multiple implementations of ops to select the best one, memory planning to avoid bottlenecks and so on.

    4. Code Generation

    Finally, Inductor generates the executable code for the target hardware. It has multiple backends it can choose including Triton, OpenMP, CUTLASS, ROCm, XPU (for Intel GPUs) and others.

    It also generates kernel wrappers to handle memory allocation and orchestration.

    This code lives under torch/_inductor/codegen.