PyTorch compiler engineer Richard Zou did a great Q&A session with the GPU Mode discord community recently. You can watch the session on YouTube, but Richard also collected questions into a doc with some nice snippets and references.
Our value proposition: You can sit down for hours/days/weeks tuning a custom kernel. torch.compile provides good baseline performance so you don’t need to do that all the time!
The goal with the compiler is that you can spend most of your time thinking about the model, get the majority of the speedups, and only have to go down to custom kernel authoring when you’ve established an opportunity or need for further performance.
torch.compile offers some knobs for controlling the trade-off of execution performance with longer compile times. This is particularly useful for inference, where the same model will be running for a long time.
Passing the max-autotune option to instructs the compiler to test more options for the operations. The compiler has the option to use pre-built aten kernels, leverage kernels from libraries like CuDNN or Cutlass, or use templated Triton kernels. When autotuning, specific variants are tested on device with the shape information identified during tracing, and the fastest options are selected. Thanks to Triton templates, it can also use options like fusions where pointwise ops can be fused into a single kernel via a Triton template, saving kernel launch overhead.
The downside of this is that testing the options takes more time, so using max-autotune can lead to some very extended compile times. You also need a hefty enough GPU to get the benefit: is_big_gpu gates it on the number of SMs, so it works best on a 3090, V100 or above.
You can see a lot of the autotuning options in _inductor/config.py. Backends that are considered are set separately for GEMMs and convolution ops:
Each kernel has implementations using for the different backends which are added to possible choices. e.g. in _inductor/kernels/mm.py you can see calls to use_[backend]_template that verify whether the backend in question is a choice:
if is_nonzero and use_cutlass_template(layout, m, n, k):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
If you run autotuning, you’ll get some log output, and caches will be written to /tmp/torchinductor_yourusername.
We can try this out on a simple MLP:
import torch, time
class SimpleMLP(torch.nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, hidden_features)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
# Set up device and model
device = 'cuda'
model = SimpleMLP(in_features=1024, hidden_features=1024, out_features=1024).to(device)
x = torch.randn(256, 1024, device=device) # batch of 256, 1024 features each
# Compile the model in default mode and max-autotune mode
model_default = torch.compile(model, mode="default")
# Warm-up runs (to trigger compilation)
torch.compiler.reset()
with torch.no_grad():
model_default(x)
torch.cuda.synchronize() # ensure warm-up completes
# Measure performance of default compiled model
start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
start.record()
for _ in range(50):
_ = model_default(x)
end.record()
torch.cuda.synchronize()
time_default_ms = start.elapsed_time(end) / 50.0
torch.compiler.reset()
model_autotune = torch.compile(model, mode="max-autotune")
with torch.no_grad():
model_autotune(x)
torch.cuda.synchronize() # ensure warm-up completes
# Measure performance of max-autotune compiled model
start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
start.record()
for _ in range(50):
_ = model_autotune(x)
end.record()
torch.cuda.synchronize()
time_autotune_ms = start.elapsed_time(end) / 50.0
print(f"Average inference time - torch.compile default: {time_default_ms:.3f} ms")
print(f"Average inference time - torch.compile max-autotune: {time_autotune_ms:.3f} ms")
Disappointedly, this is the result:
Average inference time - torch.compile default: 0.113 ms Average inference time - torch.compile max-autotune: 3.251 ms
We can turn on logging with the TORCH_LOG env variable: some useful options are inductor, autotuning, and perf_hints.
TORCH_LOGS="perf_hints" python tune.py
You can control many more autotune options via the options flags, though its incompatible with passing a mode value. We can recreate the max-autotune mode, and turn on some useful tracing options like this (note that the options version uses an underscore, the mode a hypen!)
Options "trace.enabled": True, "trace.graph_diagram": True generate trace outputs, and output a nice diagram of the captured graph. Cudagraphs turned out to be the culprit here, which is common enough there is a non-cudagraph mode available to stop you having to remember all the options:
As you can see here in the graphs of with and without, the slower version actually has an extra fusion performed!
Captured graphs for the two runs
Triton Autotuning
Triton also conducts autotuning, but it’s a little more explicit. When authoring a Triton kernel you can specify configurations. At compile time each config variant will be tested, the most performant one picked and the choice stored for future calls. A key value can be provided to indicate when to re-autotune based on changing inputs:
import os
import torch
import triton
import triton.language as tl
# Just to save passing this on the command line
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 256}, num_warps=8, num_stages=2),
],
key=['N'] # re‑tune only if the length N changes
)
@triton.jit
def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(x_ptr + offs, mask=mask, other=0.0)
y = tl.load(y_ptr + offs, mask=mask, other=0.0)
tl.store(out_ptr + offs, x + y, mask=mask)
def vec_add(x: torch.Tensor, y: torch.Tensor):
assert x.is_cuda and y.is_cuda
N = x.numel()
out = torch.empty_like(x)
grid = (triton.cdiv(N, 128),) # 128 = smallest BLOCK_SIZE we declared
vecadd_kernel[grid](x, y, out, N)
return out
x = torch.randn(1 << 20, device="cuda") # 1 048 576 elements
y = torch.randn_like(x)
_ = vec_add(x, y) # first call → autotuning prints to stdout
_ = vec_add(x, y) # second call → no autotuning, uses the best config found
Setting the env variable TRITON_PRINT_AUTOTUNING documents the process as it goes:
There are a couple of different options to profile a Triton kernel.
Proton
Proton is the profiler that ships with Triton (profiler for triton). You can enable it and (optionally) activate/deactive around specific regions you want to profile. You have the ability to annotate functions with specific metrics as well.
session = proton.start() # Start profiling session
bias = torch.rand((256,), device='cuda', dtype=torch.float16) # Bias vector
flops = 2 * M * N * K
bytes_accessed = A.element_size() * M*K + B.element_size() * K*N + C.element_size() * M*N # rough bytes
with proton.scope(f"fused_gemm_bias_relu [M={M}, N={N}, K={K}]", {"flops": flops, "bytes": bytes_accessed}):
fused_gemm_bias_relu[grid](
A, B, C, bias,
M, N, K,
A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
)
proton.finalize()
The output can be visualized with the built-in viewer:
proton-viewer -m time/ms,tflop/s ./proton.hatchet
0.040 6.645 ROOT ├─ 0.004 nan _ZN2at6native55_GLOBAL__N__11f7a751_22_DistributionUniform_[...]_15PhiloxCudaStateESH_SI_ └─ 0.037 7.284 fused_gemm_bias_relu [M=1024, N=256, K=512] └─ 0.037 nan fused_gemm_bias_relu
In this case you can see both the (trimmed!) generated name for the bias tensor set up as well as the name of my custom kernel.
nsight-compute
Nvidia also have a good range of tools for looking at performance. Note will need to enable access to the counters on device for this:
Nvidia ships nsight system which tracks larger system wide metrics, and nsight compute which is more focused on profiling execution. You can run it against a script like so:
ncu -o profile_results python test.py
The tool comes with a nice GUI for inspecting the results. It can show you the PTX or SASS source for the kernels, offers metrics like actively used registers (good for checking on register spilling), and calls out warnings on poor utilization or memory clashes.
There is an extension coming for Proton that enables profiling within kernels. This reserves a pre-allocated buffer on device and logs metrics locally, for reading out at the end of the execution. It outputs as a chrome trace for use within a wide range of dev tools. While this isn’t merged into mainline yet, you can see an example of the usage in the dev repo.
JEPA is an example of a predictive coding architecture. Predictive coding is the idea that the brain works by predicting the next sensory input and learns from the error between the prediction and the actual input. Hierarchies of this kind of prediction allow higher level elements to predict the outputs of lower-level elements, building up deeper and more complex semantics.
The core idea in JEPA is to take two related things (say consecutive frames of a video) x and y, encode each of them into a latent space (an embedding), and then predict the embedding for y s(y) based on s(x). The encoders can be Transformer-based models — in practice models like I-JEPA have trained the x encoder and updated the weights via a moving average for the y (target) encoder.
The learning is not based on how well the end-result predicts the target e.g. how close the pixels of the next frame are predicted. Instead, it’s based on how well the latent representation of the next frame is predicted.
The advantage of working in the latent space for the prediction is the model can choose what level of detail it wants to capture, discarding some aspects and focusing on more foundational concepts. This helps build a more robust world-model, with the hope being that training in this way will then allow easier generalization to more tasks, with less data required
Similarities
This is somewhat similar to autoencoders. Autoencoders take an input, compress it in a latent space, then reconstruct the original from the latent space and propagate back the error. JEPA does a similar process across two different items with separate encoders, and only cares about error within the latent space.
Contrastive models embed two different items into the same space and try to increase similarity between the embeddings for things known to be similar and make them dissimilar to other items. This is used in CLIP and other multimodal text-image encoders, where the text and the image embed to the same space so that a text caption and a matching image are close in embedding space. This requires a lot of pairwise comparisons, while JEPA is a more straightforward s(x)->s(y) prediction in training.
Challenges
Because JEPA models leave you with a latent they need to be paired with a generator for getting an observable/human viewable output, which is a per-domain challenge. This makes it harder to evaluate how well the model is learning, beyond measuring loss.
Training stability can also be tricky — it is possible for the model to collapse and learn trivial representations to minimize prediction error. Even without complete collapse it can require some experimentation to ensure the model is learning a deep enough conceptual level. For example, I-JEPA, which worked in image space, found that using large enough masked patches was important to ensure the model captured sufficient detail.
Other than the helpful docs the Quansight folks maintain (py-free-threading) it’s been interesting to see some projects and tools pop up on there I hadn’t see. One being Zsolt’s py-free-threading which checks whether your project and deps have FT wheels for any non-pure python deps, which can be run as a 1-liner thanks to uv:
Dives deep into TMEM into particular, and the trend over the last few Nvidia generations of special-casing GEMMS in hardware:
Tensor Memory and UMMA do for MMA just what TMA did for copy, making it a single-threaded, asynchronous operation that does not consume registers. As a result, registers can primarily be used for other tasks like scheduling and fused epilogue operations.
Edit: link no longer seems to be working! It was a great post though, so hopefully comes back! Edit edit: it did!
LLMs have most of their parameters in the FFN parts of the transformer layers — 50+bn params of the Llama 3 70b model, for example. The compute and memory requirements are a bit different between the FFN and attention parts of the model: attention requires a different KV cache for each request, so attention tends to be memory bound while the dense FFNs tend to be compute bound.
Because of this it’s pretty common to split up tasks at inference time. The initial prefill stage (processing the initial prompt) populates the KV cache for the following autoregressive decoding. The decode can be more aggressively batched for getting better utilization. vLLM really helped popularize this idea!
ByteDance extend this idea for mixture of expert models. In MoEs the compute intensity of the FFNs is limited by needing to load different experts, and having only a proportion of tokens going through a given expert. They extend the disagg idea to go from M “attention” GPUs to N (fewer!) expert GPUs, with a larger batch size for each of the expert calls. This gets better utilization on the matmuls and lowers overall cost of serving. The natural structure of transformer layers alternating attention and FFN lends itself well to a ping-pong pipelining approach that lets them hide the comms overhead.
We present MegaScale-Infer, an efficient and cost-effective system designed for large-scale MoE serving. MegaScale-Infer disaggregates the attention and expert modules, assigning them to separate GPUs—a strategy we term disaggregated expert parallelism. Our approach offers two major benefits. First, it enables independent scaling of each module with customized model parallelism strategies. Specifically, attention modules are replicated using data parallelism, while FFN modules are scaled with expert parallelism. By consolidating requests from multiple attention replicas, the GPU utilization of each expert increases significantly as the batch size per attention replica grows. Second, it enables the deployment of attention and FFN modules on heterogeneous GPUs to fully leverage their different capabilities and achieve lower costs. For example, attention modules can be deployed on GPUs with more cost-effective memory capacity and bandwidth, while FFN modules can utilize GPUs with more affordable compute capability. As shown in Figure 1(c), FFN can easily become compute-intensive in MegaScale-Infer, while attention achieves higher GPU utilization per cost under heterogeneous deploymen
Ed Yang was recently recommending keeping your own benchmark of LLM evals, so you can test newer models on problems that they have struggled with in the past. I have recommended similar things to people, but there is some barrier to entry into knowing how to start. Ed references (and forks) Nicolas Carlini’s personal benchmark repo, but its nice to have some light(ish) weight options too.
Pydantic Evals is a powerful evaluation framework designed to help you systematically test and evaluate the performance and accuracy of the systems you build, especially when working with LLMs.
You can install the library with uv or pip:
uv add pydantic-evals
I tried it out with a strawberry test, calling openrouter with different models. I needed a custom eval as the default Contains is a bit rigid, but the approach seems nice!
import os
import asyncio
from dataclasses import dataclass
from pydantic_evals import Case, Dataset
from pydantic_evals.evaluators import Evaluator, EvaluationReason, EvaluatorContext
from pydantic_evals.evaluators.common import _truncated_repr
from openai import OpenAI
from typing import Any, Optional, cast
@dataclass
class FlexibleContains(Evaluator[object, object, object]):
"""
Check if the output contains any one of the expected options.
"""
value: Any
case_sensitive: bool = False
def evaluate(
self, ctx: EvaluatorContext[object, object, object]
) -> EvaluationReason:
failure_reason: Optional[str] = None
# Normalize value into a list of options if it isn't already a list or tuple.
options = self.value if isinstance(self.value, (list, tuple)) else [self.value]
output_str = str(ctx.output)
if not self.case_sensitive:
output_str = output_str.lower()
match_found = False
for opt in options:
opt_str = str(opt)
if not self.case_sensitive:
opt_str = opt_str.lower()
if opt_str in output_str:
match_found = True
break
if not match_found:
failure_reason = (
f"Output string {_truncated_repr(output_str, max_length=100)} does not contain "
f"any of expected strings: {[str(opt) for opt in options]}"
)
return EvaluationReason(value=match_found, reason=failure_reason)
strawberry = Case(
name="strawberry",
inputs="How many rs are in strawberry?",
evaluators=[FlexibleContains(value=["3", "three"])],
metadata={"difficulty": "easy"},
)
dataset = Dataset(cases=[strawberry])
MODELS = [
"anthropic/claude-3.5-sonnet",
"openai/gpt-4o",
"meta-llama/llama-4-maverick:free",
"meta-llama/llama-4-scout:free",
"openrouter/optimus-alpha", # secret model!
]
def generate_completion(inputs: str, model: str) -> str:
"""Generate a completion using OpenRouter with specified model"""
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": inputs},
],
max_tokens=50,
temperature=0.7,
)
return response.choices[0].message.content.strip()
def evaluate_models():
"""Run evaluations across multiple models"""
for model in MODELS:
print(f"\nResults for model: {model}")
print("=" * 50)
# Wrap the synchronous generate_completion in an async function:
async def model_specific_generate(inputs: str) -> str:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, generate_completion, inputs, model)
# Run evaluation for this model
report = dataset.evaluate_sync(model_specific_generate)
# Print results for this model
report.print(include_input=True, include_output=True, include_durations=False)
def main():
evaluate_models()
if __name__ == "__main__":
main()
To give a trimmed output:
Results for model: openrouter/optimus-alpha
==================================================
Evaluation Summary: model_specific_generate
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Case ID ┃ Inputs ┃ Outputs ┃ Assertions ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ strawberry │ How many rs are in strawberry? │ The word **"strawberry"** contains **three** letter "r"s. │ ✔ │
├────────────┼────────────────────────────────┼───────────────────────────────────────────────────────────┼────────────┤
│ Averages │ │ │ 100.0% ✔ │
└────────────┴────────────────────────────────┴───────────────────────────────────────────────────────────┴────────────┘
JMMV shares this thoughts on build beyond bazel, and highlights Ed Schouten (of BuildBarn) and his experiments with Bonanza. Bonanza maintains Bazel compatibility, but is remote execution first, does analysis on the remote workers with a distributed cache to minimize cold builds, and has starlark-only for rules (as with Buck2). Also like Buck2, its written in Rust.
When I was at Lyft we went through a laborious build system migration, including on Android going from Gradle to Buck(1) to Bazel, which at one point involved a full shim that allowed single build files to work on either Buck or Bazel. The idea of being able to keep the same build definitions but swap out engines is pretty appealing.
Julio actually calls for yet-another Bazel replacement as well, this one more focused on the small project/local build case: I can definitely see the appeal there!
The time for these next-generation Bazel-compatible build systems is now. Google has spent the last 10 years Starlark-ifying Bazel, making the core execution engine replaceable. We are reaching a point where the vast majority of the build logic can be written in Starlark as Bonanza proves, and thus we should be able to have different build tools that implement the same build system for different use cases
Eugene Yan has put together a really extensive survey of recent research exploring the use of LLMs in recommendation systems.
Although early research in 2023—that applied LLMs to recommendations and search—often fell short, these recent efforts show more promise, especially since they’re backed by industry results. It suggests that there are tangible benefits from exploring the augmentation of recsys and search systems with LLMs, increasing performance while reducing cost and effort.
Recommendation systems are enormously important to a large swathes of tech business, primarily for e-commerce, content and advertisement targeting. Traditional deep recommenders typically use a two-tower architecture: one tower for users and another for items, independently encoding features into embeddings that can be scored together to retrieve and rank items. Features in each tower include both sparse (usually categorical, e.g., item categories, user histories) and dense (often continuous, e.g., age, price).
This design is popular because its effective and scalable: you can cache each tower’s embedding vectors and only pull in the ones you need for a given query (e.g. the batch of users you are getting recommendations for right now).
Despite the effectiveness and scalability of this approach, traditional systems often struggle with a set of known issues, such as cold-start problems—predicting relevant content for new items or users—and typically don’t consider interaction recency without additional engineering.
Yan categorizes recent research into four areas:
LLM/Multimodal Architectures:
Directly embedding content understanding within the models. Content understanding has been used for a long time via separate models to generate additional metadata for content items to help both with cold start and accuracy.
Generative approaches, which reframe recommendation as predicting future user actions based on interaction sequences.
LLM-Assisted Data Generation and Analysis:
Improving content understanding and generating richer metadata for items.
Scaling Laws, Transfer, and Distillation:
Adapting LLMs to meet latency requirements of recommendations, through smaller models and efficient inference techniques. RecSys, particularly models for advertising, tend to have very low latency requirements.
Unified Architectures for Search and Recommendations:
Consolidating search and recommendation tasks into unified models that enable returning items based on interaction histories and/or user queries simultaneously.
There are a couple of common themes from reading the summaries:
Semantic Content Integration & Joint Tasks: Techniques like YouTube’s Semantic IDs and Kuaishou’s M3CSR generate content-based identifiers replacing traditional hashed IDs. The idea is to have inputs to the models represent the content in a way that carries meaning, rather than represent an identifier for the content.
Efficiency in Inference: Teacher-student distillation and efficient fine-tuning allow generating smaller, performant models for specific needs. For instance, Alibaba’s MLoRA trains a base model then LoRA fine-tunes for specific types of content, replacing a number of independently trained models.
These two combine somewhat to enable a trend towards more foundation- model-like training in RecSys that tackle a variety of user personalization tasks with a unified view of users, content, and user/content interactions.
Dynamic shapes are one of the more distinctive parts of torch.compile. Rather than specializing a graph to static shapes (which works in many cases!), PyTorch’s approach allows a single graph to work for a variety of sizes, so things like sequence length or batch size can vary. It does this by reasoning about shapes symbolically: instead of using fixed shape values, it uses placeholders and infers rules that constrain those shapes.
Tracing & Symbolic Shapes
PyTorch uses tracing in Python (via Dynamo) to capture the graph of operations. By default, it marks shapes as static during tracing. If a shape marked as static changes at runtime, it is marked as dynamic and treated symbolically in a recompilation. You can also proactively mark a dimension as dynamic to encourage symbolic treatment from the start:
torch._dynamo.mark_dynamic(x, 0) # Mark dim 0 as dynamic
Under the hood, PyTorch uses SymPy to represent and manipulate symbolic shapes. Each dynamic shape is replaced with a SymInt and tracked by a central ShapeEnv.
Every operation in PyTorch has a meta function — a lightweight implementation that computes metadata like shape changes without actually performing the computation. This lets PyTorch propagate symbolic shapes through the graph. For example, concatenating two tensors along dimension zero is represented symbolically as:
s0 = s_x0 + s_y0
To support branching logic, symbolic shapes carry a “hint” of their current concrete value. This allows specific branches of conditionals like if tensor.size(0) > 2 to be taken during tracing based on the hint. PyTorch adds a guard at this point to ensure that the resulting graph is only used if that branch is the correct one.
Guards
Guards are runtime checks inserted into the compiled graph to ensure the assumptions made during tracing still hold. For example, in the case of tensor.size(0) > 2, if the tensor is the result of concatenation, the guard will check that a.size(0) + b.size(0) > 2. If this fails, the code is retraced, and a graph for the new branch generated. Multiple graphs can be cached and selected at runtime based on guard validation.
Guards don’t need to assert exact sizes; they can use symbolic constraints like x.size(0) > 2. This allows dimensions to vary within bounded ranges. The backend compiler (usually Inductor) can then compile code that operates over symbolic dimensions, as long as the variability is within the guarded constraints.
For example, operations like broadcasting typically generalize well to symbolic shapes. In contrast, if an op specializes on a fixed shape (e.g., optimized path for 1D input), it may require conditional tracing and guards.
What this means in practice is that most of the time compilation will follow this process:
Take a batch of data, assume all shapes are static, insert guards, and pass static sizes to the compiler
On the next batch see which guards have been violated, mark those dimensions as dynamic, add appropriate guards and pass symbolic dimensions to the compiler
Assuming no control flow, continue to reuse this dynamic graph for without recompilation
Backed vs. Unbacked SymInts
Most symbolic shapes are backed, meaning they have an associated concrete value at trace time. These are usually derived from inputs and show up in traces as s0, s1, etc.
UnbackedSymInts lack a concrete value. These arise from data-dependent operations, e.g.:
n = (x > 1).nonzero().size(0)
Here, n depends on the data in x, so its size cannot be known at trace time. It will be represented as an unbacked SymInt like u0.
If a control flow decision depends on an unbacked SymInt, tracing cannot proceed, resulting in a graph break or a GuardOnDataDependentSymNode error (when full_graph=True).
However, you can guide the compiler with additional constraints, e.g.:
torch._check(x.size(0) == y, lambda: f"size mismatch: {x.size(0)} != {y}")
This lets PyTorch treat x.size(0) as equivalent to y throughout the graph. The check will be validated at runtime.
There are other APIs to help mark unbacked SymInts as size-like to enable meta function compatibility (see Ed’s docs for more).
Controlling Dynamic Shape Usage
You can control dynamic behavior in torch.compile with the dynamic flag:
Not passed: default shape inference behavior
dynamic=false: force all shapes to be static
dynamic=True: treat all shapes as dynamic
The default is usually best, but dynamic=True can help in testing.
Use full_graph=True to attempt to generate a single, complete graph without breaks. This is often critical for performance, as graph breaks can drastically affect runtime and it’s easy to make innocuous looking code changes that can trigger additional breaks!
Trent Nelson has written an extremely detailed breakdown of his experiments with running inference on GPT-2 on PyTorch and the GIL-free version of Python from 3.13 and 3.14.
He implements parallel generation using multiple threads (on one GPU and later multiple devices), parallel model loading, and then some of the challenges with torch.compile (which doesn’t work great with nogil yet!)
Hopefully this encourages more folks to experiment with free-threaded Python, or perhaps port their existing Python packages to play nicely when installed in a free-threaded Python environment. I personally can’t wait until free-threaded Python is the default! Although that’s probably at least five or so years out at this point.
Free threaded python really changes the performance trade-offs around Python, and I expect it to be the default for ML work a lot sooner than that!
Further support for elicitation and a very good example of why its worth starting with the evals.
They started with evals for mathematical reasoning, and then tested base Deepseek, Qwen and Llama models with different templates to see how they did prior to any RL. In doing so they discovered that Qwen did best with no template, and that Deepseek v3 already would create “aha” moments (reasoning self reflection) without any further tuning. Some of the examples are amusing, particuarly the “awkward silence”:
In Pascal’s Triangle, every row starts and ends with 1, … … This can be calculated as: awkward silence Wait, I’m overthinking. Let’s try again. The number of elements in the first n rows of Pascal’s Triangle…
The second interesting takeaway for me was that the GRPO implementation (along with most PPO implementations) have a length bias: wrong but long answers are preferred over wrong but short ones. Their corrected, unbiased approach resulted in the same performance, but better token efficiency:
We also revisit the GRPO’s optimization bias with the Llama base model. The right plot of Fig. 8 compares the model performance and response length trained with GRPO and Dr. GRPO [Ed: the unbiased version]. We can clearly see that GRPO can produce the “double-increase” phenomenon, potentially leading to a misperception that long-CoT can also emerge on Llama models after math pretraining. Unfortunately, the increase of length might be due to the optimization bias
When data is in the global memory on a GPU it’s usually in row-major or column-major order. Loading from global memory is quite slow though, so for performance we want to move the data to shared memory for the threads in a warp to work on.
To make that load from global memory performance we want memory reads to be coalesced, meaning we are reading contiguous chunk of memory at a time. Shared memory on the other hand is divided into banks, typically 32 banks which are 4 bytes wide. If multiple threads in the same warp try to write to different addresses in the same bank then the requests are processed sequentially, slowing things down while the threads wait on each other. Nsight and other profiling tools will helpfully point this out to you!
For example, let’s say we’re loading a row major and column major tensor, and will be doing a multiplication between them (this is naive, to demonstrate the issue):
__shared__ float Asub[TILE_DIM][TILE_DIM];
__shared__ float Bsub[TILE_DIM][TILE_DIM]; // (No padding in this naive version)
int lane = threadIdx.x; // 0...31 (warp lane index)
int tileRow = blockIdx.y * TILE_DIM;
int tileCol = blockIdx.x * TILE_DIM;
int globalRow = tileRow + lane;
int globalCol = tileCol + lane;
Asub[lane][0] = A[globalRow * N + tileCol + 0];
Bsub[lane][0] = B[(tileRow + lane) + (tileCol + 0) * N];
Now when we fill Bsub we will be writing everything to the same shared memory bank, significantly slowing things down. One easy fix is just to add padding:
__shared__ float Asub[TILE_DIM][TILE_DIM]; // A tile (row-major, no conflict in our case)
__shared__ float Bsub[TILE_DIM][TILE_DIM + PAD]; // B tile (extra column to prevent conflicts)
With PAD as 1 (and TILE_DIM as 32) we have 32×33, or 132 bytes, offsetting the writes and ensuring that each thread gets its own bank.
The downside is that this wastes shared memory, a scarce resource, so an alternative approach is swizzling: changing the layout such that consecutive thread accesses aren’t causing bank conflicts. That’s what Bert implemented to get performance in his recent GEMM walkthrough, but it’s easy to get it wrong.
To make life easier than writing it in raw CUDA, Cutlass has a system called CuTE. Cute is a set of templates to express layout of data:
auto tileLayout = make_layout(make_shape(Int<32>{}, Int<32>{}), GenRowMajor{});
auto swizzledLayout = composition(Swizzle<5, 0, 5>{}, tileLayout);
Here you specify how the data is laid out in global memory with the shape and stride, then make_layout and the copy operation take care of translating from the row-major layout in global memory to the swizzled layout in shared memory.
From a Triton perspective, Lei Zhang has a great post on memory access, and how it works in Triton, specifically the LinearLayout class that allows the language to similarly handle swizzling and layouts for you:
Indeed the whole point of LLs is that they allow us to specify transposed and swizzled layouts as a “general case”. Instead of a layout class for registers in a thread, and another layout for registers in a thread but in MMAv2 order, and so on, all of these can be represented by different LLs. This gets rid of special cases and lets us write more general code.
There’s a great colfax report on building GEMMS that covers shared memory bank conflicts, and Lei Mao has a post with a nice illustration. Axel Feldman also has a post about benchmarking different approaches and identifying bank conflicts, and some more efficient loading techniques.
Very useful insight in this paper out of Stanford.
Test-time inference has emerged as a powerful paradigm for enabling language models to “think” longer and more carefully about complex challenges, much like skilled human experts. While reinforcement learning (RL) can drive self-improvement in language models on verifiable tasks, some models exhibit substantial gains while others quickly plateau.
The authors were running a reasoning post-training process on both Qwen 2.5 3B and Llama 3.2 3B. They noticed that while both learned, Llama was consistently worse than Qwen, which feels odd as both models are strong. In looking at the reasoning approaches exhibited they observed 4 distinct reasoning strategies:
verification
backtracking
subgoal setting
backward chaining
They noticed that Qwen exhibited these behaviors more from the base model, and those behaviors were enhanced more in the RL process.
While the larger Llama-3.1-70B showed generally increased activation of these behaviors compared to Llama-3.2-3B, this improvement was notably uneven — backtracking, in particular, remained limited even in the larger model.
They then generated some custom reasoning traces that intentionally demonstrated all 4 behaviors, using Claude.
We generate these datasets using Claude-3.5-Sonnet4, leveraging its ability to produce reasoning trajectories with precisely specified behavioral characteristics. While Claude does not always produce the correct answer (see Fig. 9), it consistently demonstrates the requested reasoning patterns, providing clean behavioral primitives for our analysis.
They found that when using the SFT set before RL they closed most of the gap between Llama and Qwen. They also found that it isn’t even important that the reasoning traces are correct – demonstrating the behavior is more important than the reasoning itself at the SFT stage.
Priming models with cognitive behaviors, by a small amount of finetuning, enabled significant performance gains even in models that initially lack these capabilities. Remarkably, this holds even when primed with incorrect solutions that exhibit the target behavioral patterns, suggesting that cognitive behaviors matter more than solution accuracy.
This fits with the elicitation idea — the SFT is training a style. By having increased activation of the reasoning styles the RL process is more able to explore these capabilities and reinforce extended reasoning generation.
This also fits with my mental model that a base model’s capabilities are often pretty underexplored: a combination of targeted SFT + RL seems to be a very powerful elicitation tool!