Skip to content

Commit 7a9fe90

Browse files
committed
Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig # prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(m, QATConfig(base_config, step="prepare")) # train (not shown) # convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) # train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` # prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) # train (not shown) # convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned]
1 parent 8f56651 commit 7a9fe90

File tree

7 files changed

+40
-40
lines changed

7 files changed

+40
-40
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1616
:nosignatures:
1717

1818
QATConfig
19-
19+
QATConfigStep
2020

2121
Custom QAT APIs
2222
---------------

docs/source/finetuning.rst

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import (
209-
quantize_,
210-
)
211-
from torchao.quantization.qat import (
212-
FakeQuantizeConfig,
213-
IntXQuantizationAwareTrainingConfig,
214-
)
208+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
209+
from torchao.quantization.qat import QATConfig
210+
215211
model = get_model()
216212
217-
# prepare: insert fake quantization ops
218-
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
219-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
220-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
221-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
222-
quantize_(model, qat_config)
213+
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
215+
quantize_(model, QATConfig(base_config, step="prepare"))
223216
224217
# fine-tune
225218
train_loop(model)
@@ -232,18 +225,12 @@ The next step is to actually quantize the model:
232225

233226
.. code:: py
234227
235-
from torchao.quantization import (
236-
Int8DynamicActivationInt4WeightConfig,
237-
)
238-
from torchao.quantization.qat import (
239-
FromIntXQuantizationAwareTrainingConfig,
240-
)
228+
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
241229
242-
# convert: transform fake quantization ops into actual quantized ops
243-
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
244-
# quantized activation and weight tensor subclasses
245-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
246-
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
230+
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231+
quantize_(model, QATConfig(base_config, step="convert"))
232+
233+
# inference or generate
247234
248235
Now our model is ready for serving, and will typically have higher quantized
249236
accuracy than if we did not apply the prepare step (fake quantization) during

test/quantization/test_qat.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,9 +1281,7 @@ def test_qat_config_init(self):
12811281
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
12821282

12831283
# Bad step
1284-
with self.assertRaisesRegex(
1285-
ValueError, "`step` must be either 'prepare' or 'convert'"
1286-
):
1284+
with self.assertRaisesRegex(ValueError, "`step` must be one of"):
12871285
QATConfig(base_config, step="blah")
12881286

12891287
# Step was not a keyword arg

torchao/quantization/qat/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ model = get_model()
9393
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
9494
quantize_(model, QATConfig(base_config, step="prepare"))
9595

96-
# train (not shown)
96+
# train
97+
train_loop(model)
9798

9899
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
99100
quantize_(model, QATConfig(base_config, step="convert"))
@@ -123,7 +124,8 @@ qat_config = QATConfig(
123124
)
124125
quantize_(model, qat_config)
125126

126-
# train (not shown)
127+
# train
128+
train_loop(model)
127129

128130
# convert: (not shown, same as before)
129131
```

torchao/quantization/qat/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
FromIntXQuantizationAwareTrainingConfig,
44
IntXQuantizationAwareTrainingConfig,
55
QATConfig,
6+
QATConfigStep,
67
from_intx_quantization_aware_training,
78
initialize_fake_quantizers,
89
intx_quantization_aware_training,
@@ -25,12 +26,13 @@
2526
)
2627

2728
__all__ = [
29+
"QATConfig",
30+
"QATConfigStep",
2831
"FakeQuantizeConfigBase",
32+
"IntxFakeQuantizeConfig",
33+
"FakeQuantizer",
2934
"FakeQuantizedLinear",
3035
"FakeQuantizedEmbedding",
31-
"FakeQuantizer",
32-
"IntxFakeQuantizeConfig",
33-
"QATConfig",
3436
# Prototype
3537
"initialize_fake_quantizers",
3638
# Legacy quantizers

torchao/quantization/qat/api.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8+
from enum import Enum
89
from typing import Any, List, Optional, Tuple
910

1011
import torch
@@ -25,6 +26,15 @@
2526
from .linear import FakeQuantizedLinear
2627

2728

29+
class QATConfigStep(str, Enum):
30+
"""
31+
Enum value for the `step` field in :class:`~torchao.quantization.qat.QATConfig`.
32+
"""
33+
34+
PREPARE = "prepare"
35+
CONVERT = "convert"
36+
37+
2838
@dataclass
2939
class QATConfig(AOBaseConfig):
3040
"""
@@ -114,7 +124,7 @@ class QATConfig(AOBaseConfig):
114124
base_config: Optional[AOBaseConfig]
115125
activation_config: Optional[FakeQuantizeConfigBase]
116126
weight_config: Optional[FakeQuantizeConfigBase]
117-
step: str
127+
step: QATConfigStep
118128

119129
# Express `step` as a keyword argument
120130
# TODO: Use `kw_only=True` instead, added in python 3.10
@@ -124,7 +134,7 @@ def __init__(
124134
activation_config: Optional[FakeQuantizeConfigBase] = None,
125135
weight_config: Optional[FakeQuantizeConfigBase] = None,
126136
*,
127-
step: str = "prepare",
137+
step: QATConfigStep = "prepare",
128138
):
129139
self.base_config = base_config
130140
self.activation_config = activation_config
@@ -134,8 +144,9 @@ def __init__(
134144

135145
def __post_init__(self):
136146
self.step = self.step.lower()
137-
if self.step not in ["prepare", "convert"]:
138-
raise ValueError("`step` must be either 'prepare' or 'convert'")
147+
all_step_values = [s.value for s in QATConfigStep]
148+
if self.step not in all_step_values:
149+
raise ValueError("`step` must be one of %s" % all_step_values)
139150
if self.base_config is None and self.weight_config is None:
140151
raise ValueError(
141152
"One of `base_config` or `weight_config` must be specified"
@@ -178,7 +189,7 @@ def _qat_config_transform(
178189
# Swap nn.Embedding -> FakeQuantizedEmbedding
179190
base_config = config.base_config
180191
step = config.step
181-
if step == "prepare":
192+
if step == QATConfigStep.PREPARE:
182193
if base_config is not None:
183194
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
184195
else:
@@ -201,7 +212,7 @@ def _qat_config_transform(
201212
# Swap FakeQuantizedLinear -> nn.Linear
202213
# Swap FakeQuantizedEmbedding -> nn.Embedding
203214
# Then apply the base config's transform function to quantize the model
204-
assert step == "convert", "unexpected step '%s' in QATConfig" % step
215+
assert step == QATConfigStep.CONVERT, "unexpected step '%s' in QATConfig" % step
205216
assert base_config is not None, "expected `base_config` in convert step"
206217
if isinstance(module, FakeQuantizedLinear):
207218
module = module.to_linear()

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def _infer_fake_quantize_configs(
287287
is_symmetric=base_config.act_mapping_type == MappingType.SYMMETRIC,
288288
)
289289
weight_config = IntxFakeQuantizeConfig(
290-
dtype=torch.int4,
290+
dtype=TorchAODType.INT4,
291291
group_size=base_config.group_size,
292292
is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC,
293293
)

0 commit comments

Comments
 (0)