|
8 | 8 | # This test takes a long time to run
|
9 | 9 |
|
10 | 10 | import copy
|
| 11 | +import io |
| 12 | +import logging |
11 | 13 | import unittest
|
12 | 14 | from typing import List
|
13 | 15 |
|
@@ -1841,6 +1843,64 @@ def test_legacy_quantize_api_e2e(self):
|
1841 | 1843 | baseline_out = baseline_model(*x2)
|
1842 | 1844 | torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
|
1843 | 1845 |
|
| 1846 | + def _test_deprecation(self, deprecated_class, *example_args, first_time=True): |
| 1847 | + """ |
| 1848 | + Assert that instantiating a deprecated class triggers the deprecation warning. |
| 1849 | + """ |
| 1850 | + try: |
| 1851 | + log_stream = io.StringIO() |
| 1852 | + handler = logging.StreamHandler(log_stream) |
| 1853 | + logger = logging.getLogger(deprecated_class.__module__) |
| 1854 | + logger.addHandler(handler) |
| 1855 | + logger.setLevel(logging.WARN) |
| 1856 | + deprecated_class(*example_args) |
| 1857 | + if first_time: |
| 1858 | + regex = ( |
| 1859 | + "'%s' is deprecated and will be removed in a future release" |
| 1860 | + % deprecated_class.__name__ |
| 1861 | + ) |
| 1862 | + self.assertIn(regex, log_stream.getvalue()) |
| 1863 | + else: |
| 1864 | + self.assertEqual(log_stream.getvalue(), "") |
| 1865 | + finally: |
| 1866 | + logger.removeHandler(handler) |
| 1867 | + handler.close() |
| 1868 | + |
| 1869 | + @unittest.skipIf( |
| 1870 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1871 | + ) |
| 1872 | + def test_qat_api_deprecation(self): |
| 1873 | + """ |
| 1874 | + Test that the appropriate deprecation warning has been logged. |
| 1875 | + """ |
| 1876 | + from torchao.quantization.qat import ( |
| 1877 | + FakeQuantizeConfig, |
| 1878 | + from_intx_quantization_aware_training, |
| 1879 | + intx_quantization_aware_training, |
| 1880 | + ) |
| 1881 | + from torchao.quantization.qat.utils import _LOGGED_DEPRECATED_CLASS_NAMES |
| 1882 | + |
| 1883 | + # Reset deprecation warning state, otherwise we won't log warnings here |
| 1884 | + _LOGGED_DEPRECATED_CLASS_NAMES.clear() |
| 1885 | + |
| 1886 | + # Assert that the deprecation warning is logged |
| 1887 | + self._test_deprecation(IntXQuantizationAwareTrainingConfig) |
| 1888 | + self._test_deprecation(FromIntXQuantizationAwareTrainingConfig) |
| 1889 | + self._test_deprecation(intx_quantization_aware_training) |
| 1890 | + self._test_deprecation(from_intx_quantization_aware_training) |
| 1891 | + self._test_deprecation(FakeQuantizeConfig, torch.int8, "per_channel") |
| 1892 | + |
| 1893 | + # Assert that warning is only logged once per class |
| 1894 | + self._test_deprecation(IntXQuantizationAwareTrainingConfig, first_time=False) |
| 1895 | + self._test_deprecation( |
| 1896 | + FromIntXQuantizationAwareTrainingConfig, first_time=False |
| 1897 | + ) |
| 1898 | + self._test_deprecation(intx_quantization_aware_training, first_time=False) |
| 1899 | + self._test_deprecation(from_intx_quantization_aware_training, first_time=False) |
| 1900 | + self._test_deprecation( |
| 1901 | + FakeQuantizeConfig, torch.int8, "per_channel", first_time=False |
| 1902 | + ) |
| 1903 | + |
1844 | 1904 |
|
1845 | 1905 | if __name__ == "__main__":
|
1846 | 1906 | unittest.main()
|
0 commit comments