https://github.com/deepseek-ai/FlashMLA/
DeepSeek’s infra team are having a great week open-sourcing some their components, to the benefit of everyone! The first day was their Multihead Latent Attention kernel, which take the FlashAttention approach of leveraging shared memory heavily as an additional level of tiling during the attention computation, split-k to divide up along the K dimension and so on.
On top of that they add a latent K/V vector: the input to the K and V is first projected to a lower latent dimension, then the K and V matrices project back into full dimension per-head. Only the latent vector is cached for past tokens. While this does use a little more compute than a traditional KV cache, choosing a sufficiently small latent dimension means significant memory and memory bandwidth savings, which is typically the constraint.
MLA was introduced back in DeepSeek v2, if you want to read the full breakdown!