Skip to content

Commit ccc9f4f

Browse files
committed
New multi-step QAT API
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` ghstack-source-id: 5dcd7e8 Pull Request resolved: #2629
1 parent a4e0235 commit ccc9f4f

File tree

8 files changed

+436
-120
lines changed

8 files changed

+436
-120
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180180

181181
```python
182-
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187-
quantize_(my_model, qat_config)
182+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
183+
from torchao.quantization.qat import QATConfig
184+
185+
# prepare
186+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
187+
quantize_(my_model, QATConfig(base_config, step="prepare"))
188+
189+
# train model (not shown)
190+
191+
# convert
192+
quantize_(my_model, QATConfig(base_config, step="convert"))
188193
```
189194

190195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).

docs/source/api_ref_qat.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchao.quantization.qat
66

77
.. currentmodule:: torchao.quantization.qat
88

9-
QAT Configs for quantize_
9+
Main Config for quantize_
1010
---------------------------------------
1111
For a full example of how to use QAT with our main `quantize_` API,
1212
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
@@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1515
:toctree: generated/
1616
:nosignatures:
1717

18-
IntXQuantizationAwareTrainingConfig
19-
FromIntXQuantizationAwareTrainingConfig
18+
QATConfig
19+
2020

2121
Custom QAT APIs
2222
---------------
2323
.. autosummary::
2424
:toctree: generated/
2525
:nosignatures:
2626

27+
FakeQuantizeConfigBase
2728
IntxFakeQuantizeConfig
2829
FakeQuantizedLinear
2930
FakeQuantizedEmbedding
3031
FakeQuantizer
3132
linear.enable_linear_fake_quant
3233
linear.disable_linear_fake_quant
3334

34-
Legacy QAT Quantizers
35+
Legacy QAT APIs
3536
---------------------
3637

3738
.. autosummary::
3839
:toctree: generated/
3940
:nosignatures:
4041

42+
IntXQuantizationAwareTrainingConfig
43+
FromIntXQuantizationAwareTrainingConfig
4144
Int4WeightOnlyQATQuantizer
4245
linear.Int4WeightOnlyQATLinear
4346
Int8DynActInt4WeightQATQuantizer

test/quantization/test_qat.py

Lines changed: 129 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ComposableQATQuantizer,
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
37+
QATConfig,
3738
initialize_fake_quantizers,
3839
)
3940
from torchao.quantization.qat.embedding import (
@@ -59,7 +60,7 @@
5960
_get_qmin_qmax,
6061
)
6162
from torchao.quantization.quant_api import (
62-
int8_dynamic_activation_int4_weight,
63+
Int8DynamicActivationInt4WeightConfig,
6364
)
6465
from torchao.quantization.quant_primitives import (
6566
MappingType,
@@ -1261,11 +1262,61 @@ def test_qat_prototype_bc(self):
12611262
@unittest.skipIf(
12621263
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
12631264
)
1264-
def test_quantize_api_standalone(self):
1265+
def test_qat_config_init(self):
1266+
"""
1267+
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
1268+
"""
1269+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
1270+
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")
1271+
1272+
# OK
1273+
QATConfig(base_config, step="prepare")
1274+
QATConfig(base_config, step="convert")
1275+
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
1276+
QATConfig(weight_config=fq_config, step="prepare")
1277+
1278+
# OK: good step values
1279+
self.assertEqual(QATConfig(base_config).step, "prepare")
1280+
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
1281+
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
1282+
1283+
# Bad step
1284+
with self.assertRaisesRegex(
1285+
ValueError, "`step` must be either 'prepare' or 'convert'"
1286+
):
1287+
QATConfig(base_config, step="blah")
1288+
1289+
# No configs are provided
1290+
with self.assertRaisesRegex(
1291+
ValueError, "One of `base_config` or `weight_config` must be specified"
1292+
):
1293+
QATConfig(step="prepare")
1294+
1295+
# Clashing configs are provided
1296+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1297+
QATConfig(base_config, weight_config=fq_config, step="prepare")
1298+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1299+
QATConfig(base_config, activation_config=fq_config, step="prepare")
1300+
with self.assertRaisesRegex(
1301+
ValueError, "must be specified in the convert step"
1302+
):
1303+
QATConfig(weight_config=fq_config, step="convert")
1304+
1305+
# FakeQuantizeConfigBase was specified as base_config
1306+
with self.assertRaisesRegex(
1307+
ValueError,
1308+
"was passed as `base_config`. Did you mean to do the following instead?",
1309+
):
1310+
QATConfig(fq_config, step="prepare")
1311+
1312+
@unittest.skipIf(
1313+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1314+
)
1315+
def test_quantize_api_prepare(self):
12651316
"""
12661317
Test that the following:
12671318
1268-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1319+
quantize_(model, QATConfig(...))
12691320
12701321
can produce the same results as `ComposableQATQuantizer`.
12711322
"""
@@ -1290,20 +1341,15 @@ def test_quantize_api_standalone(self):
12901341
baseline_model = baseline_quantizer.prepare(baseline_model)
12911342

12921343
# quantize_ API
1293-
activation_config = IntxFakeQuantizeConfig(
1294-
torch.int8,
1295-
"per_token",
1296-
is_symmetric=False,
1297-
)
1344+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
12981345
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1299-
quantize_(
1300-
m,
1301-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1346+
qat_config1 = QATConfig(
1347+
activation_config=act_config, weight_config=weight_config
13021348
)
1349+
qat_config2 = QATConfig(weight_config=weight_config)
1350+
quantize_(m, qat_config1)
13031351
quantize_(
1304-
m,
1305-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
1306-
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1352+
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
13071353
)
13081354

13091355
# Compare model values
@@ -1322,37 +1368,29 @@ def test_quantize_api_errors(self):
13221368
Test that we throw exceptions with helpful error messages if `quantize_`
13231369
runs into unexpected configurations.
13241370
"""
1325-
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1371+
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1372+
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
13261373
m = M3()
13271374

13281375
# Embedding currently only supports weight-only quantization
13291376
with self.assertRaisesRegex(
13301377
ValueError, "Activation fake quantization is not supported for embedding"
13311378
):
1332-
quantize_(
1333-
m,
1334-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1335-
lambda m, _: isinstance(m, torch.nn.Embedding),
1336-
)
1379+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))
13371380

13381381
# Only linear and embedding are supported currently
13391382
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1340-
quantize_(
1341-
m,
1342-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1343-
lambda m, _: isinstance(m, torch.nn.ReLU),
1344-
)
1383+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))
13451384

13461385
@unittest.skipIf(
13471386
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
13481387
)
1349-
def test_quantize_api_convert_path(self):
1388+
def test_quantize_api_e2e(self):
13501389
"""
13511390
Test that the following:
13521391
1353-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1354-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
1355-
quantize_(model, int8_dynamic_activation_int4_weight())
1392+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1393+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
13561394
13571395
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
13581396
"""
@@ -1370,16 +1408,8 @@ def test_quantize_api_convert_path(self):
13701408
baseline_model = baseline_quantizer.prepare(baseline_model)
13711409

13721410
# quantize_ prepare
1373-
activation_config = IntxFakeQuantizeConfig(
1374-
torch.int8,
1375-
"per_token",
1376-
is_symmetric=False,
1377-
)
1378-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1379-
quantize_(
1380-
m,
1381-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1382-
)
1411+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1412+
quantize_(m, QATConfig(base_config, step="prepare"))
13831413

13841414
# Compare prepared values
13851415
torch.manual_seed(self.SEED)
@@ -1393,8 +1423,7 @@ def test_quantize_api_convert_path(self):
13931423
baseline_model = baseline_quantizer.convert(baseline_model)
13941424

13951425
# quantize_ convert
1396-
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
1397-
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1426+
quantize_(m, QATConfig(base_config, step="convert"))
13981427

13991428
# Compare converted values
14001429
torch.manual_seed(self.SEED)
@@ -1447,14 +1476,12 @@ def test_qat_linear_bias(self):
14471476
Test that QAT supports linear bias.
14481477
"""
14491478
m = ModelWithLinearBias()
1450-
activation_config = IntxFakeQuantizeConfig(
1451-
torch.int8, "per_token", is_symmetric=False
1452-
)
1479+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14531480
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1454-
quantize_(
1455-
m,
1456-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1481+
qat_config = QATConfig(
1482+
activation_config=act_config, weight_config=weight_config
14571483
)
1484+
quantize_(m, qat_config)
14581485
example_inputs = m.example_inputs()
14591486
m(*example_inputs)
14601487

@@ -1653,7 +1680,7 @@ def test_qat_range_learning(self):
16531680
)
16541681
m = M()
16551682
example_inputs = m.example_inputs()
1656-
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1683+
quantize_(m, QATConfig(weight_config=config))
16571684

16581685
# Not initialized, should fail
16591686
for t in m._get_all_weight_qparams():
@@ -1756,6 +1783,60 @@ def test_qat_fp8a4w_quantizer(self):
17561783
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17571784
self.assertFalse(torch.equal(new_weight, prev_weight))
17581785

1786+
@unittest.skipIf(
1787+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1788+
)
1789+
def test_legacy_quantize_api_e2e(self):
1790+
"""
1791+
Test that the following two APIs are numerically equivalent:
1792+
1793+
New API:
1794+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1795+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1796+
1797+
Old API:
1798+
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1799+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1800+
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1801+
"""
1802+
group_size = 16
1803+
torch.manual_seed(self.SEED)
1804+
m = M()
1805+
baseline_model = copy.deepcopy(m)
1806+
1807+
# Baseline prepare
1808+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1809+
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1810+
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1811+
quantize_(baseline_model, old_qat_config)
1812+
1813+
# QATConfig prepare
1814+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1815+
quantize_(m, QATConfig(base_config, step="prepare"))
1816+
1817+
# Compare prepared values
1818+
torch.manual_seed(self.SEED)
1819+
x = m.example_inputs()
1820+
x2 = copy.deepcopy(x)
1821+
out = m(*x)
1822+
baseline_out = baseline_model(*x2)
1823+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1824+
1825+
# Baseline convert
1826+
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1827+
quantize_(baseline_model, base_config)
1828+
1829+
# quantize_ convert
1830+
quantize_(m, QATConfig(base_config, step="convert"))
1831+
1832+
# Compare converted values
1833+
torch.manual_seed(self.SEED)
1834+
x = m.example_inputs()
1835+
x2 = copy.deepcopy(x)
1836+
out = m(*x)
1837+
baseline_out = baseline_model(*x2)
1838+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1839+
17591840

17601841
if __name__ == "__main__":
17611842
unittest.main()

0 commit comments

Comments
 (0)