How to Think About GPUs | How To Scale Your Model
The Jax “How To Scale Your Model” book is one of my favorite references for folks trying to get their head round pretraining1. It breaks down the performance characteristics of model training (often using Llama 3 as an example) in an incredibly clear way. The only slight limitation is that it is primarily focused on scaling LLMs on TPUs: interesting, but probably not your main platform target (unless you work at Deepmind). They just released a new chapter covering GPUs, and it’s also a great summary2.
There are also plenty of mildly snarky comments about design choices to leaven the reading too:
Takeaway: in theory, NVIDIA SHARP (available on most NVIDIA switches) should reduce the cost of an AllReduce on B bytes from about 2 * B / W to B / W. However, in practice we only see a roughly 30% improvement in bandwidth. Since pure AllReduces are fairly rare in LLMs, this is not especially useful.