Tag: python

  • Constraints & Orchestrators

    I recently read a few posts that helped connect the dots on why Python is a) so successful as the lingua franca of ML b) also seems likely to be successful in the future1.

    ML code reads like one program, but runs many: CUDA kernels, vectorized CPU loops, graph compilers and a bunch of glue code moving data around and tying things together. Python has continually improved at balancing two somewhat competing challenges: constraining the hot path so compilers can optimize it and structuring an orchestration path so humans can reason about it.

    Hot Path

    constrained languages are easier to optimize by Jynn Nelson touches on this:

    we should not be asking “what language can i use everywhere for every purpose”; we should build meta-languages that allow you to easily use the right tool for the job. this is already true for regular expressions and query languages; let’s go further. i want inline futhark; inline CSS selectors; inline datalog; ffi between python and C that’s trivially easy. the easier we make it to interop, the easier it becomes to pick the right tool for the job.

    Compilers are generally going to perform better if you have regular shapes, minimal side effects, predictable memory access and so on, but you want languages to be expressive and flexible, particularly when “research” is a big part of the work. In practice, that’s precisely what happens with ML : torch.compile lowers PyTorch graphs to an IR and (often) emits Triton kernels. Being able to hand off inner-loops to specialized languages allows compilers and runtimes to optimize and target the use cases they are best at.

    While this is (somewhat) clear for GPUs or other accelerators with distinctive programming models, I think it’s also largely true for getting the best out of modern CPUs. Daniel Lemire’s SEA 2025 talk covers nearly a decade of performance work and sums it up: modern CPUs do nearly as many instructions per cycle as you can feed them. To really maximize performance you need to batch work, reduce instruction counts and vectorize. We can do some of that in the general Python2 runtime but dynamic dispatch, aliasing and side effects all make the job a lot harder. We can add speculative guards, which can be hard to reason about, or give up and lose performance. By having DSLs3 that add additional constraints we can give ourselves the ability to get much, much higher performance without scrificing the overall flow of our program.

    Orchestration Path

    Python is unusually good as an orchestrator. From a readability perspective the language is baseline very readable and as long as libraries and DSLs stay Pythonic they tend to inherit that intelligibility. The challenge with orchestration is coordinating work in such a way that your most precious resources are well utilized. The investments in Free-Threaded Python make it a lot cheaper to do concurrency, but they don’t magically fix the challenge of coordination.

    asyncio: a library with too many sharp corners covers some of the many failure modes the community have encountered with asyncio, and makes a case for Trio or ANyIO style structured concurrency that allows for manageable failure modes.

    asyncio is not a good library. It is constantly full of sharp edges everywhere with implementation details leaking and poorly designed APIs forcing end users into odd code patterns to avoid fundamental flaws in the interfaces.

    This is very much a readability version of the constraints concern on the hot path. Threads are a bad app abstraction over shared mutable state, reasoning about races and cancellation is hard, and primitives are always leaky. But threads are a perfectly fine implementation detail behind a more constrained API, like task groups, or actors, or so on.

    One area that I do think needs sustained improvement is how we debug and trace across this kind of set up: it’s been challenging even in a controlled environment to really understand how all the pieces interact in a reasonably scaled ML workload, and I imagine that problem will only get worse. But I also expect that the flexibility and breadth of Python will end up a boon there as well.

    1. Beyond just sheer momentum, of course. ↩︎
    2. Or any language! Certainly for some optimizations having a JIT for Python would (and does) make life easier. ↩︎
    3. Whether that is an embedded JIT like Triton or a library+execution engine like Polars. ↩︎
  • Cute-DSL

    In May Nvidia shipped CuTe‑DSL, the Python library they teased at GTC earlier in the year that mirrors CUTLASS’s C++ tensor‑layout . Then, at the start of June, the ‑dev label disappeared (so presumably its production ready now). The pitch is simple: Write speed‑of‑light kernels from the comfort of Python.

    Of course, nothing about CUDA is ever really simple. CuTe‑DSL gives the full Cutlass experience1, wrapped in an ever so slightly more approachable interface.

    Getting Cute: Transpose

    Matrix transpose felt like a reasonable ‘hello world’ : (B[j,i] = A[i,j]). The PyTorch version is simple: torch.transpose(input, 0, 1).

    To get a baseline, here is a simple transpose kernel in Triton. We tl.load, flip the coordinates and tl.store it back.

    @triton.jit
    def triton_transpose_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE: tl.constexpr):
        # 2D block coordinates
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        
        # Calculate offsets
        offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        
        # Load with masking
        mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
        a = tl.load(input_ptr + offs_m[:, None] * N + offs_n[None, :], mask=mask)
        
        # Store transposed (swap coordinates)
        tl.store(output_ptr + offs_n[:, None] * M + offs_m[None, :], a, mask=mask)

    Here’s the same idea in CuTe‑DSL. CuTe leverages a decorator and Pythons ability to integrate with JITs. Anything decorated with @jit runs host-side (on the CPU), while @kernel is used for device side (on the GPU). There are both AST based and tracing based options dependening on the presence of dynamic shapes or control flow.

        @cute.kernel
        def transpose_kernel(self, mA: cute.Tensor, mB: cute.Tensor):
            tidx = cute.arch.thread_idx()[0]
            tidy = cute.arch.thread_idx()[1]
            bidx = cute.arch.block_idx()[0]
            bidy = cute.arch.block_idx()[1]
            # This might all be unnecessary
            # but I was fearful of the compiler
            tile_start_m = cutlass.Int32(0)
            tile_start_n = cutlass.Int32(0)
            global_m = cutlass.Int32(0)
            global_n = cutlass.Int32(0)
            M = cutlass.Int32(0)
            N = cutlass.Int32(0)
            val = cutlass.Float32(0.0)
            # Calculate tile starting positions
            tile_start_m = bidy * self._tile_size
            tile_start_n = bidx * self._tile_size
            # Calculate global coordinates for this thread
            global_m = tile_start_m + tidy
            global_n = tile_start_n + tidx
            # Get matrix dimensions at runtime
            M = mA.shape[0]
            N = mA.shape[1]
            # Bounds checking and transpose operation
            if global_m < M and global_n < N:
                val = mA[global_m, global_n]
                # Transpose: B[n, m] = A[m, n]
                mB[global_n, global_m] = val

    What just happened?

    • Thread and block indices come straight from CUDA (thread_idx, block_idx), vs the Triton block abstraction
    • No explicit loads or stores: CuTe uses overloaded [] to generate them.

    Launching isn’t a million miles away from Triton:

    @cute.jit   # host side
    def launch(self, A: cute.Tensor, B: cute.Tensor):
        M, N = A.shape
        grid = ((N + self.T - 1)//self.T,
                (M + self.T - 1)//self.T, 1)
        self.transpose_kernel(A, B).launch(
            grid=grid,
            block=[self.T, self.T, 1],
        )

    Because CuTe‑DSL speaks DLPack, you can hand it a PyTorch tensor directly. If you wanted to cache the conversion, it looks like this:

    A_cute = from_dlpack(A).mark_layout_dynamic()

    The mark_layout_dynamic is used to trigger the dynamic shape support, and avoid shape specialization. The one place where this went a bit funky in my testing was dealing with singular leading dimensions: there you need to be more explicit about the shape to satisfy the compiler.

    Layouts and Memory

    This kernel isn’t really leveraging the fundamental value of CuTe though, which is composable tensor layouts and memory management. CuTe‑DSL exposes the full memory hierarchy: global, shared, register (and tmem for those with blackwells), and lets you tile, copy, and pipeline data between them. Common primitives:

    • make_layout / make_layout_tv: describe how a tensor is laid out.
    • cute.zipped_divide(tensor, tiler): tile a tensor.
    • cute.copy(src_layout, dst_layout, pred=mask): async copy.
    • cute.arch.sync_threads(): explicit barrier.

    HGEMMony2

    CuTe ships with some example kernels, so I grabbed one — an HGEMM (half-precision, FP16, batched GEMM) — and compared to an example implementation in Triton.

    To express the same thing in  PyTorch, we can unleash our inner Jeremy Howard and use einsum notation: torch.einsum("mkl,nkl->mnl", a, b).  Take L batches of a MxK matrix, L batches of a NxK matrix, and return L batches of a MxN matrix.

    Here is the Triton:

    @triton.jit
    def triton_batched_hgemm_kernel(
    	a_ptr, b_ptr, c_ptr,
      M, N, K, L, 
      stride_am, stride_ak, stride_al, 
      stride_bn, stride_bk, stride_bl, 
      stride_cm, stride_cn, stride_cl, 
      BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
      GROUP_SIZE_M: tl.constexpr,
    ):
        """Triton batched half-precision GEMM kernel: C[m,n,l] = sum_k A[m,k,l] * B[n,k,l]"""
        pid = tl.program_id(axis=0)
        pid_batch = tl.program_id(axis=1)  # Batch dimension
        
        # Calculate batch offsets
        batch_offset_a = pid_batch * stride_al
        batch_offset_b = pid_batch * stride_bl  
        batch_offset_c = pid_batch * stride_cl
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m
        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        
        # Include batch offsets in pointer calculations
        a_ptrs = a_ptr + batch_offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        # transpose for GEMM (load B[K, N] pattern)
        b_ptrs = b_ptr + batch_offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        
        # We accumulate into fp32 for higher accuracy.
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            # Load the next block of A and B, mask in K
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            accumulator = tl.dot(a, b, accumulator)
            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk
        # Convert back to FP16 for output
        c = accumulator.to(tl.float16)
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        c_ptrs = c_ptr + batch_offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
        tl.store(c_ptrs, c, mask=c_mask)

    The core loop here is:

    • Divide into block_size_k chunks of K
    • For each do a masked load (so when we hit the edges we don’t bring in garbage)
    • Do the dot product for the tile into an accumulator for the result matrix
    • Advance the A and B pointers

    The CuTe kernel is, uhhh… a bit more involved. The full kernel is several hundred lines long. You can see the source in the Cutlass repo, which demonstrates some cool features like the ability to pass in an epilogue function for fusion.

    For now, lets focus on the main loop. As before, we are looping over tiles of K3 .

    The first thing we do is advance the pointers. The kernel is doing explicit pipelining of transfers from global memory, to shared memory, to registers, so we need to set wait groups do ensure all the loading has completed before we advance. We’re not actually doing loading in this section, just prepping the ground:

    for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
    	for k_block in range(num_k_block):  
          if k_block == num_k_block - 1:
        	tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
          tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
          cute.arch.cp_async_wait_group(num_smem_stages - 2)
          cute.arch.sync_threads()

    Next we kick off a copy from shared memory (smem) to registers (rmem) using cute.copy for the (future) A and B tiles.

    k_block_next = (k_block + 1) % num_k_block  # static
    cute.copy(
    	tiled_copy_s2r_A,
      	tCsA_p[None, None, k_block_next],
        tCrA_copy_view[None, None, k_block_next],
       )
    cute.copy(
    	tiled_copy_s2r_B,
        tCsB_p[None, None, k_block_next],
         tCrB_copy_view[None, None, k_block_next],
      )

    Finally, we interleave the transfers of the next A and B tiles from global to shared memory with the actual gemm operation (the equivalent of tl.dot). I will trust the folks at Nvidia that this is an optimal pattern. The pred= in there is the equivalent of masking in Triton.

       if k_block == 0:
     		if k_tile + num_smem_stages - 1 < k_tile_count:
        	    cute.copy(
    				tiled_copy_B,
    				tBgB[None, None, None, k_tile_index],
    				tBsB[None, None, None, smem_pipe_write],
    				pred=tBpB,
    			)
            k_tile_index = k_tile_index + 1
            cute.arch.cp_async_commit_group()
    		smem_pipe_write = smem_pipe_read
    		smem_pipe_read = smem_pipe_read + 1
    			
    		if smem_pipe_read == num_smem_stages:
          	    smem_pipe_read = 0

    The pipelining is explicit, which is nice for debuggability and optimization, but very manual.

    Debugging Tips

    export CUTE_DSL_LOG_TO_CONSOLE=1
    export CUTE_DSL_LOG_LEVEL=10   # up to 100
    export CUTE_DSL_PRINT_IR=1     # dump MLIR
    • cute.printf() gives you a GPU‑side printf.
    • Kernels are aggressively cached; rm ~/.cache/cutedsl if things look stale.
    • Multiple @cute.jit host functions in the same Python scope can confuse MLIR (mainly for launching kernels).
    • The control‑flow rules are strict: no return inside a kernel; initialize everything.

    If you’re exploring GPU kernels for the first time, I strongly recommend starting with Triton. When you need to really get into the weeds, or want to reuse CUTLASS building blocks, its great to have CuTe‑DSL as an option in Python (provided you’re comfortable spelunking in GPU internals).

    1. I spent a lot of time holding it wrong. Arguably, still holding it wrong. ↩︎
    2. No one knows what it means, but it’s provocative ↩︎
    3. Note the explicit unroll tag. When you really want #pragma but can’t. ↩︎
  • Free-Threaded Python gets ‘supported’ status

    Huge congratulations to Thomas, Matt and Sam for sheparding through PEP 779 that moves the no-gil/free threaded python mode from experimental to supported:

    https://discuss.python.org/t/pep-779-criteria-for-supported-status-for-free-threaded-python/84319/123

    With these recommendations and the acceptance of this PEP, we as the Python developer community should broadly advertise that free-threading is a supported Python build option now and into the future, and that it will not be removed without following a proper deprecation schedule

    Having confidence in the long term of this feature is great for anyone building on it. I’m very grateful to (and feel lucky to be working with!) the many folks who have been squashing bugs and improving performance, and to the people adding support for FT Python across the ecosystem!

    The steering commitee have laid out some solid documentation and performance expectations for the ongoing work, and are setting an expectation for broad compatibility for future cpython work:

    New experimental projects within CPython must be compatible with, and should be based on the free-threading build. The SC encourages this direction to reduce engineering complexity caused by supporting both GIL and free-threaded builds

    I also appreciate the call for building out more high-level concurrency primitives: I think there are a lot of exciting projects to come as we move more of this into production!

  • Monarch: PyTorch Single Controller

    I’ve been excited for this to make it to OSS: The PyTorch team at Meta recently soft-launched Monarch on Github.

    pytorch-labs/monarch: PyTorch Single Controller

    Back in 2022, Google’s Pathways paper proposed (revisiting) a single-controller approach for managing machine learning runs. Typically, ML jobs use an SPMD (Single Program, Multiple Data) approach, distributing identical code across multiple hosts. Each host runs independently, synchronizing during collective operations. This works, as evidenced by the many large training runs in the world. It also introduces complexity, especially with pipeline parallelism where conditional logic for different ranks can clutter up your training code. Even without that, subtle issues can arise: for example, slight differences in torch.compile optimization have (in the past!) lead to deadlocks by placing collectives differently on separate nodes.

    The single-controller model simplifies this by centralizing program execution on one main node and using generic workers on the hosts that execute assigned tasks. This provides a consistent, global view of the entire computation, making it easier to get to a correct implementation of parallelisms and other distributed work. This doesn’t come for free though: the main node must efficiently manage (potentially!) thousands of GPUs without becoming a bottleneck, and existing code must adapt to this new centralized model.

    Monarch is the PyTorch team’s implementation of this single-controller concept. It provides a familiar PyTorch frontend, additional module wrappers, and a high-performance Rust-based actor system for distributing and managing work.

    The fundamental abstraction in Monarch is the Actor. Each Actor executes on their own accelerator, maintains state and behavior. Communication with other Actors is via async message passing on methods decorated with @endpoint. The nice thing about the programming model is you can interact with all of the actors in your mesh in a consistent way.

    Monarch is appealing even if you’re not GPU-rich. For instance, at home, I have a machine equipped with two (mismatched) 3090s, and Monarch allows me to run and debug jobs directly in notebooks without relying on external services.

    Installation had minor hurdles because I built from source rather than using the available pip package. Although the README specifies Python 3.10, Python 3.13 worked fine. The dependencies reference dnf (reflecting Meta’s internal Linux distro choice), so adapting commands to other Linux distributions (Ubuntu in my case) was necessary. Additionally, I had to set BINDGEN_EXTRA_CLANG_ARGS="-I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11" to resolve Rust compilation issues.

    Once installed, running Monarch’s distributed data-parallel notebook was straightforward (see: monarch/examples/notebooks/spmd_ddp.ipynb). The notebook shows that minimal code changes to the standard DDP example are required, mainly subclassing Actor (e.g., class DDPActor(Actor)), while keeping the training loop virtually identical. Monarch handles the rest, including distributed execution and debugging across multiple GPUs.

    Setting up the environment means providing the mesh configuration and launching the actors, which can be done from a cell:

    # Spawn a process mesh
    local_proc_mesh = await proc_mesh(
        gpus=WORLD_SIZE,
        env={
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12455",
        },
    )
    # Spawn our actor mesh on top of the process mesh
    ddp_actor = await local_proc_mesh.spawn("ddp_actor", DDPActor)

    I didn’t have to manually start any other services; it all happened under the hood. Triggering the run is just:

    await ddp_actor.demo_basic.call()

    Which output:

    self.rank=0 Running basic DDP example
    self.rank=1 Running basic DDP example
    self.rank=1 Finished running basic DDP example
    self.rank=0 Finished running basic DDP example

    What I find really appealing is how easy it is to execute across ranks. For example, to query for system info:

    print("Gathering system info from all ranks...")
    system_info = await ddp_actors.get_system_info.call()
    
    print("\n SYSTEM INFORMATION ACROSS ALL RANKS:")
    print("=" * 60)
    
    for point, rank_info in system_info:
        print(f"Rank {rank_info['rank']}: PID={rank_info['process_id']}, "
              f"Device={rank_info['device_name']}, "
              f"GPU Memory={rank_info['gpu_memory_allocated']/1e6:.1f}MB")
    
    print(f"\nFound {len(system_info)} ranks in the mesh")
    
    all_rank_info = [value for point, value in system_info]
    print(f"Total GPU memory across all ranks: {sum(info['gpu_memory_allocated'] for info in all_rank_info)/1e6:.1f}MB")

    Outputting:

    Gathering system info from all ranks...
    [Rank 0] System Info: PID=10519, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    [Rank 1] System Info: PID=10520, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
    
     SYSTEM INFORMATION ACROSS ALL RANKS:
    ============================================================
    Rank 0: PID=10519, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    Rank 1: PID=10520, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
    
    Found 2 ranks in the mesh
    Total GPU memory across all ranks: 0.1MB

    I can also stop training and dump state if I need to , making it easy to check norms and debug:

    print("Running training steps...")
    for step in range(3):
        print(f"\n--- Step {step + 1} ---")
        
        step_results = await ddp_actors.train_step.call()
        
        all_results = [value for point, value in step_results]
        
        losses = [result['loss'] for result in all_results]
        grad_norms = [result['grad_norm'] for result in all_results]
        throughputs = [result['throughput'] for result in all_results]
        
        print(f"Losses across ranks: {[f'{l:.4f}' for l in losses]}")
        print(f"Gradient norms: {[f'{g:.4f}' for g in grad_norms]}")
        print(f"Avg throughput: {sum(throughputs)/len(throughputs):.1f} samples/sec")
    --- Step 1 ---
    [Rank 1] Step 1: Loss=1.1128, GradNorm=0.3198, Time=0.241s
    [Rank 0] Step 1: Loss=1.0414, GradNorm=0.3198, Time=0.253s
    Losses across ranks: ['1.0414', '1.1128']
    Gradient norms: ['0.3198', '0.3198']
    Avg throughput: 129.8 samples/sec
    
    --- Step 2 ---
    [Rank 0] Step 2: Loss=1.1526, GradNorm=0.3096, Time=0.003s
    [Rank 1] Step 2: Loss=1.0546, GradNorm=0.3096, Time=0.003s
    Losses across ranks: ['1.1526', '1.0546']
    Gradient norms: ['0.3096', '0.3096']
    Avg throughput: 9800.9 samples/sec
    
    --- Step 3 ---
    [Rank 1] Step 3: Loss=0.9116, GradNorm=0.2243, Time=0.002s
    [Rank 0] Step 3: Loss=0.9662, GradNorm=0.2243, Time=0.002s
    Losses across ranks: ['0.9662', '0.9116']
    Gradient norms: ['0.2243', '0.2243']
    Avg throughput: 19977.5 samples/sec

    While the distributed stuff here is cool, it’s not wildly different than using a distributed framework like Ray and a little bit of setup (though I am pretty allergic to setup). What is most interesting is how this changes the programming model of PyTorch, and makes it really easy to build out distributed experiments.

    For example, if I was building a param server the sync only requires an await’d read of the weights from another object, taking advantage of the RDMA support for an efficient cop1y:

        @endpoint
        async def worker_sync_with_ps(self, param_server) -> bool:
            """Synchronize with parameter server and get RDMA handles"""
                
            self._log("Synchronizing with parameter server...")
            
            # Get RDMA buffer handles
            self.weight_buffers = await param_server.ps_get_weight_handles.call_one()
            self.gradient_buffers = await param_server.ps_get_gradient_handles.call_one()
            
            # Get metadata
            metadata = await param_server.ps_get_metadata.call_one()
            self.weight_metadata = metadata['weights']
            self.gradient_metadata = metadata['gradients']
            
            # Perform initial weight sync
            sync_time = await self._sync_weights_from_ps()
            
            self._log(f"Synchronized with parameter server (sync time: {sync_time:.3f}s)")
            return True

    Getting those weight buffers is as simple as creating the right Monarch object:

    def tensor_to_rdma_buffer(tensor: torch.Tensor) -> RDMABuffer:
        # RDMA requires 1D contiguous uint8 tensors
        byte_tensor = tensor.view(torch.uint8).flatten()
        return RDMABuffer(byte_tensor)

    For an early preview of a library, Monarch is surprisingly complete, and definitely worth a look.

    1. Not that this would do anything for my 3090s! ↩︎
  • Pyrefly

    https://pyrefly.org

    I’m at PyCon (mildly awkward photo thanks to Simon Willison!) and earlier had to steal some extra chairs for the Typing Summit as it was full up! There is a lot of energy and interest around type checking, thanks to Astral’s Ty and Meta’s Pyrefly projects coming in to the space recently.

    While the playground is great to try it, I wanted to see what it was like on a larger codebase I was familiar with. I decided to try TorchTune, which makes use of types, but doesn’t configure a typechecker explicitly for CI (as far as I know!), relying on the LSP to show squiggles as the main type hinting feedback (which is reasonable!)

    I tried running mypy over it with a very basic config, and time mypy

    [mypy]
    python_version = 3.13
    ignore_missing_imports = True
    warn_unused_ignores = True
    strict_optional = True
    files = .

    Unsurprisingly, there are quite a few errors!

    Found 1361 errors in 211 files (checked 485 source files)
    real 10m45.596s
    user 0m12.906s
    sys 0m0.929s

    I installed pyrefly and init’d it:

    pip install pyreflypyrefly init

    This created a pyrefly.toml containing a very minimal config:

    project_includes = ["."]
    python_version = "3.13.0"

    pyrefly check then gave me

    INFO 2,966 errors shown, 7 errors ignored, 485 modules, 7,364 transitive dependencies, 3,522,743 lines, took 47.94s (checking 34.88s; reporting 12.98s), peak memory physical 863.6 MiB

    It’s impressively fast: 10 minutes for mypy vs under 50 seconds for pyrefly. There are also a lot more errors, and it’s tricky to tell whether they’re false positive from pyrefly, skipped errors from mypy, or something else. I scoped it down to the TorchTunes KV cache module in torchtune/modules/kv_cache.py to get a better look.

    There pyrefly returns 9 errors, and mypy 17, but that’s caused by slightly different ways of capturing some of the same errors from what I can see. For example:

    k_out[:, :, self.cache_pos[:seq_len]] = k_val

    This code is doing a bit of Tensor slicing:

    • cache_pos is a max_seq_len long tensor holding absolute write positions
    • k_out is the key cache, with shape batch_size x num_heads x max_seq_len x head_dim
    • Here we’re getting a view for the part of the cache we want to update, and appending the latest values

    MyPy gives these errors:

    torchtune/modules/kv_cache.py:104: error: "Tensor" not callable [operator]
    torchtune/modules/kv_cache.py:104: error: Value of type "Tensor | Module" is not indexable [index]

    While pyrefly gives:

    torchtune/modules/kv_cache.py:104:9-46: Item assignment is not supported on Module | Tensor   Expected __setitem__ to be a callable, got BoundMethod...

    In this module cache_pos and k_cache are created by calling PyTorch’s register_buffer, which stores params for use in the state_dict but doesn’t use them for training. Buffers don’t have to be Tensors, so I am guessing the type doesn’t propagate well there. Adding explicit type declarations in the class body fixes the errors in both mypy and pyrefly.

    cache_pos: torch.Tensor
    k_cache: torch.Tensor

  • The First Year of Free-Threaded Python

    https://labs.quansight.org/blog/free-threaded-one-year-recap

    Quansight have been doing a tonne of work on the roll-out of Free Threaded Python, and they released an insightful blog post covering the progress over the last year that complements the recent Meta eng blog post well.

    At this time last year, when Python 3.13.0b1 shipped, the wider ecosystem of Python packages was more or less completely broken on the free-threaded build. Trying to pip install anything but the simplest package with no dependencies or only pure-Python dependencies would likely lead to build errors. Most of these issues were not due to fundamental problems but because of unsupported default options or minor assumptions broken on the free-threaded build.

    Together with package maintainers and other contributors in the community, we have fixed many of these issues and today things are much better. With the release of Cython 3.1.0, which ships official support for the free-threaded build, we also helped fix one of the most significant sources of build issues.

    There have been a number of good talks at PyCon US this year around Free-Threaded, and lots of familiar stories of rediscovering the concurrency principles that other languages have worked on in the process! I’d recommend keeping an eye on the PyCon YouTube Channel for talk recordings, particularly Lisandro and Nathan’s talk about their journey, David Hewitt’s talk about using Rust with FT Python from his experience with PyO3, and Alvaro Duran’s on his FT Python load balancer!

    Some useful links I’ve been looking at from the talks:

  • LSP & Standards

    https://www.michaelpj.com/blog/2024/09/03/lsp-good-bad-ugly.html

    Having recently spent a lot more time around typecheckers, I’ve been reading about the Language Server Protocol that glues IDEs and language support services together. The article, from September last year, gives a breakdown of the good and the bad about the protocol, and is a really great dive into the broader topic.

    Much of the pain stems from how the protocol emerged and is managed:

    There is zero open discussion of features before they are added to the spec. Typically they are implemented in VSCode, and then the specification is updated as a fait accompli to document those changes. Implementers of open-source language servers get very influence on the development of the specification.6 There is not even a community space for implementers of language servers to get together and talk about the many tricky corners.

    I feel echoes of this in a lot of different projects I have been around, including PHP internals, ZeroMQ’s protocol, various CNCF working groups, PyTorch and Triton. Protocols and technologies emerge from a need, and grow because that need is shared, but transitioning from a narrow and highly connected problem source to a true standard is difficult: attempt to standardize and bring in voices too early and you just slow down progress to the point something else emerges which solves immediate needs better; leave it too long and the governance questions can be sufficient to encourage folks to rally around forks or alternatives.

    One example of that playing out at the moment is Anthropic’s (and dsp’s!) Model Context Protocol. Tim Kellog wrote a nice post the other day comparing it to OpenAPI, concluding:

    Standards are mostly sociological advancements. Yes, they concern technology, but they govern how society interacts with them. The biggest reason for MCP is simply that everyone else is doing it. Sure, you can be a purist and demand that OpenAPI is adequate, but how many clients support it?

    The reason everyone is agreeing on MCP is because it’s far smaller than OpenAPI. Everything in the tools part of an MCP server is directly isomorphic to something else in OpenAPI. In fact, I can easily generate an MCP server from an openapi.json file, and vice versa. But MCP is far smaller and purpose-focused than OpenAPI is.

  • Free-Threaded Python community

    As Thomas posted on the Python Discuss board there is a discord for discussing Free-Threaded/NoGil Python: https://discord.gg/rqgHCDqdRr

    Other than the helpful docs the Quansight folks maintain (py-free-threading) it’s been interesting to see some projects and tools pop up on there I hadn’t see. One being Zsolt’s py-free-threading which checks whether your project and deps have FT wheels for any non-pure python deps, which can be run as a 1-liner thanks to uv:

    uvx python-ft-deps

  • PyTorch and Python Free Threading

    https://trent.me/articles/pytorch-and-python-free-threading/

    Trent Nelson has written an extremely detailed breakdown of his experiments with running inference on GPT-2 on PyTorch and the GIL-free version of Python from 3.13 and 3.14.

    He implements parallel generation using multiple threads (on one GPU and later multiple devices), parallel model loading, and then some of the challenges with torch.compile (which doesn’t work great with nogil yet!)

    Hopefully this encourages more folks to experiment with free-threaded Python, or perhaps port their existing Python packages to play nicely when installed in a free-threaded Python environment. I personally can’t wait until free-threaded Python is the default! Although that’s probably at least five or so years out at this point.

    Free threaded python really changes the performance trade-offs around Python, and I expect it to be the default for ML work a lot sooner than that!

  • Performance of the tail-call interpreter in Python 3.14

    Performance of the Python 3.14 tail-call interpreter – Made of Bugs

    Great example of the benefits (and complexities!) of open source today. Nelson Elhage at Anthropic investigated the recent tail-call interpreter improvements in Python 3.14 after being suspicious of the 10-15% claimed performance win. It turned out that the baseline was artificially bad due to a bug in LLVM.

    Unfortunately, as I will document in this post, these impressive performance gains turned out to be primarily due to inadvertently working around a regression in LLVM 19. When benchmarked against a better baseline (such GCC, clang-18, or LLVM 19 with certain tuning flags), the performance gain drops to 1-5% or so depending on the exact setup.

    That’s still a very impressive speedup for something as well used and improved as the Python interpreter. Ken Jin, the original change author, wrote a nice apology post but this is a very tricky situation!

    In order to avoid catastrophic slowdowns (or memory usage) in certain cases, LLVM 19 implemented some limits on tail-duplication pass, causing it to bail out if duplication would blow up the size of the IR past certain limits. Unfortunately, on CPython those limits resulted in Clang leaving all of the dispatch jumps merged, and entirely undoing the whole purpose of the computed goto-based implementation! 

    Baselines are a persistent issue, and the blog has some good things to say there. A level of skepticism is good: in this case an optimization in one area ended up being a performance hit in another, which is pretty common. Similarly, while this is a genuine, very challenging, mistake, it’s much more common to (somewhat) intentionally choose a flattering baseline. I recently had discussions at work around something that had delivered a double-digit change, but only when compared against a poor baseline. When measured against a more meaningful alternative, the gain was much more modest. Nelson makes the point succinctly:

    I work in machine learning at Anthropic these days, and we see this all the time in ML papers. When a paper comes out claiming some algorithmic improvement or other advance, I’ve noticed that the first detail our researchers ask is often not “What did they do?” but “What baseline did they compare against?” It’s easy to get impressive-looking results if you’re comparing against a poorly-tuned baseline, and that observation turns out to explain a surprising fraction of supposed improvements.

    This happens in at large scale too: Google, on a recent earnings call, touted that 25% of their code was AI generated. This was apparently true, but in part that’s replacing the traditional autocomplete with a model copilot. The real question there is what % would have been “machine generated” anyway, and how much of the delta is LLM: very likely a much smaller number, especially for a company with very sophisticated developer infrastructure.

    The investigation is also a good counter to some breathless takes around vibe coding: this stuff is hard, and it will take a while to be able to automatically catch, root cause and resolve something like this:

    If you’d asked me, a month ago, to estimate the likelihood that an LLVM release caused a 10% performance regression in CPython and that no one noticed for five months, I’d have thought that a pretty unlikely state of affairs! Those are both widely-used projects, both of which care a fair bit about performance, and “surely” someone would have tested and noticed.

  • JIT decorators in Python

    https://eli.thegreenplace.net/2025/decorator-jits-python-as-a-dsl/

    Breaks down how different JIT decorators work, from AST based, to bytecode based, to tracers, which covers Triton, Numba and Jax as examples.

    In both cases, the function decorated with jit doesn’t get executed by the Python interpreter in the normal sense. Instead, the code inside is more like a DSL (Domain Specific Language) processed by a special purpose compiler built into the library (JAX or Triton). Another way to think about it is that Python is used as a meta language to describe computations.