Skip to content

Commit 899fb5e

Browse files
committed
New multi-step QAT API
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` ghstack-source-id: 7adbc7c Pull Request resolved: #2629
1 parent a4e0235 commit 899fb5e

File tree

8 files changed

+457
-119
lines changed

8 files changed

+457
-119
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180180

181181
```python
182-
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187-
quantize_(my_model, qat_config)
182+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
183+
from torchao.quantization.qat import QATConfig
184+
185+
# prepare
186+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
187+
quantize_(my_model, QATConfig(base_config, step="prepare"))
188+
189+
# train model (not shown)
190+
191+
# convert
192+
quantize_(my_model, QATConfig(base_config, step="convert"))
188193
```
189194

190195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).

docs/source/api_ref_qat.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchao.quantization.qat
66

77
.. currentmodule:: torchao.quantization.qat
88

9-
QAT Configs for quantize_
9+
Main Config for quantize_
1010
---------------------------------------
1111
For a full example of how to use QAT with our main `quantize_` API,
1212
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
@@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1515
:toctree: generated/
1616
:nosignatures:
1717

18-
IntXQuantizationAwareTrainingConfig
19-
FromIntXQuantizationAwareTrainingConfig
18+
QATConfig
19+
2020

2121
Custom QAT APIs
2222
---------------
2323
.. autosummary::
2424
:toctree: generated/
2525
:nosignatures:
2626

27+
FakeQuantizeConfigBase
2728
IntxFakeQuantizeConfig
2829
FakeQuantizedLinear
2930
FakeQuantizedEmbedding
3031
FakeQuantizer
3132
linear.enable_linear_fake_quant
3233
linear.disable_linear_fake_quant
3334

34-
Legacy QAT Quantizers
35+
Legacy QAT APIs
3536
---------------------
3637

3738
.. autosummary::
3839
:toctree: generated/
3940
:nosignatures:
4041

42+
IntXQuantizationAwareTrainingConfig
43+
FromIntXQuantizationAwareTrainingConfig
4144
Int4WeightOnlyQATQuantizer
4245
linear.Int4WeightOnlyQATLinear
4346
Int8DynActInt4WeightQATQuantizer

test/quantization/test_qat.py

Lines changed: 135 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ComposableQATQuantizer,
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
37+
QATConfig,
3738
initialize_fake_quantizers,
3839
)
3940
from torchao.quantization.qat.embedding import (
@@ -59,7 +60,7 @@
5960
_get_qmin_qmax,
6061
)
6162
from torchao.quantization.quant_api import (
62-
int8_dynamic_activation_int4_weight,
63+
Int8DynamicActivationInt4WeightConfig,
6364
)
6465
from torchao.quantization.quant_primitives import (
6566
MappingType,
@@ -1261,11 +1262,67 @@ def test_qat_prototype_bc(self):
12611262
@unittest.skipIf(
12621263
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
12631264
)
1264-
def test_quantize_api_standalone(self):
1265+
def test_qat_config_init(self):
1266+
"""
1267+
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
1268+
"""
1269+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
1270+
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")
1271+
1272+
# OK
1273+
QATConfig(base_config, step="prepare")
1274+
QATConfig(base_config, step="convert")
1275+
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
1276+
QATConfig(weight_config=fq_config, step="prepare")
1277+
1278+
# OK: good step values
1279+
self.assertEqual(QATConfig(base_config).step, "prepare")
1280+
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
1281+
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
1282+
1283+
# Bad step
1284+
with self.assertRaisesRegex(
1285+
ValueError, "`step` must be either 'prepare' or 'convert'"
1286+
):
1287+
QATConfig(base_config, step="blah")
1288+
1289+
# Step was not a keyword arg
1290+
with self.assertRaisesRegex(
1291+
TypeError, "4 positional arguments but 5 were given"
1292+
):
1293+
QATConfig(base_config, None, None, "prepare")
1294+
1295+
# No configs are provided
1296+
with self.assertRaisesRegex(
1297+
ValueError, "One of `base_config` or `weight_config` must be specified"
1298+
):
1299+
QATConfig(step="prepare")
1300+
1301+
# Clashing configs are provided
1302+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1303+
QATConfig(base_config, weight_config=fq_config, step="prepare")
1304+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1305+
QATConfig(base_config, activation_config=fq_config, step="prepare")
1306+
with self.assertRaisesRegex(
1307+
ValueError, "must be specified in the convert step"
1308+
):
1309+
QATConfig(weight_config=fq_config, step="convert")
1310+
1311+
# FakeQuantizeConfigBase was specified as base_config
1312+
with self.assertRaisesRegex(
1313+
ValueError,
1314+
"was passed as `base_config`. Did you mean to do the following instead?",
1315+
):
1316+
QATConfig(fq_config, step="prepare")
1317+
1318+
@unittest.skipIf(
1319+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1320+
)
1321+
def test_quantize_api_prepare(self):
12651322
"""
12661323
Test that the following:
12671324
1268-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1325+
quantize_(model, QATConfig(...))
12691326
12701327
can produce the same results as `ComposableQATQuantizer`.
12711328
"""
@@ -1290,20 +1347,15 @@ def test_quantize_api_standalone(self):
12901347
baseline_model = baseline_quantizer.prepare(baseline_model)
12911348

12921349
# quantize_ API
1293-
activation_config = IntxFakeQuantizeConfig(
1294-
torch.int8,
1295-
"per_token",
1296-
is_symmetric=False,
1297-
)
1350+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
12981351
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1299-
quantize_(
1300-
m,
1301-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1352+
qat_config1 = QATConfig(
1353+
activation_config=act_config, weight_config=weight_config
13021354
)
1355+
qat_config2 = QATConfig(weight_config=weight_config)
1356+
quantize_(m, qat_config1)
13031357
quantize_(
1304-
m,
1305-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
1306-
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1358+
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
13071359
)
13081360

13091361
# Compare model values
@@ -1322,37 +1374,29 @@ def test_quantize_api_errors(self):
13221374
Test that we throw exceptions with helpful error messages if `quantize_`
13231375
runs into unexpected configurations.
13241376
"""
1325-
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1377+
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1378+
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
13261379
m = M3()
13271380

13281381
# Embedding currently only supports weight-only quantization
13291382
with self.assertRaisesRegex(
13301383
ValueError, "Activation fake quantization is not supported for embedding"
13311384
):
1332-
quantize_(
1333-
m,
1334-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1335-
lambda m, _: isinstance(m, torch.nn.Embedding),
1336-
)
1385+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))
13371386

13381387
# Only linear and embedding are supported currently
13391388
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1340-
quantize_(
1341-
m,
1342-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1343-
lambda m, _: isinstance(m, torch.nn.ReLU),
1344-
)
1389+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))
13451390

13461391
@unittest.skipIf(
13471392
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
13481393
)
1349-
def test_quantize_api_convert_path(self):
1394+
def test_quantize_api_e2e(self):
13501395
"""
13511396
Test that the following:
13521397
1353-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1354-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
1355-
quantize_(model, int8_dynamic_activation_int4_weight())
1398+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1399+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
13561400
13571401
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
13581402
"""
@@ -1370,16 +1414,8 @@ def test_quantize_api_convert_path(self):
13701414
baseline_model = baseline_quantizer.prepare(baseline_model)
13711415

13721416
# quantize_ prepare
1373-
activation_config = IntxFakeQuantizeConfig(
1374-
torch.int8,
1375-
"per_token",
1376-
is_symmetric=False,
1377-
)
1378-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1379-
quantize_(
1380-
m,
1381-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1382-
)
1417+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1418+
quantize_(m, QATConfig(base_config, step="prepare"))
13831419

13841420
# Compare prepared values
13851421
torch.manual_seed(self.SEED)
@@ -1393,8 +1429,7 @@ def test_quantize_api_convert_path(self):
13931429
baseline_model = baseline_quantizer.convert(baseline_model)
13941430

13951431
# quantize_ convert
1396-
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
1397-
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1432+
quantize_(m, QATConfig(base_config, step="convert"))
13981433

13991434
# Compare converted values
14001435
torch.manual_seed(self.SEED)
@@ -1447,14 +1482,12 @@ def test_qat_linear_bias(self):
14471482
Test that QAT supports linear bias.
14481483
"""
14491484
m = ModelWithLinearBias()
1450-
activation_config = IntxFakeQuantizeConfig(
1451-
torch.int8, "per_token", is_symmetric=False
1452-
)
1485+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14531486
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1454-
quantize_(
1455-
m,
1456-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1487+
qat_config = QATConfig(
1488+
activation_config=act_config, weight_config=weight_config
14571489
)
1490+
quantize_(m, qat_config)
14581491
example_inputs = m.example_inputs()
14591492
m(*example_inputs)
14601493

@@ -1653,7 +1686,7 @@ def test_qat_range_learning(self):
16531686
)
16541687
m = M()
16551688
example_inputs = m.example_inputs()
1656-
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1689+
quantize_(m, QATConfig(weight_config=config))
16571690

16581691
# Not initialized, should fail
16591692
for t in m._get_all_weight_qparams():
@@ -1756,6 +1789,60 @@ def test_qat_fp8a4w_quantizer(self):
17561789
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17571790
self.assertFalse(torch.equal(new_weight, prev_weight))
17581791

1792+
@unittest.skipIf(
1793+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1794+
)
1795+
def test_legacy_quantize_api_e2e(self):
1796+
"""
1797+
Test that the following two APIs are numerically equivalent:
1798+
1799+
New API:
1800+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1801+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1802+
1803+
Old API:
1804+
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1805+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1806+
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1807+
"""
1808+
group_size = 16
1809+
torch.manual_seed(self.SEED)
1810+
m = M()
1811+
baseline_model = copy.deepcopy(m)
1812+
1813+
# Baseline prepare
1814+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1815+
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1816+
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1817+
quantize_(baseline_model, old_qat_config)
1818+
1819+
# QATConfig prepare
1820+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1821+
quantize_(m, QATConfig(base_config, step="prepare"))
1822+
1823+
# Compare prepared values
1824+
torch.manual_seed(self.SEED)
1825+
x = m.example_inputs()
1826+
x2 = copy.deepcopy(x)
1827+
out = m(*x)
1828+
baseline_out = baseline_model(*x2)
1829+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1830+
1831+
# Baseline convert
1832+
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1833+
quantize_(baseline_model, base_config)
1834+
1835+
# quantize_ convert
1836+
quantize_(m, QATConfig(base_config, step="convert"))
1837+
1838+
# Compare converted values
1839+
torch.manual_seed(self.SEED)
1840+
x = m.example_inputs()
1841+
x2 = copy.deepcopy(x)
1842+
out = m(*x)
1843+
baseline_out = baseline_model(*x2)
1844+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1845+
17591846

17601847
if __name__ == "__main__":
17611848
unittest.main()

0 commit comments

Comments
 (0)