Dynamic Shapes in PyTorch

Dynamic shapes are one of the more distinctive parts of torch.compile. Rather than specializing a graph to static shapes (which works in many cases!), PyTorch’s approach allows a single graph to work for a variety of sizes, so things like sequence length or batch size can vary. It does this by reasoning about shapes symbolically: instead of using fixed shape values, it uses placeholders and infers rules that constrain those shapes.

Tracing & Symbolic Shapes

PyTorch uses tracing in Python (via Dynamo) to capture the graph of operations. By default, it marks shapes as static during tracing. If a shape marked as static changes at runtime, it is marked as dynamic and treated symbolically in a recompilation. You can also proactively mark a dimension as dynamic to encourage symbolic treatment from the start:

torch._dynamo.mark_dynamic(x, 0)  # Mark dim 0 as dynamic

Under the hood, PyTorch uses SymPy to represent and manipulate symbolic shapes. Each dynamic shape is replaced with a SymInt and tracked by a central ShapeEnv.

Every operation in PyTorch has a meta function — a lightweight implementation that computes metadata like shape changes without actually performing the computation. This lets PyTorch propagate symbolic shapes through the graph. For example, concatenating two tensors along dimension zero is represented symbolically as:

s0 = s_x0 + s_y0

To support branching logic, symbolic shapes carry a “hint” of their current concrete value. This allows specific branches of conditionals like if tensor.size(0) > 2 to be taken during tracing based on the hint. PyTorch adds a guard at this point to ensure that the resulting graph is only used if that branch is the correct one.

Guards

Guards are runtime checks inserted into the compiled graph to ensure the assumptions made during tracing still hold. For example, in the case of tensor.size(0) > 2, if the tensor is the result of concatenation, the guard will check that a.size(0) + b.size(0) > 2. If this fails, the code is retraced, and a graph for the new branch generated. Multiple graphs can be cached and selected at runtime based on guard validation.

Guards don’t need to assert exact sizes; they can use symbolic constraints like x.size(0) > 2. This allows dimensions to vary within bounded ranges. The backend compiler (usually Inductor) can then compile code that operates over symbolic dimensions, as long as the variability is within the guarded constraints.

For example, operations like broadcasting typically generalize well to symbolic shapes. In contrast, if an op specializes on a fixed shape (e.g., optimized path for 1D input), it may require conditional tracing and guards.

What this means in practice is that most of the time compilation will follow this process:

  • Take a batch of data, assume all shapes are static, insert guards, and pass static sizes to the compiler
  • On the next batch see which guards have been violated, mark those dimensions as dynamic, add appropriate guards and pass symbolic dimensions to the compiler
  • Assuming no control flow, continue to reuse this dynamic graph for without recompilation

Backed vs. Unbacked SymInts

Most symbolic shapes are backed, meaning they have an associated concrete value at trace time. These are usually derived from inputs and show up in traces as s0, s1, etc.

Unbacked SymInts lack a concrete value. These arise from data-dependent operations, e.g.:

n = (x > 1).nonzero().size(0)

Here, n depends on the data in x, so its size cannot be known at trace time. It will be represented as an unbacked SymInt like u0.

If a control flow decision depends on an unbacked SymInt, tracing cannot proceed, resulting in a graph break or a GuardOnDataDependentSymNode error (when full_graph=True).

However, you can guide the compiler with additional constraints, e.g.:

torch._check(x.size(0) == y, lambda: f"size mismatch: {x.size(0)} != {y}")

This lets PyTorch treat x.size(0) as equivalent to y throughout the graph. The check will be validated at runtime.

There are other APIs to help mark unbacked SymInts as size-like to enable meta function compatibility (see Ed’s docs for more).

Controlling Dynamic Shape Usage

You can control dynamic behavior in torch.compile with the dynamic flag:

  • Not passed: default shape inference behavior
  • dynamic=false: force all shapes to be static
  • dynamic=True: treat all shapes as dynamic

The default is usually best, but dynamic=True can help in testing.

Use full_graph=True to attempt to generate a single, complete graph without breaks. This is often critical for performance, as graph breaks can drastically affect runtime and it’s easy to make innocuous looking code changes that can trigger additional breaks!

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading