diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index bd6ede0af5..48a9f780b6 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -9,6 +9,7 @@ import copy import unittest +import warnings from typing import List import torch @@ -1844,6 +1845,45 @@ def test_legacy_quantize_api_e2e(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_api_deprecation(self): + """ + Test that the appropriate deprecation warning is logged exactly once per class. + """ + from torchao.quantization.qat import ( + FakeQuantizeConfig, + from_intx_quantization_aware_training, + intx_quantization_aware_training, + ) + + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Map from deprecated API to the args needed to instantiate it + deprecated_apis_to_args = { + IntXQuantizationAwareTrainingConfig: (), + FromIntXQuantizationAwareTrainingConfig: (), + intx_quantization_aware_training: (), + from_intx_quantization_aware_training: (), + FakeQuantizeConfig: (torch.int8, "per_channel"), + } + + with warnings.catch_warnings(record=True) as _warnings: + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): + cls(*args) + cls(*args) + + # Each call should trigger the warning only once + self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) + for w in _warnings: + self.assertIn( + "is deprecated and will be removed in a future release", + str(w.message), + ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 0b7c1228b0..0d69f44bd9 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -24,6 +24,7 @@ _infer_fake_quantize_configs, ) from .linear import FakeQuantizedLinear +from .utils import _log_deprecation_warning class QATStep(str, Enum): @@ -224,11 +225,11 @@ def _qat_config_transform( return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) -# TODO: deprecate @dataclass class IntXQuantizationAwareTrainingConfig(AOBaseConfig): """ - (Will be deprecated soon) + (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. + Config for applying fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. @@ -256,9 +257,13 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig): activation_config: Optional[FakeQuantizeConfigBase] = None weight_config: Optional[FakeQuantizeConfigBase] = None + def __post_init__(self): + _log_deprecation_warning(self) + # for BC -intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig +class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig): + pass @register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) @@ -286,10 +291,11 @@ def _intx_quantization_aware_training_transform( raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) -# TODO: deprecate +@dataclass class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ - (Will be deprecated soon) + (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. + Config for converting a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, @@ -306,11 +312,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): ) """ - pass + def __post_init__(self): + _log_deprecation_warning(self) # for BC -from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig +class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig): + pass @register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 77b40267ad..554ed2a065 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -25,6 +25,8 @@ ZeroPointDomain, ) +from .utils import _log_deprecation_warning + class FakeQuantizeConfigBase(abc.ABC): """ @@ -134,6 +136,14 @@ def __init__( if is_dynamic and range_learning: raise ValueError("`is_dynamic` is not compatible with `range_learning`") + self.__post_init__() + + def __post_init__(self): + """ + For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630. + """ + pass + def _get_granularity( self, granularity: Union[Granularity, str, None], @@ -260,7 +270,13 @@ def __setattr__(self, name: str, value: Any): # For BC -FakeQuantizeConfig = IntxFakeQuantizeConfig +class FakeQuantizeConfig(IntxFakeQuantizeConfig): + """ + (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead. + """ + + def __post_init__(self): + _log_deprecation_warning(self) def _infer_fake_quantize_configs( diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 5fc51ab7ca..e2f425a1d5 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import warnings +from typing import Any import torch @@ -104,3 +106,33 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True): qmin = 0 qmax = 2**n_bit - 1 return (qmin, qmax) + + +def _log_deprecation_warning(old_api_object: Any): + """ + Log a helpful deprecation message pointing users to the new QAT API, + only once per deprecated class. + """ + warnings.warn( + """'%s' is deprecated and will be removed in a future release. Please use the following API instead: + + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) + # train (not shown) + quantize_(model, QATConfig(base_config, step="convert")) + +Alternatively, if you prefer to pass in fake quantization configs: + + activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) + qat_config = QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", + ) + quantize_(model, qat_config) + +Please see https://github.com/pytorch/ao/issues/2630 for more details. + """ + % old_api_object.__class__.__name__ + )