Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):

```python
from torchao.quantization import quantize_
from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(my_model, qat_config)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)

# convert
quantize_(my_model, QATConfig(base_config, step="convert"))
```

Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).
Expand Down
13 changes: 8 additions & 5 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torchao.quantization.qat

.. currentmodule:: torchao.quantization.qat

QAT Configs for quantize_
Main Config for quantize_
---------------------------------------
For a full example of how to use QAT with our main `quantize_` API,
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
Expand All @@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
QATConfig
QATConfigStep

Custom QAT APIs
---------------
.. autosummary::
:toctree: generated/
:nosignatures:

FakeQuantizeConfig
FakeQuantizeConfigBase
IntxFakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizer
linear.enable_linear_fake_quant
linear.disable_linear_fake_quant

Legacy QAT Quantizers
Legacy QAT APIs
---------------------

.. autosummary::
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
Int4WeightOnlyQATQuantizer
linear.Int4WeightOnlyQATLinear
Int8DynActInt4WeightQATQuantizer
Expand Down
35 changes: 11 additions & 24 deletions docs/source/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.

.. code:: py

from torchao.quantization import (
quantize_,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
quantize_(model, qat_config)
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))

# fine-tune
train_loop(model)
Expand All @@ -232,18 +225,12 @@ The next step is to actually quantize the model:

.. code:: py

from torchao.quantization import (
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import Int8DynamicActivationInt4WeightConfig

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
quantize_(model, QATConfig(base_config, step="convert"))

# inference or generate

Now our model is ready for serving, and will typically have higher quantized
accuracy than if we did not apply the prepare step (fake quantization) during
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_int8_dynamic_activation_intx_e2e(
optimizer.step()

# apply torchao quantized activations on top
activation_config = FakeQuantizeConfig(
activation_config = IntxFakeQuantizeConfig(
torch.int8,
granularity="per_token",
mapping_type=config.act_mapping_type,
Expand Down
Loading
Loading