Paper from the Triton folks at OpenAI on their solution to the layouts/data movement problem. Data often needs to be laid out in a specific way to maximize performance on a GPU. This includes certain instructions, and also avoidance of bank conflicts in shared memory. You might have data stored nicely in global memory, need to permute it to load, then permute it again for execution.
Part of the appeal of CuTe is expressing these layouts and allowing a relatively simple algebra to transform it between these domains. This works, but the Triton approach is to try and hide this type of complexity, particularly hardware specific complexity, in the compiler.
While both CUTE and linear layouts aim to address the challenge of flexible task mapping on emerging architectures, they differ in several key aspects. First and foremost, CUTE is primarily designed for users to manually describe layouts, whereas linear layouts are integrated into a compiler. Second, the linear algebra framework of linear layouts enables compilers to generate efficient code for layout conversion and code lowering for many common operators, which is absent in CUTE. Third, swizzling is inherently defined within linear layouts, whereas in CUTE, it is treated as a separate step
The clever insight is that you can represent any of the layouts as a binary matrix over F₂, which means you can use XOR/AND for arithmetic. You can compose those binary matrices freely, and it’s also easy to replace the transform matrix with a new one for hardware that requires a different permutation.
To give a step-by-step example (as I’m not totally sure how well I grok this myself!) let’s say we are working on am MMA for a 16×8 tile:
We start with our data, say in row major order (0,0), (0,1), …, (0,7), (1,0). Each value is stored in its own register
We have 32 threads, each managing their own section of the block: in this case 4 registers
So we have a hardware location for each value: the thread (0..31) and the register (0..3). You can imagine this as 7 bits of data, thread ID (5 bits), and register ID (2 bits)
Equivalently we have imagine tracking the tensor location for each value: 4 bits for 0..15 rows, 3 bits for 0..7 columns
We can have a map which translates between tensor location and hardware location: block location row 1 col 0 is in thread 2 register 0. This would be a 7 by 7 binary matrix
We can define a matrix that transforms the hardware map to the one needed for our ldmatrix tensorcore call.
For example, we might need thread 0 to manage tensor values (0,0), (4,0), (8,0), (12,0)
If the mapping requires moving a value to a different register in the same thread we can use a prmt (permute) instruction
If the mapping requires moving values between thread’s registers, we can use a warp shuffle like shfl.sync that allows swapping registers between threads without using shared memory1
Triton has layouts for standard block level storage, and for MMAs and other operations. By multiplying through the required mappings it can automatically work out how best to optimize movement, versus the manual transforms you do in CuTe!
It also has versions of these mappings for different hardware, so for many operations only the layouts need to be swapped out when moving from Ampere to Hopper or Blackwell!
mostly. if there will be bank conflicts, it will spill to shared memory. ↩︎
Back in 2022, Google’s Pathways paper proposed (revisiting) a single-controller approach for managing machine learning runs. Typically, ML jobs use an SPMD (Single Program, Multiple Data) approach, distributing identical code across multiple hosts. Each host runs independently, synchronizing during collective operations. This works, as evidenced by the many large training runs in the world. It also introduces complexity, especially with pipeline parallelism where conditional logic for different ranks can clutter up your training code. Even without that, subtle issues can arise: for example, slight differences in torch.compile optimization have (in the past!) lead to deadlocks by placing collectives differently on separate nodes.
The single-controller model simplifies this by centralizing program execution on one main node and using generic workers on the hosts that execute assigned tasks. This provides a consistent, global view of the entire computation, making it easier to get to a correct implementation of parallelisms and other distributed work. This doesn’t come for free though: the main node must efficiently manage (potentially!) thousands of GPUs without becoming a bottleneck, and existing code must adapt to this new centralized model.
Monarch is the PyTorch team’s implementation of this single-controller concept. It provides a familiar PyTorch frontend, additional module wrappers, and a high-performance Rust-based actor system for distributing and managing work.
The fundamental abstraction in Monarch is the Actor. Each Actor executes on their own accelerator, maintains state and behavior. Communication with other Actors is via async message passing on methods decorated with @endpoint. The nice thing about the programming model is you can interact with all of the actors in your mesh in a consistent way.
Monarch is appealing even if you’re not GPU-rich. For instance, at home, I have a machine equipped with two (mismatched) 3090s, and Monarch allows me to run and debug jobs directly in notebooks without relying on external services.
Installation had minor hurdles because I built from source rather than using the available pip package. Although the README specifies Python 3.10, Python 3.13 worked fine. The dependencies reference dnf (reflecting Meta’s internal Linux distro choice), so adapting commands to other Linux distributions (Ubuntu in my case) was necessary. Additionally, I had to set BINDGEN_EXTRA_CLANG_ARGS="-I/usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11" to resolve Rust compilation issues.
Once installed, running Monarch’s distributed data-parallel notebook was straightforward (see: monarch/examples/notebooks/spmd_ddp.ipynb). The notebook shows that minimal code changes to the standard DDP example are required, mainly subclassing Actor (e.g., class DDPActor(Actor)), while keeping the training loop virtually identical. Monarch handles the rest, including distributed execution and debugging across multiple GPUs.
Setting up the environment means providing the mesh configuration and launching the actors, which can be done from a cell:
# Spawn a process mesh
local_proc_mesh = await proc_mesh(
gpus=WORLD_SIZE,
env={
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12455",
},
)
# Spawn our actor mesh on top of the process mesh
ddp_actor = await local_proc_mesh.spawn("ddp_actor", DDPActor)
I didn’t have to manually start any other services; it all happened under the hood. Triggering the run is just:
await ddp_actor.demo_basic.call()
Which output:
self.rank=0 Running basic DDP example
self.rank=1 Running basic DDP example
self.rank=1 Finished running basic DDP example
self.rank=0 Finished running basic DDP example
What I find really appealing is how easy it is to execute across ranks. For example, to query for system info:
print("Gathering system info from all ranks...")
system_info = await ddp_actors.get_system_info.call()
print("\n SYSTEM INFORMATION ACROSS ALL RANKS:")
print("=" * 60)
for point, rank_info in system_info:
print(f"Rank {rank_info['rank']}: PID={rank_info['process_id']}, "
f"Device={rank_info['device_name']}, "
f"GPU Memory={rank_info['gpu_memory_allocated']/1e6:.1f}MB")
print(f"\nFound {len(system_info)} ranks in the mesh")
all_rank_info = [value for point, value in system_info]
print(f"Total GPU memory across all ranks: {sum(info['gpu_memory_allocated'] for info in all_rank_info)/1e6:.1f}MB")
Outputting:
Gathering system info from all ranks...
[Rank 0] System Info: PID=10519, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
[Rank 1] System Info: PID=10520, CPU=0.1%, RAM=5.2%, GPU_MEM=0.0MB
SYSTEM INFORMATION ACROSS ALL RANKS:
============================================================
Rank 0: PID=10519, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
Rank 1: PID=10520, Device=NVIDIA GeForce RTX 3090, GPU Memory=0.0MB
Found 2 ranks in the mesh
Total GPU memory across all ranks: 0.1MB
I can also stop training and dump state if I need to , making it easy to check norms and debug:
print("Running training steps...")
for step in range(3):
print(f"\n--- Step {step + 1} ---")
step_results = await ddp_actors.train_step.call()
all_results = [value for point, value in step_results]
losses = [result['loss'] for result in all_results]
grad_norms = [result['grad_norm'] for result in all_results]
throughputs = [result['throughput'] for result in all_results]
print(f"Losses across ranks: {[f'{l:.4f}' for l in losses]}")
print(f"Gradient norms: {[f'{g:.4f}' for g in grad_norms]}")
print(f"Avg throughput: {sum(throughputs)/len(throughputs):.1f} samples/sec")
While the distributed stuff here is cool, it’s not wildly different than using a distributed framework like Ray and a little bit of setup (though I am pretty allergic to setup). What is most interesting is how this changes the programming model of PyTorch, and makes it really easy to build out distributed experiments.
For example, if I was building a param server the sync only requires an await’d read of the weights from another object, taking advantage of the RDMA support for an efficient cop1y:
@endpoint
async def worker_sync_with_ps(self, param_server) -> bool:
"""Synchronize with parameter server and get RDMA handles"""
self._log("Synchronizing with parameter server...")
# Get RDMA buffer handles
self.weight_buffers = await param_server.ps_get_weight_handles.call_one()
self.gradient_buffers = await param_server.ps_get_gradient_handles.call_one()
# Get metadata
metadata = await param_server.ps_get_metadata.call_one()
self.weight_metadata = metadata['weights']
self.gradient_metadata = metadata['gradients']
# Perform initial weight sync
sync_time = await self._sync_weights_from_ps()
self._log(f"Synchronized with parameter server (sync time: {sync_time:.3f}s)")
return True
Getting those weight buffers is as simple as creating the right Monarch object:
Tokenization has always struck me as one of the odder aspects of natural language deep learning. Despite the extensive end-to-end learning processes we typically use, tokenization initially involves creating a dictionary of optimal sub-word segments from your dataset. One of the appealing concepts in the Byte Latent Transformers paper is the potential to learn tokenization dynamically, recognizing that tokenizers solve deeper problems than merely providing a fixed vocabulary.
This paper addresses tokenization from a theoretical perspective by modeling sequences using kth-order Markov processes, where the likelihood of each token depends on the preceding sequence, as in natural language. The parameter k corresponds to the model’s context window size. Key findings include:
Training without tokenization leads models to effectively behave as unigram predictors, significantly limiting performance.
Using a well-designed tokenizer (e.g., Byte Pair Encoding – BPE) enables models to achieve nearly optimal performance in capturing sequence dependencies.
Increasing the tokenizer’s dictionary size improves the model’s performance, moving it closer to the ideal probability distribution.
Tokenizers which do a good job at learning patterns in the data and assigning these frequent patterns as tokens in the dictionary are compatible with an i.i.d. model over tokens.
This insight suggests that despite the complexity of natural language’, a good tokenizer converts sequences into something approximating an independent and identically distributed (i.i.d.) format, which brings the modeling tasks for transformers closer to the one they can solve.
While the paper does not explicitly explore the Byte Latent approach, I wonder if its entropy-driven dynamic token allocation might similarly achieve this i.i.d. simplification. In BLT the entropy model, trained separately, could be dynamically transform inputs into a distribution that is more palatable for transformers.
Filing this under interesting work I will probably never use. The authors try to construct a more accurate simulation of the Ampere (4090/A100 type GPUs) microarchitecture, backed by extensive testing on real hardware. It’s a good reminder that, in part because of how good some of the abstractions are, there is quite a lot about Nvidia GPUs that isn’t really known outside Nvidia. My main takeaway was that the compiler is very deeply coupled to the hardware performance: a Nvidia chip is not really a complete unit without taking in to account the software driving the performance, and recognizing that accounts for why Nvidia have done such a good job of building a solid stack with CUDA.
One of the things I found interesting was the use of a Stall counter: the compiler notes fixed latency instructions (which seem to be a preferred design choice) and adds a counter to the instructions control bits that specifies how many cycles the warp should wait before issuing the next instruction, and so other warps will be selected for execution. This means the hardware doesn’t have to dynamically check for data dependencies.
For example, an addition whose latency is four cycles and its first consumer is the following instruction encodes a four in the Stall counter. Using the methodology explained in section 3, we have verified that if the Stall counter is not properly set, the result of the program is incorrect since the hardware does not check for RAW hazards, and simply relies on these compiler-set counters. In addition, this mechanism has benefits in terms of area and energy wiring. Keep in mind that wires from fixed-latency units to the dependence handling components are not needed, in contrast to a traditional scoreboard approach where they are required.
There are variable execution length instructions, like memory loads, and in that case they have a Dependence counter, which is decremented when data arrives.
In the vein of handing off to the compiler, the scheduler uses a Compiler Guided Greedy Then Youngest policy: it will keep issuing instructions from the same warp (greedy) with guidance from the Stall (and an explicit Yield bit) and otherwise will swithch to the youngest ready warp. Older GPUs (apparently!) used Greedy Then Oldest instead, which resulted in more often selecting a warp that was still stalled waiting for memory or similar, while the youngest more likely has useful work to do.
The scheduler starts issuing instructions from the youngest warp, which is W3, until it misses in the Icache.As a result of the miss, W3 does not have any valid instruction, so the scheduler switches to issue instructions from W2. W2 hits in the I-cache since it reuses the instructions brought by W3, and when it reaches the point where W3 missed, the miss has already been served, and all remaining instructions are found in the I-cache, so the scheduler greedily issues that warp until the end. Later, the scheduler proceeds to issue instruction from W3 (the youngest warp) until the end, since now all instructions are present in the I-cache.
Similarly, the paper points out that the instruction prefetch cache is a stream buffer (probably 16 instructions deep) rather than any kind of complex branch prediction logic, because we generally don’t do that kind of thing on GPUs!
a straightforward prefetcher, such as a stream buffer, behaves close to a perfect instruction cache in GPUs. This is because the different warps in each sub-core usually execute the same code region and the code of typical GPGPUs applications do not have a complex control flow, so prefetching 𝑁 subsequent ines usually performs well. Note that since GPUs do not predict branches, it is not worth implementing a Fetch Directed Instruction prefetcher [76] because it would require the addition of a branch predictor.
Interesting paper breakdown on Gonzo ML of another evolutionary agent approach from the extended Sakanaverse.
It commences with an initial coding agent, constructed upon a frozen foundation model (FM) equipped with tool-use capabilities (e.g. running bash commands, editing files). In each cycle, “parent” agents are selected from the expanding archive. This selection process prioritizes agents based on a combination of their performance (assigning greater weight to higher scores, scaled by sigmoid) and a novelty bonus (inversely correlated with the number of offspring they have already produced, thereby encouraging exploration of less-frequented paths).
The actual foundation model is a frozen component, so much like alphaevolve this is a search set up on top of the model intelligence. The search is evolving the agent code itself to try and do better on benchmarks.
Qualitatively, the DGM learned to enhance its own tools and workflows. For instance, it developed more granular file editing capabilities (e.g., string replacement), improved long-context window management (e.g., auto-summarizing prior interactions), and refined its problem-solving strategies (e.g., making multiple attempts at a solution and using another FM to evaluate patches). These discovered improvements also demonstrated generalizability, transferring benefits across different underlying FMs and programming languages.
When it comes to coding agents I had been thinking there were three axes of performance, which gate the overall effectiveness, but the paper makes it clear there are at least 4:
The foundation model itself, with its base coding, tool use, reasoning abilities and context window size
The tools it has available – the more the tool is exposes underlying semantics the more the model can efficiently use it.
The UI, how the user interacts with the agent to direct it, provide clarity and review work.
The prompt, strategies for problem solving and how the context window is managed (eg when to summarize)
In this case the UI is held fixed (an outer eval loop), the model is fixed and the search explores tools and strategies. It seems at the very least a search across multiple different models as options might also work well!
File this under the “gross oversimplifications” category. The basic approach to keeping GPUs busy is dividing the work into tiles, smaller sub-problems that make up the larger result. For a GEMM you might break the matrix into 128×128 or 128×64 tiles and let each CUDA thread block (CTA) own one tile. The GPU has many streaming multiprocessors (an A100 has 108) and every SM picks up one CTA at a time. If you want to know how many SMs your own card has you can call:
Tiles are launched in waves. A full wave is the moment when every SM is busy with exactly one CTA. If the total number of tiles isn’t a multiple of the SM count, the final wave is only partly full and some SMs sit idle; Nvidia calls that wave quantization. There is a similar problem at the edge of the matrix: if the dimensions aren’t multiples of the tile size the right-most or bottom-most tiles are partly empty, wasting threads (tile quantization). Sometimes a smaller tile size (for example 64 × 64) gives higher overall throughput because it leaves less unused space at the edges.
The usual cure for poor wave utilization is a persistent kernel. Instead of launching one CTA per tile, you launch (roughly) one CTA per SM and have each CTA pull tiles from a global queue until the queue is empty. Because each CTA is pulls whenever ready, the SMs rarely go idle and the tail effect is reduced.
Inside an SM the main performance lever for GEMMs arethe Tensor Core, which execute matrix-multiply add (MMA) instructions efficiently. On Ampere you use WMMA instructions: one Warp (32 threads) computes a 16 × 16 fragment at a time. Hopper introduces WGMMA instructions where four warps acting in ia warp-group (128 threads) execute a larger matrix multiply (up to 64 × 64 for FP16/FP8). To issue WGMMA you must place the right-hand operand B in shared memory; A can sit in either registers or shared memory. The operation is asynchronous, so while a warp-group is processing one tile the same CTA can be pre-loading the next tile.
Blackwell pushes the idea further. A pair of CTAs on neighbouring SMs can cooperate in a pair unified MMA, letting two SMs’ tensor cores process an even larger tile.
To make that possible Hopper introduced thread-block clusters and Blackwell extends them. When you launch a kernel you can group CTAs into clusters such that the scheduler guarantees to place them on SMs inside the same GPC (GPU Processing Core), so they share a fast interconnect and can access shared memory across SMs. If the grid doesn’t divide cleanly into whole clusters you also lower efficiency on the tail (is this cluster quantization? stick with the trend Nvidia!) so Blackwell has a Cluster Launch Control that can shrink the last cluster to better fit the work.
Loading Data
All of this only works if data is present in shared memory. The first optimization is making sure (global) memory access is coalesced. A 32-thread warp can request 32-byte chunks , but the memory bandwidth for a single fetch from DRAM is wider. e.g. If four consecutive threads request address 1, 4, 8 and 12, the memory controller can coalesce these into a single 128-byte read. If the addresses are strided (e.g. hopping across rows) then only 32 bytes out of the 128 byte fetch capacity is loaded at a time, so the load takes longer. Getting this right is about ensuring the memory layout is set up for the kernel, and doing any transforms needed in shared memory before executing.
In older GPUs the warp had to wait on the copy operation. Ampere enabled cp.async plus non-blocking wait/arrive barriers so a warp can initiate a copy from global to shared memory and immediately continue with arithmetic. Hopper adds the Tensor Memory Accelerator: with TMA, a single thread in the CTA can describe a multidimensional block to copy and the TMA hardware streams it to shared memory while the threads do something else. Blackwell goes one step further and can multicast a single TMA load into every SM of a cluster, which is helpful when multiple CTAs are about to reuse the same B tile.
In practice you hide latency by organizing the main loop using so that it double buffers: while the tensor cores work on tile k the TMA or cp.async engine is fetching tile k + 1 into the other half of shared memory; then you swap buffers and repeat. As long as copy time and compute time overlap well, the tensor cores and the copy engines stay saturated.
Choosing the right tile size
Choosing the right tile size (often expressed in Triton as BLOCK_M × BLOCK_N) is a balance between each of these: enough threads to issue a warp-group MMA, small enough tiles that the matrix edges aren’t mostly padding, enough shared-memory space to double-buffer, and a grid size that fills whole waves or is run via a persistent kernel. Autotuning in Triton or CUTLASS can empirically test different options on the hardware, but it helps to have the right mental model about what sets of sizes they should consider. One good clue that you’re missing an option is when you see a sudden drop in achieved TFLOP/s for particular shape.
AMD
AMD’s MI300X hardware takes a somewhat different route. The GPU is divided into chiplets, where each chiplet has its own compute units and multiple schedulers that schedule wavefronts (AMD for warps, 64 threads rather than 32) independently, so the hardware load-balances multiple kernels by itself. Matrix instructions run at the wavefront level; there is no cross-CU equivalent to WGMMA. Latency hiding relies on launching a large grid of workgroups and letting the hardware interleave them, rather than on explicitly scheduling async copies. On AMD the guidance is to mostly focus on high occupancy and coalesced memory access, whereas on NVIDIA there is value in crafting (by hand or compiler) the copy–compute pipeline.
Don’t blindly tie every piece of work to top-level metrics. Even if technically feasible, the cost is too high and the risk of spurious logic chains significant.
Start with Value Definition
Begin each project with a crisp definition of why it’s worth doing. What underlying problem are you solving, and why is that problem worth solving? Once you have these narrative assertions, it’s usually clear how extraordinary or controversial each claim is.
The more notable the claim, the more likely you need data to support it.
Value Metrics
1. Direct outcome metrics (strongest) We will run an ongoing experiment measuring profit per user with the feature on vs baseline.
2. Strong correlative metrics This will increase time on site, which we can measure and believe correlates with profit per user.
3. Anecdotes and feedback N sales team members report they could sell into more accounts if we launch this feature.
4. Strategic assertions We must do this because of upcoming regulation or we will be unable to continue this business line.
Progress Measurement
Once your value claim is clear and defensible, identify how you’ll measure progress. This may differ from your value metric. Ideal progress metrics tell you whether you’re succeeding, respond quickly to team actions, and have strong reference baselines.
1. Clear baseline, goal, and team-tied metric (Strongest) Launching this compressor will reduce binary download size by an estimated 10% vs the best available industry baseline. We can measure relative progress continually against our production binary during development.
2. Responsive metric without clear reference point We can improve compile times on this fixed codebase from today’s 90s baseline.
3. Non-responsive metric We can measure weekly mobile app release binary size, comparing the new compressor to our old implementation.
4. Milestones We will implement passes a, b, c, after which we can ship the new compiler optimizations to target customers.
Common Challenges
Stronger measures aren’t always a worthwhile tradeoff. If you have high confidence in the work’s value and applicability and mainly need to validate progress, milestones can be completely reasonable.
In general, approach projects with skepticism about whether this is the right thing to do and whether you’ll make good progress. Then identify ways to get concrete data rather than rely solely on leadership support.
Leadership pressure for top-level metrics The clearer you are on why you’re doing something, the easier it becomes to communicate your measurement decisions. If leadership continues pushing back and you have a good relationship, use that as a lens to explore concerns you might have missed. Often requests for metric clarity stem from deeper worries about project value or plausibility.
Team dynamics and gaming Metrics communicate value in performance reviews, creating incentives for engineers to identify unnecessary correlations or game metrics (intentionally or not). Counter this with “health” metrics that balance negative behaviors — if measuring deployment frequency, also measure production incidents with a flat target to prevent trading off reliability for speed.
For senior engineers concerned about optics, have them clearly articulate the value chain. Working through and demonstrating strong data usage in project steering is itself a highly rewarded skill — encourage them to take on that role.
When to revisit metrics The world changes. What’s sensible in one environment isn’t in another. Constantly relitigating is a headache, but reevaluating logic at regular intervals (say, each half-year planning cycle) is appropriate. Otherwise, maintain awareness of company trends: new projects, initiatives, or teams gaining significant attention. If they were successful, would that change what you do? There’s no hard rule: it’s corporate decision-making taste that develops with experience.
Great post by the folks at General Reasoning on the combination of factors that led to O1-type breakthroughs in inference time compute.
But here is the key point: no-one suddenly discovered that reinforcement learning was useful for reasoning. It was always useful, but getting some of the details right was the difference between a good post-training recipe and a paradigm shift in the way we use language models.
ML research is prone to these lollapalooza effects where several positive facts coincide to produce a much larger than expected result. You can go look at the launch of ChatGPT for another example: ChatGPT wasn’t a surprise for folks who had spent time with large language models, and had seen attempts like Galactica before. But for many people it was a remarkable, new experience, and the engagement and interaction ChatGPT saw was new to the researcher community. That itself contributed to further breakthroughs and improvements.
Fused Linear Cross-Entropy is a popular optimization that combines the final linear projection and cross-entropy loss into a single operation. This fusion is very valuable for training large language models efficiently, as it can reduce memory usage significant, particularly for larger vocabularies.
If you look at a LLM training loop, you generally see something like:
logits = model(input_ids)
loss = cross_entropy(logits, targets)
And if you look at the end of the model, you’ll see something like the below, where h is the hidden state so far and output is output = nn.Linear(embed_dim, vocab_size, bias=False)
That final logics value can be pretty big: sequence length is long and the vocabulary size is large (128k for Llama 3, 202k for llama 4), so logits can be GB of memory: with a 16k context window, a 128k vocab, and 4k embedding dimensions even at a batch size of 1, you get 8bn entries. At BF16, that’s 4GB. You’ll also need to capture the gradient, which will give you another 4GB in the backwards.
That set of logits has a range of values that are a bit all over the place, one for each of the possible targets.
Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it’s common to use cross entropy as a loss function for classification problems where:
q is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called logits, denoted as y^), that is q=softmax(y^)
p is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position y (our label for the correct class): pi={1i=y 0i≠y
This means that cross-entropy simplifies to F.nll_loss(F.log_softmax(x, 1), target)
Softmax makes our previously messy logits into a nice probability distribution where all the values are positive and sum to one. log softmax is usually used in LLMs, for numerical stability and efficiency.
When we implement softmax, the naive implementations looks something like:
out = torch.log(torch.exp(x) / torch.sum(torch.exp(x)))
This isn’t numerically stable, so you want to subtract the max to avoid overflows and underflows in the exp. This is the common log-sum-exp implementation:
x_max = torch.max(x)
shifted_x = x - x_max
exp_shifted = torch.exp(shifted_x)
out = shifted_x - torch.log(torch.sum(exp_shifted)
In general the memory and compute cost of this grows with the size, which gets painful for our hefty logits. We can instead keep a running log-sum-exp so we don’t have to deal with the whole input at once.
lse = xs[0]
for x in xs[1:]:
m = torch.max(torch.stack([lse, x]))
lse = m + torch.log(torch.exp(lse - m) + torch.exp(x - m))
out = lse
This is the online log-sum-exp approach, and makes our life easier! We can now compute incrementally, but we are still generating the big logits before hand.
Fused Linear Cross-Entropy replaces the output projection, softmax and loss calculation with a single kernel that a tiles across all of it.
This is the core of the idea: instead of computing all logits at once (which creates a massive tensor), we can:
Compute logits for small chunks of the vocabulary
Compute the softmax incrementally
Only store the logits we need for the loss calculation
This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.
Roughly, the loop looks like:
For each of the token i in the sequence, with output layer weights h
Compute a partial dot product si = hi dot W_tile
Reduce for a running max and exp-sum
Return only the si[targeti] needed for the loss.
This gives you quite a lot of memory wins, which also reduce peak memory bandwidth needs. But this also introduces some potential pain!
You’re fusing the final layer op into the loss, which might be defined in different places in your model code
You’re accumulating, so you have to use fp32 or risk introducing numeric errors
You have to write you own backwards op as well, which will generally do some extra computation, so you are paying some extra FLOPS
Debugging can be harder, so you want a good recipe prior to swapping in the op
May require some futzing for best implementations on different hardware.
Actually implementing is pretty straightforward.
@staticmethod
def forward(ctx, h, W, target):
B, D = h.shape
V, _ = W.shape
chunk_size = min(1024, V)
# Initialize online softmax accumulators
max_logits = torch.full((B,), -float('inf'), device=h.device, dtype=torch.float32)
sum_exp = torch.zeros(B, device=h.device, dtype=torch.float32)
target_logits = torch.zeros(B, device=h.device, dtype=torch.float32)
# Process vocabulary in chunks
for chunk_start in range(0, V, chunk_size):
chunk_end = min(chunk_start + chunk_size, V)
# Compute logits for this chunk only
W_chunk = W[chunk_start:chunk_end, :]
logits_chunk = h @ W_chunk.T # [B, chunk_size]
# Update running max
chunk_max = logits_chunk.max(dim=1).values
new_max = torch.maximum(max_logits, chunk_max)
# Adjust previous sum_exp by exp(old_max - new_max)
sum_exp *= torch.exp(max_logits - new_max)
# Add this chunk's contribution to sum_exp
sum_exp += torch.exp(logits_chunk - new_max.unsqueeze(1)).sum(dim=1)
# Update max
max_logits = new_max
# Extract target logits if target is in this chunk
chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
is_target_in_chunk = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
target_logits += (logits_chunk * is_target_in_chunk).sum(dim=1)
# Compute loss: -log(p_target) = -(target_logit - log_sum_exp)
log_sum_exp = max_logits + torch.log(sum_exp)
loss = log_sum_exp - target_logits
# Save for backward
ctx.save_for_backward(h, W, target, max_logits, sum_exp)
ctx.chunk_size = chunk_size
return loss.mean()
Here we chunk the vocabulary, calculate the partial transform for the chunk h @ W_chunk.T, do online softmax and accumulate the target logits.
The backward calculates the gradients:
@staticmethod
def backward(ctx, grad_output):
h, W, target, max_logits, sum_exp = ctx.saved_tensors
chunk_size = ctx.chunk_size
B, D = h.shape
V, _ = W.shape
# Scale gradient by batch size (since we use mean reduction)
grad_scale = grad_output / B
# Initialize gradient accumulators
grad_h = torch.zeros_like(h)
grad_W = torch.zeros_like(W)
# Process vocabulary in chunks (same as forward)
for chunk_start in range(0, V, chunk_size):
chunk_end = min(chunk_start + chunk_size, V)
chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
# Recompute logits for this chunk
W_chunk = W[chunk_start:chunk_end, :]
logits_chunk = h @ W_chunk.T # [B, chunk_size]
# Compute softmax probabilities for this chunk
# p_i = exp(logit_i - max) / sum_exp
probs_chunk = torch.exp(logits_chunk - max_logits.unsqueeze(1)) / sum_exp.unsqueeze(1)
# Gradient w.r.t. logits: grad_logits = p - 1_{y=i}
grad_logits_chunk = probs_chunk * grad_scale
# Subtract 1 from target positions
is_target = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
grad_logits_chunk -= is_target.float() * grad_scale
# Accumulate gradients
grad_h += grad_logits_chunk @ W_chunk
grad_W[chunk_start:chunk_end, :] = grad_logits_chunk.T @ h
return grad_h, grad_W, None
In the backwards we recompute the logits for the chunks, and calculate the logits.
This is a very simplified implementation that trades off a bunch of kernel launches, so gives up a lot of performance, but you can see the memory savings:
For a more sophisticated implementation, you can look at the repo mentioned before or Liger has a good quality kernel with further optimizations. These calculate the gradients in the forward pass, then can just scale them in the backwards. This trades off a bit more memory for less of a compute hit. In general there are a few options for choosing the right point
Some technologies are obligate in a competitive environment
The example of his that stuck with me was the plough: many cultures were animistic (a belief in the spirit of the animal), but after the scaling up of agriculture enabled by the plough, most weren’t. The plough’s enablement of large-scale agriculture likely shifted societies toward sedentism (vs nomadism) and surplus, altering spiritual relationships with animals as they became tools for labor. The perspective shift — the value it encodes — is embedded in the technology.
The plough is also obligate. If one group uses it and other doesn’t, the group that does will be able to farm more per person. That surplus enables for more specialization, which yields an advantage either in terms of trading or conflict. If the second group doesn’t adopt the plough they will be taken-over, outgrown, or both, by the first.
AI may well be an obligate technology, which forces us to make deliberate ethical choices about its deployment and values. We are in the early stages of seeing that with software development. That’s going to change the nature of certain careers: changing what the day-to-day work looks like and impacting demand for software engineers. That isn’t necessarily negative: it will depend on the opportunities that replace the current ones. It also isn’t neutral: our approach to AI, how we deploy it, how it is used are all a series of choices that embed values.
Some of those values are encoded into the models by the training data and loss functions, some are encoded in the systems engineering, the choices of which tasks to apply it to, which interactions to explore and so on, and some are explicitly engineered in through fine tuning and reinforcement learning.
One way of looking at those values is through the study of ethics, how to live in a just way. This is a core topic for philosophers. One example is Kant’s Categorical Imperative, which requires actions to follow maxims that could be universally applied without contradiction, ensuring rational consistency.
It’s somewhat akin to asking the question: Would I still support this if I knew everyone else would act this way? Further, would I support this action if knew I would be born again randomly into the world, maybe in a much different situation than my one now?
The proliferation of useful AI agents adds a somewhat realistic flavor to the question: if, in the future, you are dependent on systems constrained by these specific guidelines or rules , are you happy about that?
Kantian (or deontological) thinking is far from the only ethical system. A lot of thinking about AI ethics has been consequentialist. Consequentialism is practical: the “goodness” of an action is whether it results in a good outcome! Inherently we judge AI training (at least for RL and supervised learning) by the achievement of the outcome encoded in a loss function, reward function or similar. Stuart Russel (of & Norvig fame from university courses of my youth) has written about “provably beneficial” AI where the AI maximizes a human-involved reward signal (a little like the Assistant Games pattern we discussed before).
The downside of all this is well documented — Nick Bostrom’s famous paperclip maximizer thought experiment is an AI that achieves the objective, but in a way that was undesirable. A more benign but annoying example might be a cleaning robot that pushes everything outside the house in order to make it tidier. Because outcome-based rules just judge the what, and now the how, they can also encourage power-seeking (as called out by Bostrom) in order to better achieve objectives.
standard forms of consequentialism recommend taking unsafe actions when such acts maximize expected utility. Adding features like risk-aversion and future discounting may mitigate some of these safety issues, but it’s not clear they solve them entirely.
Anthropic’s constitutional AI approach can be seen as a blend of approaches; the constitution is a set of principles that can be used by another AI to criticize and improve output in response to requests:
As AI systems become more capable, we would like to enlist their help to supervise other AIs. We experiment with methods for training a harmless AI assistant through self-improvement, without any human labels identifying harmful outputs. The only human oversight is provided through a list of rules or principles, and so we refer to the method as ‘Constitutional AI’.
The training still ultimately uses a form of reinforcement learning (which is inherently consequentialist), but the reward is given according to how well the outputs adhere to the constitutional principles.
A more recent philosopher, Derek Parfitt, argued that all moral systems were hill climbing towards a shared perspective, and you can evaluate an action on multiple in order to gain confidence. For example, when considering an option, you could ask:
a) Would it maximize overall good? (consequentialist) b) Could everyone rationally will it? (Kantian) c) Could anyone reasonably reject it? (contractualist1)
“Rationally” here is doing a bit of work: it means “with reasoning”, as in there is a chain of thought that can support and justify the decision.
Part of the challenge with rationalism is that part of the reward signal here is coming from human raters. We have seen this play out with LMSys where models which are “friendlier” score better, and in a more extreme version in the ChatGPT 4O misalignment where the model became excessively sycophantic in a way that resulted in better rewards in short doses, and didn’t impact any of the quantitative evaluations, despite being an overall negative to the experience.
As we move into more agentic systems we often have fewer tools to evaluate or make visible the values we are encoding, but we are still doing it!
For example. Google’s recent AlphaEvolve project uses Gemini underneath, which is an LLM that can be evaluated and aligned. But on top of that it uses an evolutionary search scheme (another reminder of Rich Sutton’s bitter lesson) to generate different prompts and evaluations and iterate towards a new, externally defined goal: in that case generating better algorithms and code. We are searching for superior outcomes, but that search itself is -somewhat unconstrained by other values: it’s a more consequentialist approach.
The current crop of agentic coding tools often recommends encoding preference data into a project specific file. For example, Claude Code recommends a CLAUDE.md file
Include frequently used commands (build, test, lint) to avoid repeated searches
Document code style preferences and naming conventions
Add important architectural patterns specific to your project
CLAUDE.md memories can be used for both instructions shared with your team and for your individual preferences.
While it presents them as memory, the idea here is to guide the choices of the model in a way that aligns with the principles by which the project being modified is managed.
we argue that one of the primary vulnerabilities underlying these attacks is that LLMs often consider system prompts (e.g., text from an application developer) to be the same priority as text from untrusted users and third parties. To address this, we propose an instruction hierarchy that explicitly defines how models should behave when instructions of different priorities conflict.
As well as using a single model that can incorporate different safeguards, we can use models themself to verify actions and outputs. Verification is generally an easier problem than generation, so a model that is unable to consistently follow a set of principles may still be able to validate whether a given example does or does not follow them.
LlamaGuard is a good example of this kind of system, built and released by Meta’s GenAI team alongside Llama. One example of seeing this process in the wild is OpenAI’s safety systems on 4O image generation. Inherently agentic, 4O can generate image ideas, then the image itself. Despite the model having constraints on it, it will happily generate things which violate OpenAI’s content policy, necessitating a monitoring model that whisks them away before a use can access a violating image.
If AI becomes an obligate technology, we will benefit from encoding values intentionally, balancing outcomes, universal principles, and fairness. The challenge is ensuring these choices reflect the world we want, not just the one we’re competing in.
Another theory of ethics that weights mutuality heavily: it’s frames ethical considerations as something derived between people rather than just based on outcomes or on abstract principles. Its featured particularly in Scanlon’s What We Owe to Each Other for those, like me, who get all of their ethical understanding from watching The Good Place ↩︎
So even when AI achieves genuinely impressive results in science, that doesn’t mean that AI has done something useful for science. More often, it reflects only the potential of AI to be useful down the road.
The problems Nick describes where he find PDE solving (the area he was looking into) had a lot of techniques which didn’t end up improving on non-ML approaches, feels very common. AI research likes to hill-climb metrics. It’s often the lack of progress on a certain benchmark that motivates new techniques, like the growth of test-time compute over the past year to drive math and logic performance higher.
The model does the eval is the backbone of how one should access and marshall their intuitions into a coherent view on AI progress.
The first awesome conclusion of the model does the eval is that we will achieve every evaluation we can state. Recall that evaluations must be legible, fast, and either a good approximation of a wanted capability or useful itself. The plummeting cost of compute has made all evaluations faster.
[…]
Add human intelligence to direct the cheaper compute to get more legible evaluations. Two years ago, Demis Hassabis enumerated three properties of problems suitable for AI: a massive combinatorial search space, a clear objective function to optimize against, and lots of data or an efficient simulator.
We tend to succeed where we have the evals and we have the data. Having the evals also starts to create a common lingua-franca to discuss relative performance, not that eliminates the baseline hacking Nick discusses.
The evals are often tied to having a good quality core data set that can be used for both training and evaluations. Even in areas where we have had scientific progress, mainly AlphaFold and descendants, as Derek Lowe often writes, we have a major leg-up with the existence of the PDB, an extensive database of high-quality protein structures created by people.
When we look back at major breakthroughs, we often credit that aspect: Dr Fei-Fei Li is one of the pioneers of deep learning thanks in part to the creation of ImageNet. I hope that one takeaway of scientists reading Nick’s note is that the creation of quality benchmarks and datasets can drive more progress than the application of (or innovation on) new ML techniques themselves!
Optimizers are consistently one of the great areas of ML for discovering whether you remember any linear algebra or not (I land on not). Given the pace of change, it’s somewhat surprising that Adam(W) has hung around for as long as it has. Adam updates each parameter by a moving average of the gradient, and a moving average of the squared-gradient (the 2nd moment). Each weight is updated (and the moments/running averages are tracked) separately.
One area where we have some alternative approaches are the Shampoo family of optimizers. They take a whole block of parameters (usually the weights of a layer ), stores the moving averages of second moments of all the gradients in a block, then transforms them each step. This transform gives an approximation for the of the inverse Hessian of the block. The Hessian is the square matrix representing the second derivative of the loss: telling you how the gradient slope changes, like a curvature map. This is expensive to calculate for the network as a whole, so Shampoo estimates it one layer at a time (ish – it’ll split up big layers), and rotates/transforms each parameters gradient update based on this curvature estimate.
Empirically either of these work because if you do look at the Hessian of loss in a network, it its basically block diagonal: all of the curvature is within certain blocks, and very little of it is between distant parameters.
The paper, Towards Quantifying the Hessian Structure of Neural Networks looks into why that is. The paper is dense, but they mainly conclude that the block diagonal pattern will occur as the number of output “classes” increase, with blocks representing classes. Given that a class is a token in LLMs, then modern LLMs are strongly likely to exhibit this structure.
(and layers end up somewhat aligning with class-wise blocks in most nets).
In the subsequent analysis, we will show that the number of classes C, instead of the CE loss, is one key factor. Specifically, the near-block-diagonal Hessian structure arises as C→∞ for both the MSE and the CE loss.
[…]
We emphasize that we do not claim “large C” as the only cause for the near-block-diagonal Hessian structure, but just that it is a sufficient condition.
Because output layers generally have a tight class association (e.g. one column per class) this propagates through the model during training. Ideally a block would be “all weights that eventually feed the class”, and the paper shows that training process (on a simple network) pushes different parts of tensors to have stronger associations with specific classes, so you get a kind of local version of the same block diagonal structure.
This kind of understanding is helpful because it helps explain why doing single-parameter optimizing still works (curvature is very localized), and also points a direction for improving optimizer memory usage:
II. Understanding Hessian structure can help design new training methods for NNs. For instance, Adam-mini (Zhang et al., 2024b), a recently proposed optimizer, utilizes the near-block-diagonal Hessian structure to cut down 50% memory consumption in Adam. We believe the special Hessian structure can inspire more new optimizers.
Skimming this paper did make me wonder whether maybe this would also apply to Muon. A new paper, Muon optimizer Accelerates Grokking studies Muon in practice. Grigory Sapnuov wrote up a great summary of the paper, which shows that Muon gets to understanding of the underlying distributions earlier (as shown by increases is eval-set validation).
The authors speculate about what exactly in Muon helps grokking. Spectral norm constraints and second-order cues steer the model away from simple memorization and help discover the true pattern.
Muon operates on 2D tensors only (and is usually mixed with Adam), and uses a transform called Newton-Schulz which takes the directions from the gradient upgrade but makes the singular values of the update equal in each direction, meaning we update the same effective distance based on the local curvature. It also operates one step at a time, rather than storing the moving average. This means it operates a bit like a simplified shampoo, but is even more efficient — so again benefits from the fact that you can largely ignore the geometry outside the layer it’s looking at!
I’m at PyCon (mildly awkward photo thanks to Simon Willison!) and earlier had to steal some extra chairs for the Typing Summit as it was full up! There is a lot of energy and interest around type checking, thanks to Astral’s Ty and Meta’s Pyrefly projects coming in to the space recently.
While the playground is great to try it, I wanted to see what it was like on a larger codebase I was familiar with. I decided to try TorchTune, which makes use of types, but doesn’t configure a typechecker explicitly for CI (as far as I know!), relying on the LSP to show squiggles as the main type hinting feedback (which is reasonable!)
I tried running mypy over it with a very basic config, and time mypy
It’s impressively fast: 10 minutes for mypy vs under 50 seconds for pyrefly. There are also a lot more errors, and it’s tricky to tell whether they’re false positive from pyrefly, skipped errors from mypy, or something else. I scoped it down to the TorchTunes KV cache module in torchtune/modules/kv_cache.py to get a better look.
There pyrefly returns 9 errors, and mypy 17, but that’s caused by slightly different ways of capturing some of the same errors from what I can see. For example:
k_out[:, :, self.cache_pos[:seq_len]] = k_val
This code is doing a bit of Tensor slicing:
cache_pos is a max_seq_len long tensor holding absolute write positions
k_out is the key cache, with shape batch_size x num_heads x max_seq_len x head_dim
Here we’re getting a view for the part of the cache we want to update, and appending the latest values
MyPy gives these errors:
torchtune/modules/kv_cache.py:104: error: "Tensor" not callable [operator] torchtune/modules/kv_cache.py:104: error: Value of type "Tensor | Module" is not indexable [index]
While pyrefly gives:
torchtune/modules/kv_cache.py:104:9-46: Item assignment is not supported on Module | Tensor Expected __setitem__ to be a callable, got BoundMethod...
In this modulecache_pos and k_cache are created by calling PyTorch’s register_buffer, which stores params for use in the state_dict but doesn’t use them for training. Buffers don’t have to be Tensors, so I am guessing the type doesn’t propagate well there. Adding explicit type declarations in the class body fixes the errors in both mypy and pyrefly.
Quansight have been doing a tonne of work on the roll-out of Free Threaded Python, and they released an insightful blog post covering the progress over the last year that complements the recent Meta eng blog post well.
At this time last year, when Python 3.13.0b1 shipped, the wider ecosystem of Python packages was more or less completely broken on the free-threaded build. Trying to pip install anything but the simplest package with no dependencies or only pure-Python dependencies would likely lead to build errors. Most of these issues were not due to fundamental problems but because of unsupported default options or minor assumptions broken on the free-threaded build.
Together with package maintainers and other contributors in the community, we have fixed many of these issues and today things are much better. With the release of Cython 3.1.0, which ships official support for the free-threaded build, we also helped fix one of the most significant sources of build issues.
There have been a number of good talks at PyCon US this year around Free-Threaded, and lots of familiar stories of rediscovering the concurrency principles that other languages have worked on in the process! I’d recommend keeping an eye on the PyCon YouTube Channel for talk recordings, particularly Lisandro and Nathan’s talk about their journey, David Hewitt’s talk about using Rust with FT Python from his experience with PyO3, and Alvaro Duran’s on his FT Python load balancer!
Some useful links I’ve been looking at from the talks:
Fantastic deep dive into the concept of latents and the tradeoffs around them by Sander Dieleman of DeepMind. It’s a long article, but there’s a conclusions section that pulls out some of the most interesting points, and each section is an expansion on those points.
Latents add complexity, but the computational efficiency benefits are large enough for us to tolerate this complexity – at least for now.
Three main aspects to consider when designing latent spaces are capacity (how many bits of information are encoded in the latents), curation (which bits from the input signals are retained) and shape (how this information is presented).
Preserving structure (i.e. topology, statistics) in the latent space is important to make it easy to model, even if this is sometimes worse from an efficiency perspective.