Gradient Accumulation (was) busted

This weekend I was reading the Tulu v3 paper (link), which offers a deep dive into building robust post-training setups. This is an very good resource for anyone aiming to build a really robust fine-tuning workflows. It covers critical elements like data set selection, synthetic data generation (with example prompts!), strategies for SFT and preference tuning, and various things they struggled with.

One struggle was an issue with gradient accumulation they ran into where the loss was worse than without it on. The community at large also hit this, and fixed it, thanks to an excellent blog post by Unsloth (link).

The bug

Gradient accumulation is a technique used to simulate larger batch sizes by accumulating gradients over several smaller batches before performing a backward pass. This approach is particularly useful for managing memory constraints during training, so comes up a lot when post-training on a biggish model with more limited hardware.

The problem arises when dealing with sequences of varying lengths within these mini-batches. In standard practice, the loss is calculated and normalized by the number of non-padded (i.e., valid) tokens in each sequence. However, when accumulating gradients across multiple mini-batches, each with different sequence lengths, the naive summation of gradients can lead to an incorrect total loss calculation.

The discrepancy occurs because the cross-entropy loss function normalizes by the number of valid tokens, and this normalization factor can vary between mini-batches. When these normalized losses are accumulated without proper adjustment, the final loss does not match what would have been obtained using a single large batch. This results in a higher observed loss during training when using gradient accumulation compared to full batch training.

Daniel and co at unsloth addressed this issue by developing a methodology that ensures the accumulated gradients are correctly scaled, accounting for the varying sequence lengths across mini-batches. This fix aligns the gradient accumulation process more closely with the theoretical foundations of full batch training, leading to more accurate loss calculations and improved training performance.

Fixes and Workarounds

Recent updates in both Hugging Face Transformers (pull request) and TorchTune (pull request) offer fixes. And, at least in Evan’s case, a little bit snark:

In honor of the day the ML community first discovered the fact that (x1 / n1) + (x2 / n2) != (x1 + x2) / (n1 + n2)

I really like seeing these small but practical problems pop up, and seeing the community rally around to fix them. I missed this when it happened in October, so glad to look back at it now!

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading