-
Notifications
You must be signed in to change notification settings - Fork 320
Commit 6de76cd
committed
Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the
main goal of simplifying the existing UX. The new API uses the
same `QATConfig` for both the prepare and convert steps, and
automatically infers the fake quantization configs based on
a PTQ base config provided by the user:
```
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
\# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
qat_config = QATConfig(base_config, step="prepare")
quantize_(m, qat_config)
\# train (not shown)
\# convert
quantize_(m, QATConfig(base_config, step="convert"))
```
The main improvements include:
- A single config for both prepare and convert steps
- A single quantize_ for convert (instead of 2)
- No chance for incompatible prepare vs convert configs
- Much less boilerplate code for most common use case
- Simpler config names
For less common use cases such as experimentation, users can
still specify arbitrary fake quantization configs for
activations and/or weights as before. This is still important
since there may not always be a corresponding PTQ base config.
For example:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
\# train and convert same as above (not shown)
```
**BC-breaking notes:** This change by itself is technically not
BC-breaking since we keep around the old path, but will become
so when we deprecate and remove the old path in the future.
Before:
```
\# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
\# train (not shown)
\# convert
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```
After: (see above)
**Test Plan:**
```
python test/quantization/test_qat.py
```
[ghstack-poisoned]File tree
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedFilter options
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changed
0 commit comments