Autotuning in PyTorch & Triton

torch.compile offers some knobs for controlling the trade-off of execution performance with longer compile times. This is particularly useful for inference, where the same model will be running for a long time.

model_autotune = torch.compile(model, mode="max-autotune")

Passing the max-autotune option to instructs the compiler to test more options for the operations. The compiler has the option to use pre-built aten kernels, leverage kernels from libraries like CuDNN or Cutlass, or use templated Triton kernels. When autotuning, specific variants are tested on device with the shape information identified during tracing, and the fastest options are selected. Thanks to Triton templates, it can also use options like fusions where pointwise ops can be fused into a single kernel via a Triton template, saving kernel launch overhead.

The downside of this is that testing the options takes more time, so using max-autotune can lead to some very extended compile times. You also need a hefty enough GPU to get the benefit: is_big_gpu gates it on the number of SMs, so it works best on a 3090, V100 or above.

You can see a lot of the autotuning options in _inductor/config.py. Backends that are considered are set separately for GEMMs and convolution ops:

max_autotune_gemm_backends = os.environ.get(
    "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
).upper()

Each kernel has implementations using for the different backends which are added to possible choices. e.g. in _inductor/kernels/mm.py you can see calls to use_[backend]_template that verify whether the backend in question is a choice:

if is_nonzero and use_cutlass_template(layout, m, n, k):
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

_inductor/select_algorithm.py does the actual benchmarking through the choices.

If you run autotuning, you’ll get some log output, and caches will be written to /tmp/torchinductor_yourusername.

We can try this out on a simple MLP:

import torch, time

class SimpleMLP(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features, hidden_features)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_features, out_features)
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

# Set up device and model
device = 'cuda'
model = SimpleMLP(in_features=1024, hidden_features=1024, out_features=1024).to(device)
x = torch.randn(256, 1024, device=device)  # batch of 256, 1024 features each

# Compile the model in default mode and max-autotune mode
model_default = torch.compile(model, mode="default")

# Warm-up runs (to trigger compilation)
torch.compiler.reset()
with torch.no_grad():
    model_default(x)
torch.cuda.synchronize()  # ensure warm-up completes

# Measure performance of default compiled model
start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
    start.record()
    for _ in range(50):
        _ = model_default(x)
    end.record()
torch.cuda.synchronize()
time_default_ms = start.elapsed_time(end) / 50.0
torch.compiler.reset()

model_autotune = torch.compile(model, mode="max-autotune")

with torch.no_grad():
    model_autotune(x)
torch.cuda.synchronize()  # ensure warm-up completes

# Measure performance of max-autotune compiled model
start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
    start.record()
    for _ in range(50):
        _ = model_autotune(x)
    end.record()
torch.cuda.synchronize()
time_autotune_ms = start.elapsed_time(end) / 50.0

print(f"Average inference time - torch.compile default: {time_default_ms:.3f} ms")
print(f"Average inference time - torch.compile max-autotune: {time_autotune_ms:.3f} ms")

Disappointedly, this is the result:

Average inference time - torch.compile default: 0.113 ms
Average inference time - torch.compile max-autotune: 3.251 ms

We can turn on logging with the TORCH_LOG env variable: some useful options are inductor, autotuning, and perf_hints.

TORCH_LOGS="perf_hints" python tune.py

You can control many more autotune options via the options flags, though its incompatible with passing a mode value. We can recreate the max-autotune mode, and turn on some useful tracing options like this (note that the options version uses an underscore, the mode a hypen!)

model_autotune = torch.compile(
model,
options={
"max_autotune": True,
"triton.cudagraphs": True,
"coordinate_descent_tuning": True,
"trace.enabled": True,
"trace.graph_diagram": True,
},
)

Options "trace.enabled": True, "trace.graph_diagram": True generate trace outputs, and output a nice diagram of the captured graph. Cudagraphs turned out to be the culprit here, which is common enough there is a non-cudagraph mode available to stop you having to remember all the options:

model_autotune = torch.compile(model, mode="max-autotune-no-cudagraphs")

As you can see here in the graphs of with and without, the slower version actually has an extra fusion performed!

Captured graphs for the two runs

Triton Autotuning

Triton also conducts autotuning, but it’s a little more explicit. When authoring a Triton kernel you can specify configurations. At compile time each config variant will be tested, the most performant one picked and the choice stored for future calls. A key value can be provided to indicate when to re-autotune based on changing inputs:

import os
import torch
import triton
import triton.language as tl

# Just to save passing this on the command line
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"  

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4,  num_stages=2),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=8,  num_stages=2),
    ],
    key=['N']            # re‑tune only if the length N changes
)
@triton.jit
def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid   = tl.program_id(0)
    offs  = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask  = offs < N
    x     = tl.load(x_ptr  + offs, mask=mask, other=0.0)
    y     = tl.load(y_ptr  + offs, mask=mask, other=0.0)
    tl.store(out_ptr + offs, x + y, mask=mask)

def vec_add(x: torch.Tensor, y: torch.Tensor):
    assert x.is_cuda and y.is_cuda
    N   = x.numel()
    out = torch.empty_like(x)
    grid = (triton.cdiv(N, 128),)         # 128 = smallest BLOCK_SIZE we declared
    vecadd_kernel[grid](x, y, out, N)    
    return out

x = torch.randn(1 << 20, device="cuda")   # 1 048 576 elements
y = torch.randn_like(x)

_ = vec_add(x, y)  # first call → autotuning prints to stdout
_ = vec_add(x, y)  # second call → no autotuning, uses the best config found

Setting the env variable TRITON_PRINT_AUTOTUNING documents the process as it goes:

Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
Autotuning kernel vecadd_kernel with config BLOCK_SIZE: 256, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
Triton autotuning for function vecadd_kernel finished after 0.44s; best config selected: BLOCK_SIZE: 128, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;

You can use the same do_bench tester that the autotuner does, and see how the performance varies yourself:

import torch, triton, triton.testing as tt
import triton.language as tl

@triton.jit
def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
    offs  = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask  = offs < N
    tl.store(out_ptr + offs,
             tl.load(x_ptr + offs, mask=mask) +
             tl.load(y_ptr + offs, mask=mask),
             mask=mask)

# tensors
N   = 1 << 20
x   = torch.randn(N, device='cuda')
y   = torch.randn_like(x)
out = torch.empty_like(x)

def bench(block_size, num_warps):
    grid = (triton.cdiv(N, block_size),)
    # tt.do_bench returns [median, p20, p80] in micro‑seconds
    return tt.do_bench(
        lambda: vecadd_kernel[grid](x, y, out, N, BLOCK_SIZE=block_size, num_warps=num_warps),
        warmup=5, rep=16, return_mode="all", quantiles=(0.5, 0.2, 0.8)
    )
timings = {
    "128/4": bench(128, 4),
    "256/8": bench(256, 8),
}

print("timings:", timings)

Running that gives shows that both kernels are basically equivalent, but the first one is slightly faster over the 16 runs.

timings: {'128/4': [0.01945599913597107, 0.01945599913597107, 0.02028159946203232], '256/8': [0.01945599913597107, 0.01945599913597107, 0.020479999482631683]}

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading