Pyrefly

https://pyrefly.org

I’m at PyCon (mildly awkward photo thanks to Simon Willison!) and earlier had to steal some extra chairs for the Typing Summit as it was full up! There is a lot of energy and interest around type checking, thanks to Astral’s Ty and Meta’s Pyrefly projects coming in to the space recently.

While the playground is great to try it, I wanted to see what it was like on a larger codebase I was familiar with. I decided to try TorchTune, which makes use of types, but doesn’t configure a typechecker explicitly for CI (as far as I know!), relying on the LSP to show squiggles as the main type hinting feedback (which is reasonable!)

I tried running mypy over it with a very basic config, and time mypy

[mypy]
python_version = 3.13
ignore_missing_imports = True
warn_unused_ignores = True
strict_optional = True
files = .

Unsurprisingly, there are quite a few errors!

Found 1361 errors in 211 files (checked 485 source files)
real 10m45.596s
user 0m12.906s
sys 0m0.929s

I installed pyrefly and init’d it:

pip install pyreflypyrefly init

This created a pyrefly.toml containing a very minimal config:

project_includes = ["."]
python_version = "3.13.0"

pyrefly check then gave me

INFO 2,966 errors shown, 7 errors ignored, 485 modules, 7,364 transitive dependencies, 3,522,743 lines, took 47.94s (checking 34.88s; reporting 12.98s), peak memory physical 863.6 MiB

It’s impressively fast: 10 minutes for mypy vs under 50 seconds for pyrefly. There are also a lot more errors, and it’s tricky to tell whether they’re false positive from pyrefly, skipped errors from mypy, or something else. I scoped it down to the TorchTunes KV cache module in torchtune/modules/kv_cache.py to get a better look.

There pyrefly returns 9 errors, and mypy 17, but that’s caused by slightly different ways of capturing some of the same errors from what I can see. For example:

k_out[:, :, self.cache_pos[:seq_len]] = k_val

This code is doing a bit of Tensor slicing:

  • cache_pos is a max_seq_len long tensor holding absolute write positions
  • k_out is the key cache, with shape batch_size x num_heads x max_seq_len x head_dim
  • Here we’re getting a view for the part of the cache we want to update, and appending the latest values

MyPy gives these errors:

torchtune/modules/kv_cache.py:104: error: "Tensor" not callable [operator]
torchtune/modules/kv_cache.py:104: error: Value of type "Tensor | Module" is not indexable [index]

While pyrefly gives:

torchtune/modules/kv_cache.py:104:9-46: Item assignment is not supported on Module | Tensor   Expected __setitem__ to be a callable, got BoundMethod...

In this module cache_pos and k_cache are created by calling PyTorch’s register_buffer, which stores params for use in the state_dict but doesn’t use them for training. Buffers don’t have to be Tensors, so I am guessing the type doesn’t propagate well there. Adding explicit type declarations in the class body fixes the errors in both mypy and pyrefly.

cache_pos: torch.Tensor
k_cache: torch.Tensor

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading