Fused Linear Cross-Entropy

Fused Linear Cross-Entropy is a popular optimization that combines the final linear projection and cross-entropy loss into a single operation. This fusion is very valuable for training large language models efficiently, as it can reduce memory usage significant, particularly for larger vocabularies.

If you look at a LLM training loop, you generally see something like:

logits = model(input_ids)
loss = cross_entropy(logits, targets)

And if you look at the end of the model, you’ll see something like the below, where h is the hidden state so far and output is output = nn.Linear(embed_dim, vocab_size, bias=False)

# shape: [b, seq_len, out_dim]
output = self.output(h)

That final logics value can be pretty big: sequence length is long and the vocabulary size is large (128k for Llama 3, 202k for llama 4), so logits can be GB of memory: with a 16k context window, a 128k vocab, and 4k embedding dimensions even at a batch size of 1, you get 8bn entries. At BF16, that’s 4GB. You’ll also need to capture the gradient, which will give you another 4GB in the backwards.

That set of logits has a range of values that are a bit all over the place, one for each of the possible targets.

Cross-entropy is a loss between two probability distributions. Jay Mody has an excellent blog post breaking down softmax and CE loss

Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it’s common to use cross entropy as a loss function for classification problems where:

  • q is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called logits, denoted as y^), that is q=softmax(y^)
  • p is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position y (our label for the correct class): pi={1i=y 0i≠y

This means that cross-entropy simplifies to F.nll_loss(F.log_softmax(x, 1), target)

Softmax makes our previously messy logits into a nice probability distribution where all the values are positive and sum to one. log softmax is usually used in LLMs, for numerical stability and efficiency.

When we implement softmax, the naive implementations looks something like:

out = torch.log(torch.exp(x) / torch.sum(torch.exp(x)))

This isn’t numerically stable, so you want to subtract the max to avoid overflows and underflows in the exp. This is the common log-sum-exp implementation:

x_max = torch.max(x)
shifted_x = x - x_max
exp_shifted = torch.exp(shifted_x)
out = shifted_x - torch.log(torch.sum(exp_shifted)

In general the memory and compute cost of this grows with the size, which gets painful for our hefty logits. We can instead keep a running log-sum-exp so we don’t have to deal with the whole input at once.

lse = xs[0]
for x in xs[1:]:
    m = torch.max(torch.stack([lse, x]))
    lse = m + torch.log(torch.exp(lse - m) + torch.exp(x - m))
out = lse

This is the online log-sum-exp approach, and makes our life easier! We can now compute incrementally, but we are still generating the big logits before hand.

Fused Linear Cross-Entropy replaces the output projection, softmax and loss calculation with a single kernel that a tiles across all of it.

This is the core of the idea: instead of computing all logits at once (which creates a massive tensor), we can:

  1. Compute logits for small chunks of the vocabulary
  2. Compute the softmax incrementally
  3. Only store the logits we need for the loss calculation

Quoting https://github.com/mgmalek/efficient_cross_entropy

This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.

Roughly, the loop looks like:

For each of the token i in the sequence, with output layer weights h

  • Compute a partial dot product si = hi dot W_tile
  • Reduce for a running max and exp-sum
  • Return only the si[targeti] needed for the loss.

This gives you quite a lot of memory wins, which also reduce peak memory bandwidth needs. But this also introduces some potential pain!

  1. You’re fusing the final layer op into the loss, which might be defined in different places in your model code
  2. You’re accumulating, so you have to use fp32 or risk introducing numeric errors
  3. You have to write you own backwards op as well, which will generally do some extra computation, so you are paying some extra FLOPS
  4. Debugging can be harder, so you want a good recipe prior to swapping in the op
  5. May require some futzing for best implementations on different hardware.

Actually implementing is pretty straightforward.

@staticmethod
def forward(ctx, h, W, target):
    B, D = h.shape
    V, _ = W.shape
    
    chunk_size = min(1024, V)
    
   # Initialize online softmax accumulators
   max_logits = torch.full((B,), -float('inf'), device=h.device, dtype=torch.float32)
   sum_exp = torch.zeros(B, device=h.device, dtype=torch.float32)
   target_logits = torch.zeros(B, device=h.device, dtype=torch.float32)
        
    # Process vocabulary in chunks
    for chunk_start in range(0, V, chunk_size):
        chunk_end = min(chunk_start + chunk_size, V)
            
        # Compute logits for this chunk only
        W_chunk = W[chunk_start:chunk_end, :]
        logits_chunk = h @ W_chunk.T  # [B, chunk_size]
            
        # Update running max
        chunk_max = logits_chunk.max(dim=1).values
        new_max = torch.maximum(max_logits, chunk_max)
            
        # Adjust previous sum_exp by exp(old_max - new_max)
        sum_exp *= torch.exp(max_logits - new_max)
        
        # Add this chunk's contribution to sum_exp
        sum_exp += torch.exp(logits_chunk - new_max.unsqueeze(1)).sum(dim=1)
        
        # Update max
        max_logits = new_max
            
        # Extract target logits if target is in this chunk
        chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
        is_target_in_chunk = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
        target_logits += (logits_chunk * is_target_in_chunk).sum(dim=1)
    
    # Compute loss: -log(p_target) = -(target_logit - log_sum_exp)
    log_sum_exp = max_logits + torch.log(sum_exp)
    loss = log_sum_exp - target_logits
    
    # Save for backward
    ctx.save_for_backward(h, W, target, max_logits, sum_exp)
    ctx.chunk_size = chunk_size
        
    return loss.mean()

Here we chunk the vocabulary, calculate the partial transform for the chunk h @ W_chunk.T, do online softmax and accumulate the target logits.

The backward calculates the gradients:

@staticmethod
def backward(ctx, grad_output):
    h, W, target, max_logits, sum_exp = ctx.saved_tensors
    chunk_size = ctx.chunk_size
        
    B, D = h.shape
    V, _ = W.shape
        
    # Scale gradient by batch size (since we use mean reduction)
    grad_scale = grad_output / B
        
    # Initialize gradient accumulators
    grad_h = torch.zeros_like(h)
    grad_W = torch.zeros_like(W)
        
    # Process vocabulary in chunks (same as forward)
    for chunk_start in range(0, V, chunk_size):
        chunk_end = min(chunk_start + chunk_size, V)
        chunk_indices = torch.arange(chunk_start, chunk_end, device=h.device)
            
        # Recompute logits for this chunk
        W_chunk = W[chunk_start:chunk_end, :]
        logits_chunk = h @ W_chunk.T  # [B, chunk_size]
            
        # Compute softmax probabilities for this chunk
        # p_i = exp(logit_i - max) / sum_exp
        probs_chunk = torch.exp(logits_chunk - max_logits.unsqueeze(1)) / sum_exp.unsqueeze(1)
            
        # Gradient w.r.t. logits: grad_logits = p - 1_{y=i}
        grad_logits_chunk = probs_chunk * grad_scale
            
        # Subtract 1 from target positions
        is_target = (target.unsqueeze(1) == chunk_indices.unsqueeze(0))
        grad_logits_chunk -= is_target.float() * grad_scale
            
        # Accumulate gradients
        grad_h += grad_logits_chunk @ W_chunk
            
        grad_W[chunk_start:chunk_end, :] = grad_logits_chunk.T @ h
        
    return grad_h, grad_W, None

In the backwards we recompute the logits for the chunks, and calculate the logits.

This is a very simplified implementation that trades off a bunch of kernel launches, so gives up a lot of performance, but you can see the memory savings:

Regular:
Time: 285.18 ms
Memory (total): 3072.0 MB
Loss: 11.142737
Chunked online softmax:
Time: 470.27 ms
Memory (total): 356.0 MB
Loss: 11.142738

For a more sophisticated implementation, you can look at the repo mentioned before or Liger has a good quality kernel with further optimizations. These calculate the gradients in the forward pass, then can just scale them in the backwards. This trades off a bit more memory for less of a compute hit. In general there are a few options for choosing the right point

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading