Grouped GEMMs and MoE

One of the challenges discussed in the Deepseek v3 paper is the availability of grouped GEMM kernels, which are used to hide the performance impact of many small kernel launches on GPUs. Deepseek uses many small experts (256!) rather than a few larger ones, which exacerbates this problem.

Mixture of Experts models introduce multiple experts in the feed-forward portion of each transformer layer. Rather than having a single shared set of experts, each layer has its own. Each batch of tokens first passes through the standard attention block, followed by a lightweight linear layer with a softmax function1. This determines, for each token, which experts it should be sent to. Tokens designated for each expert are gathered and sent to the appropriate device via an all-to-all operation, as experts are typically distributed across different devices.

Once the tokens are on the device with the right expert(s) we need to execute the matrix multiplies for each expert for its set of tokens. The obvious solution is just to loop through and launch each GEMM, but because these are small (small number of tokens, and smaller expert matrices) the kernel launch ends up being a lot of the performance. A grouped GEMM allows you to do this process on-device, taking in a list of tokens and experts and executing all the GEMMS with a single kernel launch.

This varies from batch GEMMs as the inputs can vary – different experts might receive different numbers of tokens.

There are example implementations available, including a tutorial on TritonLang that walks through a simple grouped GEMM kernel, as well as an example in Cutlass .

  1. In switch MoEs at least, but there are similar gating networks elsewhere. ↩︎

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading