Quantization in PyTorch

Jerry Zhang recently posted a couple of updates on the evolution of the quantization APIs in PyTorch, and the unification around TorchAO.

If you haven’t spent much time around quantization, or are used mainly to quantizing LLM models via tools, the range of options can be pretty gnarly.

Quantization compresses the model by taking a number format with a wide range and replacing it with something shorter. To recover the original value you track a scale factor and a zero point (sometimes referred to as affine quantization).

For example, if you have a float32 layer, but all of the parameters with it are between 1 and 10, then using a int8 (256 values) to represent the whole float range will compress all those values to a single value, losing a huge amount of information. Instead, you set the scale factor to cover the range of values actually present. This lowers the quantization error: the difference between the upscaled value and the original.

That’s easy to do for the weights of the models, but you need to calculate the activations as well (so you are multiplying matrics of the same time). “Static” quantization determines the scale factor and zero point for activations up front, while “dynamic” quantization calculates them at inference time, resulting in better accuracy. The downside is that generally the approach only works on CPUs as it’s inherently data-dependent.

You don’t have to quantize everything the same way, and its common to see quantization schemes in the format of A16W8 . That indicates the weights are quantized to 8-bit (usually int8), but the activations are kept at 16 bit (usually float16 or bfloat16). In those cases, the weights are upcast to a matching dtype at compute time, but you still benefit from the faster loading and lower persistent memory usage.

The general flow for quantization is:

  • Identify which parts of the model you want to quantize: often you’ll want to quantize some parts, and not others (for example, all the linear layers, but not a softmax)
  • Prepare the parts being quantized by adding quantize and dequantize operations around the normal ops.
  • For static quantization (where scale/zero point are set ahead of time for activations) calibrate the model by sending input data through and observing the ranges.
  • Convert the model, by replacing the layers with their quantized, lower bit representations, and ensure the appropriate operators are in place to dequantize when the

Quantizing after training is Post-Training Quantization (PTQ). Quantization-Aware Training (QAT) introduces quantization during training, allowing the model to learn and partially recover accuracy loss.

Quantizing in PyTorch

There are (at least!) 4 different approaches to quantization in PyTorch:

  • Eager Mode: Deprecated, simple to use with quantize_dynamic for quick dynamic quantization.
  • FX Mode: Also deprecated; separates quantization from model code via FX graph tracing but requires FX-traceability.
  • PT2E Mode: Current method using PyTorch 2 export. It captures the model graph, supports backend-specific configs (like XNNPack), and is preferable for exported models.
  • TorchAO Quantization: Latest method optimized for torch.compile. Includes easy-to-use features like autoquant for automatic quantization tuning as well as manual options.

Eager mode quantization

The original, and soon to be deprecated, quantization method in PyTorch was eager mode quantization. The simplest version to use is calling torch.ao.quantization.quantize_dynamic. This takes a config (identify what parts of the model you want to quantize), and a target dtype. That call scales the weights and downcasts, and injects ops to dynamically quantize activations. The prepare and convert steps are handled automatically, and there is no calibration as the activations will be scaled dynamically.

Eager mode can also static quantization with torch.ao.quantization.quantize. This requires a lot more modification: you manually add torch.ao.quantization.QuantStub() and DequantStub around the nn.Module calls you want to operate quantized. Then you torch.ao.quantization.prepare the model and call the forward with some example data to calibrate. Prepare adds observers that collect statistics on the activations to determine good zero point and scale values. Finally torch.ao.quantization.convert processes and returns the quantized model.

Quantization aware training works the same as before in terms of adding stubs, but rather than calibrating with input data you call torch.ao.quantization.prepare_qat and then run a regular training loop on the model.

Modifying the model code itself is pretty painful as well, particularly if you are in an active, multi-collaborator code base. This led to the second evolution of PyTorch quantization:

FX Mode Quantization

The idea behind FX mode quantization was to give the same range of options as before, without having to change model code. This is particularly helpful when you have (say) a research team developing a model, and production team that is trying to make it fast for inference!

The FX graph is a graph of operations created by tracing the model, and by working on the FX ops directly the library can make the quantization modifications while leaving the original code untouched. This works through pattern-matching in the FX graph and applying the transforms to add quant and dequant stubs automatically. The downside is it needs your model to be FX-traceable (as it would if you were using TorchScript), which often requires model changes.

torch.ao.quantization.quantize_fx contains the methods, and they follow the same pattern: a quantization config, then prepare_fx, then convert_fx with options for running input through for calibration or running a training loop for QAT.

FX tracing and TorchScript have been on the outs for a while due to their mix of complexity and inflexibility, and under the hood both this and the Eager variant use a quantized Tensor type which has been slated for deprecation. So, in their place we have…

PT2E Quantization

This one is not deprecated! But it is still changing a bit. The basic idea is very similar to FX quantization, except instead of using FX to capture the graph, we use PyTorch 2’s export feature. torch.ao.quantization.quantize_pt2e offer prepare_pt2e and convert_pt2e, again with an optional calibration/QAT step.

One big difference is that the quantization config is now backend specific. Prior to calling prepare, you would set up the backend: e.g. for the XNNPack library:

XNNPACKQuantizer().set_global(
    get_symmetric_quantization_config()
)

This allows different backends to define quantization set up based on the ops available to them.

This does require capturing the full model graph with PT2 export, which can be painful, but if you’re going through the pain its best to do so with this flow rather than FX!

If you don’t want to work with full graph capture, there is one other option which integrates with torch.compile:

TorchAO Quantization

Also not deprecated! The TorchAO library has a toolkit for quantization, and one particularly nice feature is autoquant:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

Underneath this will try different quantization schemes to identify the best size reduction vs accuracy reduction tradeoff.

You can also quantize manually:

quantize_(m, Int4WeightOnlyConfig(group_size=group_size))

Here, the quantize function wraps the prepare/convert steps for you, so its pretty user friendly!

Not sure what to use? Jerry’s table can help you distinguish, but if you’re quantizing for use in PyTorch (and using torch.compile) then prefer the TorchAO quantization. If you are quantizing as part of making the model available in non-Python environments then you’ll want to be able to export and use the PT2E flow.

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading