Skip to content

Commit 5517031

Browse files
committed
Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ```Py from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ```Py from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ```Py \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned]
2 parents 7147dcb + 81096ae commit 5517031

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1616
:nosignatures:
1717

1818
QATConfig
19-
QATConfigStep
19+
QATStep
2020

2121
Custom QAT APIs
2222
---------------

test/quantization/test_qat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
3737
QATConfig,
38+
QATStep,
3839
initialize_fake_quantizers,
3940
)
4041
from torchao.quantization.qat.embedding import (
@@ -1272,6 +1273,8 @@ def test_qat_config_init(self):
12721273
# OK
12731274
QATConfig(base_config, step="prepare")
12741275
QATConfig(base_config, step="convert")
1276+
QATConfig(base_config, step=QATStep.PREPARE)
1277+
QATConfig(base_config, step=QATStep.CONVERT)
12751278
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
12761279
QATConfig(weight_config=fq_config, step="prepare")
12771280

torchao/quantization/qat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
FromIntXQuantizationAwareTrainingConfig,
44
IntXQuantizationAwareTrainingConfig,
55
QATConfig,
6-
QATConfigStep,
6+
QATStep,
77
from_intx_quantization_aware_training,
88
initialize_fake_quantizers,
99
intx_quantization_aware_training,
@@ -27,7 +27,7 @@
2727

2828
__all__ = [
2929
"QATConfig",
30-
"QATConfigStep",
30+
"QATStep",
3131
"FakeQuantizeConfigBase",
3232
"IntxFakeQuantizeConfig",
3333
"FakeQuantizer",

torchao/quantization/qat/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .linear import FakeQuantizedLinear
2727

2828

29-
class QATConfigStep(str, Enum):
29+
class QATStep(str, Enum):
3030
"""
3131
Enum value for the `step` field in :class:`~torchao.quantization.qat.QATConfig`.
3232
"""
@@ -124,7 +124,7 @@ class QATConfig(AOBaseConfig):
124124
base_config: Optional[AOBaseConfig]
125125
activation_config: Optional[FakeQuantizeConfigBase]
126126
weight_config: Optional[FakeQuantizeConfigBase]
127-
step: QATConfigStep
127+
step: QATStep
128128

129129
# Express `step` as a keyword argument
130130
# TODO: Use `kw_only=True` instead, added in python 3.10
@@ -134,7 +134,7 @@ def __init__(
134134
activation_config: Optional[FakeQuantizeConfigBase] = None,
135135
weight_config: Optional[FakeQuantizeConfigBase] = None,
136136
*,
137-
step: QATConfigStep = "prepare",
137+
step: QATStep = "prepare",
138138
):
139139
self.base_config = base_config
140140
self.activation_config = activation_config
@@ -144,7 +144,7 @@ def __init__(
144144

145145
def __post_init__(self):
146146
self.step = self.step.lower()
147-
all_step_values = [s.value for s in QATConfigStep]
147+
all_step_values = [s.value for s in QATStep]
148148
if self.step not in all_step_values:
149149
raise ValueError("`step` must be one of %s" % all_step_values)
150150
if self.base_config is None and self.weight_config is None:
@@ -189,7 +189,7 @@ def _qat_config_transform(
189189
# Swap nn.Embedding -> FakeQuantizedEmbedding
190190
base_config = config.base_config
191191
step = config.step
192-
if step == QATConfigStep.PREPARE:
192+
if step == QATStep.PREPARE:
193193
if base_config is not None:
194194
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
195195
else:
@@ -212,7 +212,7 @@ def _qat_config_transform(
212212
# Swap FakeQuantizedLinear -> nn.Linear
213213
# Swap FakeQuantizedEmbedding -> nn.Embedding
214214
# Then apply the base config's transform function to quantize the model
215-
assert step == QATConfigStep.CONVERT, "unexpected step '%s' in QATConfig" % step
215+
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
216216
assert base_config is not None, "expected `base_config` in convert step"
217217
if isinstance(module, FakeQuantizedLinear):
218218
module = module.to_linear()

0 commit comments

Comments
 (0)