You may have noticed that FlashAttention 4 was supported in PyTorch really quickly. That required a bit of new infrastructure: torch.native by Simon Layton. Prior versions of FlashAttention were written in Cutlass/C++, but for FA4 the team implemented the kernel in CuteDSL.
You wouldn’t think that using an embedded Python DSL in a Python based Ml framework would be a challenge, except that almost all of the stuff that does ML in PyTorch is in fact… not written in Python. Replacing a PyTorch operator meant shipping a new native kernel and dealing with the build and dispatch pipeline.
Layton’s change opened the door to overriding default ops with ones authored in a embedded DSL, initially Triton or CuteDSL.
To be clear, this is not a replacement for custom ops, which most of the time is the best way of adding a new operator. torch.library.triton_op already lets you register a customer Triton kernel, for example. But FA4 is the kind of situation where wended an an alternative: it’s the right path for newer GPUs, it’s written in CuteDSL, and the PyTorch team wanted it to be available quickly to all PyTorch users without modifying their models.
To give an example, we can replace the built-in aten::_fused_rms_norm with a Triton version1:
"""Triton kernel for fused RMS normalization.RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight"""import tritonimport triton.language as tlimport torchtriton.jitdef _rms_norm_fwd_kernel( X_ptr, W_ptr, Y_ptr, RRMS_ptr, # reciprocal RMS, saved for backward stride_x_row, N_COLS: tl.constexpr, eps: tl.constexpr, BLOCK_SIZE: tl.constexpr, HAS_WEIGHT: tl.constexpr,): # [...] tl.store(RRMS_ptr + row_idx, rrms)def triton_rms_norm_forward( x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor | None, eps: float | None,) -> tuple[torch.Tensor, torch.Tensor]: """Fused RMSNorm forward pass using Triton.""" # [...] return y.reshape(orig_shape), rrms
Actually hooking it up requires calling a DSL-specific op override function, in this case triton_utils.register_op_override. This goes directly into the dispatch architecture, which means it works with autograd, torch.compile and so on.2
"""Register a Triton-based RMSNorm as a native op override using torch._native."""from torch._native import triton_utilsdef _triton_fused_rms_norm(dispatch_keys, x, normalized_shape, weight, eps): """ Wrapper that lazily imports the Triton kernel on first call. """ from triton_kernels import triton_rms_norm_forward return triton_rms_norm_forward(x, normalized_shape, weight, eps)def register(): """Register the Triton RMSNorm override.""" triton_utils.register_op_override( "aten", # lib_symbol: override an aten op "_fused_rms_norm", # op_symbol: the specific op "CUDA", # dispatch_key: only on CUDA _triton_fused_rms_norm, # impl: our wrapper unconditional_override=False, # receives dispatch_keys as first arg )
Now when we call torch.ops.aten._fused_rms_norm(x, shape, weight, eps) PyTorch will automatically use our Triton override!3
The unconditional_override param in the registration call is a helpful one: if false the function receives torch.DispatchKeySet as its first argument. This allows overriding only in specific circumstances. For example, our Triton kernel is faster than the C++ one only for larger shapes, so we could gate the decision on that:
def _smart_rms_norm(dispatch_keys, x, normalized_shape, weight, eps): n_rows = x.numel() // normalized_shape[-1] if n_rows < 4096: # Fall back to default C++ kernel for small shapes.. return torch.ops.aten._fused_rms_norm.default(x, normalized_shape, weight, eps) from triton_kernels import triton_rms_norm_forward return triton_rms_norm_forward(x, normalized_shape, weight, eps)
Going back to FlashAttention4, this overrides aten::_scaled_dot_product_flash_attention so that any code using torch.nn.functional.scaled_dot_product_attention will transparently get the FA4 implementation.
torch._native fundamentally lowers the barriers to entry for bringing new kernel implementations into PyTorch. That’s good for mainline PyTorch, and it also allows ML infrastructure teams to ship optimized kernels for new hardware without waiting for the PyTorch release cycle.
- Trimmed for length, ask your favorite coding agent to write you an RMSnorm kernel, or look at the gist. ↩︎
- To avoid hammering the import latency, all DSL runtimes are lazily loaded when the kernel is first called ↩︎
- In this case we only override for CUDA tensors, so CPU ops will continue to be handled by the default implementation. ↩︎