Perplexed

The normal loss when pre-training a language model is Cross-Entropy, which sounds more complicated than it is. As it generates a token, the model doesn’t just predict a token, it predicts a probability distribution across all possible tokens. Cross Entropy loss is -log(probability of the correct token) from that distribution.

  • If p(correct) = 0.99 → CE ≈ 0.01
  • If p(correct) = 0.5 (unsure between two tokens) → CE ≈ 0.693
  • If p(correct) = 1/100_000 (e.g. guessing uniformly) → CE ≈ 11.5

If you average the CE over a whole bunch of tokens (say in your validation set) and take e^(ave CE), you get the perplexity, or PPL.

The number gives you an idea of how many choices the model was considering. Perplexity of 1 means the model was always 100% sure and 100% right (a feat only Elon can achieve). PPL 2 means the model was flipping a coin between two tokens most of the time. PPL 50 means the model was uncertain between 50 plausible next tokens. Because you’re already calculating the loss, PPL is very cheap to compute, so it gets used a lot.

Prior to pre-training you’ll typically run a sweep of experiments of different architecture tweaks, and see which lower perplexity. During pre-training you’ll want to check whether the model is successfully learning, whether you should nuke a run rather than continuing: improvements in perplexity are a good guide to that. You can also score perplexity on fresh data using a well-trained model: data with a surprisingly high perplexity might be garbage, or a counting subreddit.

Still, you can have too much of a good thing. A new paper from Veličković et al “Perplexity cannot always tell right from wrong”, makes the argument that, much like with humans, its very easy to select for confidently wrong rather than uncertainly right.

We prove that, for a wide class of decoder-only Transformer-based language models, should the model be highly confident and correct on a sufficiently long input sequence, this must imply existence of another input where the model’s prediction is wrong, yet the log-perplexity of that prediction approaches *zero*

The basic idea is that when the model is confident, you can construct a different sequence that the model would be equally confident on but also… wrong.

This particularly shows up when contexts get longer, because all tokens are not equal. To give a trivial example:

In the word "strawberry," there are 8 Rs.

This is correct for every single token, except ‘8’. A highly confident model may have a lower perplexity for that sequence, as a whole, than a more correct but less confident one.

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading