Let’s all switch to FP16?

Serious scientists use FP64 – 64 bit floating point numbers – for high precision simulations, but in the world of machine learning we got by for the longest time with FP32. The perennial quest for increased FLOPS, particularly when memory bound, made even that seem too expensive though.

FP16 offered a reduced numeric range, but at half the size. Training with it in practice meant embracing autoscaling1 which ensured the values stayed within the range FP16 could represent. Then, Google developed BF16: it moved some of the bits to the exponent from the mantissa, so offered the same numeric range as FP32, but with reduced precision.

Since TPUv3 back in 2018 and Ampere in 2020 it’s been finding its way into hardware and has become the go-to format for training for many models. Life was good, and training in FP16 was mainly discussed as a memory of hard winters past.

Last week [2510.26788] Defeating the Training-Inference Mismatch via FP16 dropped and threw ML Twitter into a tither by making the argument everyone was doing Reinforcement Learning wrong and the solution… was FP16.

“In this work, we take a step back from the complex algorithmic fixes and investigate the root cause of the numerical mismatch: floating-point precision. We identify that the modern standard for mixed-precision training, BFloat16 (BF16), is the primary culprit. While BF16 has a wide dynamic range which is excellent for stable pre-training, its low precision makes it highly susceptible to rounding errors that accumulate and eventually cause the training and inference policies to diverge.”

The process for RL generally looks like:

  • Get a problem in a prompt
  • Do inference on the model to generate complete responses (a rollout)
  • Get a reward score for the response(s)
  • Run a training loop on the model to update weights based on the reward

If you want to be on-policy (which generally trains better) you need the “model” in steps 2 and 4 to be identical, but the actual code running around the model in the two steps is different: for example, you don’t use a KV cache in training and you don’t store gradients in inference. But you do want to keep the weights and numerics of the model the same, else your on-policy training becomes a little bit off-policy.

The last year of LLM research has been scaling this up, which requires managing a training and inference flow efficiently. This ongoing pressure to optimize the two paths independently leads to a risk of divergence. The paper finds that absolutely happens, and the divergence collapses the effectiveness of the learning. Unless, that is, you use FP16:

This is precisely why switching to FP16 provides a fundamental solution. With its 10 mantissa bits, FP16 offers 8 times more precision (210 values vs. 27 values) than BF16. This higher fidelity means that the outputs of the training and inference engines are much more likely to be numerically identical. The increased precision creates a buffer that absorbs the minor implementation differences between the two engines, preventing rounding errors from accumulating and causing a policy divergence”

The paper does an excellent job of breaking down the many reasons why this happens, but it pretty clear that FP16 is a patch: if you can’t get your numerics perfectly matched, then having more precision gives you more wiggle room.

About a month before this the ByteDance folks posted a fantastic deep dive into RL collapse from discrepancies between training and inference: When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch.

They identify a range of concerns, including straight up bugs:

“According to this GitHub issue, we set disable_cascade_attn=True when initializing the vLLM engine and found that it significantly helps reduce the training-inference mismatch in experiments conducted on A100 GPUs.

Many of the experiments in the FP16 vs BF16 paper were run on A100s2 , so some backlash emerged suggesting that perhaps this whole thing is just a kernel error. But as ByteDance showed, there really is a lot going on that can make things worse.

Another example is Horace He’s recent work at Thinking Macines around a related problem: Defeating Nondeterminism in LLM Inference – Thinking Machines Lab

“As mentioned above, one common explanation for why kernels add numbers in different orders is the “concurrency + floating point” hypothesis. The hypothesis states that if the order in which concurrent threads finish is nondeterministic and the accumulation order depends on the order in which concurrent threads finish (such as with an atomic add), our accumulation order will be nondeterministic as well.”

Horace calls out variance in batching as the primary cause of non-determinism, and hence another quite plausible cause of inference/training mismatch

“In other words, the primary reason nearly all LLM inference endpoints are nondeterministic is that the load (and thus batch-size) nondeterministically varies! This nondeterminism is not unique to GPUs — LLM inference endpoints served from CPUs or TPUs will also have this source of nondeterminism.”

The meta-point is that despite being a field fundamentally based in mathematical precision we have been sloppy with numerics, pretty much everywhere.

Ed Yang’s session in the PyTorch Conference keynote3 a couple of weeks back called this problem out from the perspective of scaling up ML infrastructure. He presented a number of solutions to try and address it, which often comes down to giving folks control over precisely how the numerics work in different parts of their model.

While the focus here was on RL and FP16, the reality is we deal with this for training->inference in much simpler cases, as well as when moving models between different hardware. Even within generations this can be hard: one of the fun infra problems when the H100 came out was everyone discovering that the FP8 tensor cores in the Hopper used a 22-bit accumulator for intermediate calculations, which wasn’t really documented!

The balance between speed and accuracy is often effectively made empirically: if something is faster, and works, then at some level it’s right! Reinforcement Learning mixes together different evolutionary chains of optimizations, so maybe those serious scientists with their FP64 were onto something. Not because they absolutely needed the precision, but because they needed to know they had the precision.

We’re probably not going to switch industry wide back to FP164, but getting a better numerical grounding into the tools we use is going to make everyone’s lives easier, eventually!

  1. torch.cuda.amp and friends ↩︎
  2. Though they did verify on Hopper some as well, which some people seemed to miss ↩︎
  3. Check out the recording: Keynote: PyTorch Technical Deep Dive – Alban Desmaison, Peng Wu, Mark Saroufim & Edward Yang, Meta ↩︎
  4. Especially since most labs are doing so much with FP8 or less these days, and it would probably annoy a bunch of chip designers ↩︎

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading