Skip to content

Commit 41c4876

Browse files
committed
Deprecate old QAT APIs
**Summary:** Deprecates QAT APIs that should no longer be used. Print helpful deprecation warning to help users migrate. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_api_deprecation ``` Also manual testing: ``` 'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead: base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(model, QATConfig(base_config, step="prepare")) # train (not shown) quantize_(model, QATConfig(base_config, step="convert")) Alternatively, if you prefer to pass in fake quantization configs: activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) Please see #2630 for more details. IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None) ``` ghstack-source-id: bb3fc80 Pull Request resolved: #2641
1 parent 374d6af commit 41c4876

File tree

5 files changed

+129
-11
lines changed

5 files changed

+129
-11
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@ Custom QAT APIs
3232
linear.enable_linear_fake_quant
3333
linear.disable_linear_fake_quant
3434

35-
Legacy QAT APIs
35+
Legacy QAT Quantizers
3636
---------------------
3737

3838
.. autosummary::
3939
:toctree: generated/
4040
:nosignatures:
4141

42-
IntXQuantizationAwareTrainingConfig
43-
FromIntXQuantizationAwareTrainingConfig
4442
Int4WeightOnlyQATQuantizer
4543
linear.Int4WeightOnlyQATLinear
4644
Int8DynActInt4WeightQATQuantizer

test/quantization/test_qat.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# This test takes a long time to run
99

1010
import copy
11+
import io
12+
import logging
1113
import unittest
1214
from typing import List
1315

@@ -1841,6 +1843,64 @@ def test_legacy_quantize_api_e2e(self):
18411843
baseline_out = baseline_model(*x2)
18421844
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18431845

1846+
def _test_deprecation(self, deprecated_class, *example_args, first_time=True):
1847+
"""
1848+
Assert that instantiating a deprecated class triggers the deprecation warning.
1849+
"""
1850+
try:
1851+
log_stream = io.StringIO()
1852+
handler = logging.StreamHandler(log_stream)
1853+
logger = logging.getLogger(deprecated_class.__module__)
1854+
logger.addHandler(handler)
1855+
logger.setLevel(logging.WARN)
1856+
deprecated_class(*example_args)
1857+
if first_time:
1858+
regex = (
1859+
"'%s' is deprecated and will be removed in a future release"
1860+
% deprecated_class.__name__
1861+
)
1862+
self.assertIn(regex, log_stream.getvalue())
1863+
else:
1864+
self.assertEqual(log_stream.getvalue(), "")
1865+
finally:
1866+
logger.removeHandler(handler)
1867+
handler.close()
1868+
1869+
@unittest.skipIf(
1870+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1871+
)
1872+
def test_qat_api_deprecation(self):
1873+
"""
1874+
Test that the appropriate deprecation warning has been logged.
1875+
"""
1876+
from torchao.quantization.qat import (
1877+
FakeQuantizeConfig,
1878+
from_intx_quantization_aware_training,
1879+
intx_quantization_aware_training,
1880+
)
1881+
from torchao.quantization.qat.utils import _LOGGED_DEPRECATED_CLASS_NAMES
1882+
1883+
# Reset deprecation warning state, otherwise we won't log warnings here
1884+
_LOGGED_DEPRECATED_CLASS_NAMES.clear()
1885+
1886+
# Assert that the deprecation warning is logged
1887+
self._test_deprecation(IntXQuantizationAwareTrainingConfig)
1888+
self._test_deprecation(FromIntXQuantizationAwareTrainingConfig)
1889+
self._test_deprecation(intx_quantization_aware_training)
1890+
self._test_deprecation(from_intx_quantization_aware_training)
1891+
self._test_deprecation(FakeQuantizeConfig, torch.int8, "per_channel")
1892+
1893+
# Assert that warning is only logged once per class
1894+
self._test_deprecation(IntXQuantizationAwareTrainingConfig, first_time=False)
1895+
self._test_deprecation(
1896+
FromIntXQuantizationAwareTrainingConfig, first_time=False
1897+
)
1898+
self._test_deprecation(intx_quantization_aware_training, first_time=False)
1899+
self._test_deprecation(from_intx_quantization_aware_training, first_time=False)
1900+
self._test_deprecation(
1901+
FakeQuantizeConfig, torch.int8, "per_channel", first_time=False
1902+
)
1903+
18441904

18451905
if __name__ == "__main__":
18461906
unittest.main()

torchao/quantization/qat/api.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_infer_fake_quantize_configs,
2525
)
2626
from .linear import FakeQuantizedLinear
27+
from .utils import _log_deprecation_warning
2728

2829

2930
class QATConfigStep(str, Enum):
@@ -224,11 +225,11 @@ def _qat_config_transform(
224225
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
225226

226227

227-
# TODO: deprecate
228228
@dataclass
229229
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
230230
"""
231-
(Will be deprecated soon)
231+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
232+
232233
Config for applying fake quantization to a `torch.nn.Module`.
233234
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
234235
@@ -256,9 +257,13 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
256257
activation_config: Optional[FakeQuantizeConfigBase] = None
257258
weight_config: Optional[FakeQuantizeConfigBase] = None
258259

260+
def __post_init__(self):
261+
_log_deprecation_warning(self)
262+
259263

260264
# for BC
261-
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
265+
class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig):
266+
pass
262267

263268

264269
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
@@ -286,10 +291,11 @@ def _intx_quantization_aware_training_transform(
286291
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
287292

288293

289-
# TODO: deprecate
294+
@dataclass
290295
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
291296
"""
292-
(Will be deprecated soon)
297+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
298+
293299
Config for converting a model with fake quantized modules,
294300
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
295301
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
@@ -306,11 +312,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
306312
)
307313
"""
308314

309-
pass
315+
def __post_init__(self):
316+
_log_deprecation_warning(self)
310317

311318

312319
# for BC
313-
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
320+
class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig):
321+
pass
314322

315323

316324
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ZeroPointDomain,
2626
)
2727

28+
from .utils import _log_deprecation_warning
29+
2830

2931
@dataclass
3032
class FakeQuantizeConfigBase(abc.ABC):
@@ -135,6 +137,8 @@ def __init__(
135137
if is_dynamic and range_learning:
136138
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
137139

140+
self.__post_init__()
141+
138142
def _get_granularity(
139143
self,
140144
granularity: Union[Granularity, str, None],
@@ -261,7 +265,13 @@ def __setattr__(self, name: str, value: Any):
261265

262266

263267
# For BC
264-
FakeQuantizeConfig = IntxFakeQuantizeConfig
268+
class FakeQuantizeConfig(IntxFakeQuantizeConfig):
269+
"""
270+
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead.
271+
"""
272+
273+
def __post_init__(self):
274+
_log_deprecation_warning(self)
265275

266276

267277
def _infer_fake_quantize_configs(

torchao/quantization/qat/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
from typing import Any
79

810
import torch
911

@@ -104,3 +106,43 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
104106
qmin = 0
105107
qmax = 2**n_bit - 1
106108
return (qmin, qmax)
109+
110+
111+
# log deprecation warning only once per class
112+
_LOGGED_DEPRECATED_CLASS_NAMES = set()
113+
114+
115+
def _log_deprecation_warning(old_api_object: Any):
116+
"""
117+
Log a helpful deprecation message pointing users to the new QAT API,
118+
only once per deprecated class.
119+
"""
120+
global _LOGGED_DEPRECATED_CLASS_NAMES
121+
deprecated_class_name = old_api_object.__class__.__name__
122+
if deprecated_class_name in _LOGGED_DEPRECATED_CLASS_NAMES:
123+
return
124+
_LOGGED_DEPRECATED_CLASS_NAMES.add(deprecated_class_name)
125+
logger = logging.getLogger(old_api_object.__module__)
126+
logger.warning(
127+
"""'%s' is deprecated and will be removed in a future release. Please use the following API instead:
128+
129+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
130+
quantize_(model, QATConfig(base_config, step="prepare"))
131+
# train (not shown)
132+
quantize_(model, QATConfig(base_config, step="convert"))
133+
134+
Alternatively, if you prefer to pass in fake quantization configs:
135+
136+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
137+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
138+
qat_config = QATConfig(
139+
activation_config=activation_config,
140+
weight_config=weight_config,
141+
step="prepare",
142+
)
143+
quantize_(model, qat_config)
144+
145+
Please see https://github.com/pytorch/ao/issues/2630 for more details.
146+
"""
147+
% deprecated_class_name
148+
)

0 commit comments

Comments
 (0)