Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c358b1b
[bc-breaking] Generalize FakeQuantizeConfig beyond intx
andrewor14 Jul 29, 2025
d076264
New multi-step QAT API
andrewor14 Jul 29, 2025
8f56651
Update on "New multi-step QAT API"
andrewor14 Jul 29, 2025
7a9fe90
Update on "New multi-step QAT API"
andrewor14 Jul 30, 2025
1e88ebf
Deprecate old QAT APIs
andrewor14 Jul 30, 2025
12e8c3f
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
2ed5e50
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
019b665
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b0c4721
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
0eb0983
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
c91b218
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
affc74e
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
3069075
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b41f4e7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
12d920b
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
68728d7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
52f72a5
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
56415d6
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
08f87af
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
be45ff4
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
28b3b41
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
9baae23
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
62cd942
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
1c30bbb
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
2fbfbb6
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
3f06429
Merge branch 'main' into gh/andrewor14/15/head
andrewor14 Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):

```python
from torchao.quantization import quantize_
from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(my_model, qat_config)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)

# convert
quantize_(my_model, QATConfig(base_config, step="convert"))
```

Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).
Expand Down
9 changes: 5 additions & 4 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torchao.quantization.qat

.. currentmodule:: torchao.quantization.qat

QAT Configs for quantize_
Main Config for quantize_
---------------------------------------
For a full example of how to use QAT with our main `quantize_` API,
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
Expand All @@ -15,16 +15,17 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
QATConfig
QATStep

Custom QAT APIs
---------------
.. autosummary::
:toctree: generated/
:nosignatures:

FakeQuantizeConfig
FakeQuantizeConfigBase
IntxFakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizer
Expand Down
35 changes: 11 additions & 24 deletions docs/source/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.

.. code:: py

from torchao.quantization import (
quantize_,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
quantize_(model, qat_config)
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))

# fine-tune
train_loop(model)
Expand All @@ -232,18 +225,12 @@ The next step is to actually quantize the model:

.. code:: py

from torchao.quantization import (
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import Int8DynamicActivationInt4WeightConfig

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
quantize_(model, QATConfig(base_config, step="convert"))

# inference or generate

Now our model is ready for serving, and will typically have higher quantized
accuracy than if we did not apply the prepare step (fake quantization) during
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_int8_dynamic_activation_intx_e2e(
optimizer.step()

# apply torchao quantized activations on top
activation_config = FakeQuantizeConfig(
activation_config = IntxFakeQuantizeConfig(
torch.int8,
granularity="per_token",
mapping_type=config.act_mapping_type,
Expand Down
Loading
Loading