From e1d7de3d96e4ddad3e6196fad90d31bee701f7fa Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 10:50:46 -0700 Subject: [PATCH 1/5] Deprecate old TORCH_VERSION variables **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 46 ++++++++++++++++++++++++++++++- torchao/utils.py | 69 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3ba2f32613..0697a97f72 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest +import warnings from unittest.mock import patch import torch @@ -12,7 +13,7 @@ from torchao.utils import TorchAOBaseTensor, torch_version_at_least -class TestTorchVersionAtLeast(unittest.TestCase): +class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ ("2.5.0a0+git9f17037", "2.5.0", True), @@ -35,6 +36,49 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) + def test_torch_version_deprecation(self): + """ + Test that TORCH_VERSION_AT_LEAST_2_5 and before and TORCH_VERSION_AFTER* + trigger a deprecation warning. + """ + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Importing and referencing should not trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + from torchao.utils import ( + TORCH_VERSION_AFTER_2_2, + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, + TORCH_VERSION_AT_LEAST_2_2, + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + ) + + deprecated_api_to_name = { + TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", + TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", + TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", + TORCH_VERSION_AT_LEAST_2_2: "TORCH_VERSION_AT_LEAST_2_2", + TORCH_VERSION_AFTER_2_5: "TORCH_VERSION_AFTER_2_5", + TORCH_VERSION_AFTER_2_4: "TORCH_VERSION_AFTER_2_4", + TORCH_VERSION_AFTER_2_3: "TORCH_VERSION_AFTER_2_3", + TORCH_VERSION_AFTER_2_2: "TORCH_VERSION_AFTER_2_2", + } + self.assertEqual(len(_warnings), 0) + + # Accessing the boolean value should trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + for api, name in deprecated_api_to_name.items(): + num_warnings_before = len(_warnings) + if api: + pass + regex = f"{name} is deprecated and will be removed" + self.assertEqual(len(_warnings), num_warnings_before + 1) + self.assertIn(regex, str(_warnings[-1].message)) + class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index fb82b9f005..e0ffabc3cf 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,6 +8,7 @@ import itertools import re import time +import warnings from functools import reduce from importlib.metadata import version from math import gcd @@ -377,13 +378,62 @@ def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 +# Deprecated, will be deleted in the future +def _torch_version_after(min_version): + return is_fbcode() or version("torch") >= min_version + + +def _get_old_torch_version_deprecation_msg(version_str: str) -> str: + return f"TORCH_VERSION_AT_LEAST_{version_str} is deprecated and will be removed in torchao 0.14.0" + + +def _get_torch_version_after_deprecation_msg(version_str: str) -> str: + return f"TORCH_VERSION_AFTER_{version_str} is deprecated and will be removed in torchao 0.14.0" + + +class _BoolDeprecationWrapper: + """ + A deprecation wrapper that logs a warning when the given bool value is accessed. + """ + + def __init__(self, bool_value: bool, msg: str): + self.bool_value = bool_value + self.msg = msg + + def __bool__(self): + warnings.warn(self.msg) + return self.bool_value + + TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0") + +# Deprecated +TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( + torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") +) +TORCH_VERSION_AT_LEAST_2_4 = _BoolDeprecationWrapper( + torch_version_at_least("2.4.0"), _get_old_torch_version_deprecation_msg("2_4") +) +TORCH_VERSION_AT_LEAST_2_3 = _BoolDeprecationWrapper( + torch_version_at_least("2.3.0"), _get_old_torch_version_deprecation_msg("2_3") +) +TORCH_VERSION_AT_LEAST_2_2 = _BoolDeprecationWrapper( + torch_version_at_least("2.2.0"), _get_old_torch_version_deprecation_msg("2_2") +) +TORCH_VERSION_AFTER_2_5 = _BoolDeprecationWrapper( + _torch_version_after("2.5.0.dev"), _get_torch_version_after_deprecation_msg("2_5") +) +TORCH_VERSION_AFTER_2_4 = _BoolDeprecationWrapper( + _torch_version_after("2.4.0.dev"), _get_torch_version_after_deprecation_msg("2_4") +) +TORCH_VERSION_AFTER_2_3 = _BoolDeprecationWrapper( + _torch_version_after("2.3.0.dev"), _get_torch_version_after_deprecation_msg("2_3") +) +TORCH_VERSION_AFTER_2_2 = _BoolDeprecationWrapper( + _torch_version_after("2.2.0.dev"), _get_torch_version_after_deprecation_msg("2_2") +) """ @@ -766,11 +816,6 @@ def fill_defaults(args, n, defaults_tail): return r -## Deprecated, will be deleted in the future -def _torch_version_at_least(min_version): - return is_fbcode() or version("torch") >= min_version - - # Supported AMD GPU Models and their LLVM gfx Codes: # # | AMD GPU Model | LLVM gfx Code | @@ -857,12 +902,6 @@ def ceil_div(a, b): return (a + b - 1) // b -TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") - - def is_package_at_least(package_name: str, min_version: str): package_exists = importlib.util.find_spec(package_name) is not None if not package_exists: From 922fc3e01872101b32463781279f49f780a8fc90 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 11:07:05 -0700 Subject: [PATCH 2/5] Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 4 +++- torchao/utils.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 0697a97f72..ddbbf68ecc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -38,7 +38,7 @@ def test_torch_version_at_least(self): def test_torch_version_deprecation(self): """ - Test that TORCH_VERSION_AT_LEAST_2_5 and before and TORCH_VERSION_AFTER* + Test that TORCH_VERSION_AT_LEAST_2_6 and before and TORCH_VERSION_AFTER* trigger a deprecation warning. """ # Reset deprecation warning state, otherwise we won't log warnings here @@ -55,9 +55,11 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, ) deprecated_api_to_name = { + TORCH_VERSION_AT_LEAST_2_6: "TORCH_VERSION_AT_LEAST_2_6", TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", diff --git a/torchao/utils.py b/torchao/utils.py index e0ffabc3cf..307e02c4a7 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -407,9 +407,11 @@ def __bool__(self): TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") # Deprecated +TORCH_VERSION_AT_LEAST_2_6 = _BoolDeprecationWrapper( + torch_version_at_least("2.6.0"), _get_old_torch_version_deprecation_msg("2_6") +) TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") ) From d8a98dedf44b404773aec561329b575b760f7035 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 12:54:15 -0700 Subject: [PATCH 3/5] Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` # Always True TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 # TORCH_VERSION_AFTER* was confusing to users TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means the oldest pytorch version we support is now 2.6 since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- torchao/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/utils.py b/torchao/utils.py index 307e02c4a7..9122284947 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -404,6 +404,9 @@ def __bool__(self): warnings.warn(self.msg) return self.bool_value + def __eq__(self, other): + return bool(self) == bool(other) + TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") From ccb28b48bd9d9b6019422c42c096e3ec761c96d1 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 13:10:20 -0700 Subject: [PATCH 4/5] Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` # Always True TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 # TORCH_VERSION_AFTER* was confusing to users TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means the oldest pytorch version we support is now 2.6 since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 24 +++++++++---------- torchao/utils.py | 57 +++++++++++++++++----------------------------- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ddbbf68ecc..ebc23466c1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -58,22 +58,22 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_6, ) - deprecated_api_to_name = { - TORCH_VERSION_AT_LEAST_2_6: "TORCH_VERSION_AT_LEAST_2_6", - TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", - TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", - TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", - TORCH_VERSION_AT_LEAST_2_2: "TORCH_VERSION_AT_LEAST_2_2", - TORCH_VERSION_AFTER_2_5: "TORCH_VERSION_AFTER_2_5", - TORCH_VERSION_AFTER_2_4: "TORCH_VERSION_AFTER_2_4", - TORCH_VERSION_AFTER_2_3: "TORCH_VERSION_AFTER_2_3", - TORCH_VERSION_AFTER_2_2: "TORCH_VERSION_AFTER_2_2", - } + deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), + (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), + (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), + (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), + (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), + (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), + (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), + (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), + (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), + ] self.assertEqual(len(_warnings), 0) # Accessing the boolean value should trigger deprecation warning with warnings.catch_warnings(record=True) as _warnings: - for api, name in deprecated_api_to_name.items(): + for api, name in deprecated_api_to_name: num_warnings_before = len(_warnings) if api: pass diff --git a/torchao/utils.py b/torchao/utils.py index 9122284947..4743806b44 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -378,17 +378,20 @@ def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 -# Deprecated, will be deleted in the future -def _torch_version_after(min_version): - return is_fbcode() or version("torch") >= min_version - - -def _get_old_torch_version_deprecation_msg(version_str: str) -> str: - return f"TORCH_VERSION_AT_LEAST_{version_str} is deprecated and will be removed in torchao 0.14.0" +def _deprecated_torch_version_at_least(version_str: str) -> str: + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper( + torch_version_at_least(version_str), + deprecation_msg, + ) -def _get_torch_version_after_deprecation_msg(version_str: str) -> str: - return f"TORCH_VERSION_AFTER_{version_str} is deprecated and will be removed in torchao 0.14.0" +def _deprecated_torch_version_after(version_str: str) -> str: + bool_value = is_fbcode() or version("torch") >= version_str + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper(bool_value, deprecation_msg) class _BoolDeprecationWrapper: @@ -412,33 +415,15 @@ def __eq__(self, other): TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") # Deprecated -TORCH_VERSION_AT_LEAST_2_6 = _BoolDeprecationWrapper( - torch_version_at_least("2.6.0"), _get_old_torch_version_deprecation_msg("2_6") -) -TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( - torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") -) -TORCH_VERSION_AT_LEAST_2_4 = _BoolDeprecationWrapper( - torch_version_at_least("2.4.0"), _get_old_torch_version_deprecation_msg("2_4") -) -TORCH_VERSION_AT_LEAST_2_3 = _BoolDeprecationWrapper( - torch_version_at_least("2.3.0"), _get_old_torch_version_deprecation_msg("2_3") -) -TORCH_VERSION_AT_LEAST_2_2 = _BoolDeprecationWrapper( - torch_version_at_least("2.2.0"), _get_old_torch_version_deprecation_msg("2_2") -) -TORCH_VERSION_AFTER_2_5 = _BoolDeprecationWrapper( - _torch_version_after("2.5.0.dev"), _get_torch_version_after_deprecation_msg("2_5") -) -TORCH_VERSION_AFTER_2_4 = _BoolDeprecationWrapper( - _torch_version_after("2.4.0.dev"), _get_torch_version_after_deprecation_msg("2_4") -) -TORCH_VERSION_AFTER_2_3 = _BoolDeprecationWrapper( - _torch_version_after("2.3.0.dev"), _get_torch_version_after_deprecation_msg("2_3") -) -TORCH_VERSION_AFTER_2_2 = _BoolDeprecationWrapper( - _torch_version_after("2.2.0.dev"), _get_torch_version_after_deprecation_msg("2_2") -) +TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") +TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") +TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") +TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") +TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") +TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") +TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") """ From 2b74165bcd9fa7842cbae257662669c13858d5e5 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 13:59:11 -0700 Subject: [PATCH 5/5] Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` # Always True TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 # TORCH_VERSION_AFTER* was confusing to users TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means the oldest pytorch version we support is now 2.6 since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 8 ++++++-- torchao/utils.py | 15 +++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ebc23466c1..9213097276 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -38,8 +38,8 @@ def test_torch_version_at_least(self): def test_torch_version_deprecation(self): """ - Test that TORCH_VERSION_AT_LEAST_2_6 and before and TORCH_VERSION_AFTER* - trigger a deprecation warning. + Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* + trigger deprecation warnings on use, not on import. """ # Reset deprecation warning state, otherwise we won't log warnings here warnings.resetwarnings() @@ -56,9 +56,13 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, ) deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), + (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), diff --git a/torchao/utils.py b/torchao/utils.py index 4743806b44..ea939bdd9a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -379,6 +379,10 @@ def torch_version_at_least(min_version): def _deprecated_torch_version_at_least(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log + a deprecation warning if the variable is used. + """ version_str_var_name = "_".join(version_str.split(".")[:2]) deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" return _BoolDeprecationWrapper( @@ -388,6 +392,10 @@ def _deprecated_torch_version_at_least(version_str: str) -> str: def _deprecated_torch_version_after(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AFTER* variables that will log + a deprecation warning if the variable is used. + """ bool_value = is_fbcode() or version("torch") >= version_str version_str_var_name = "_".join(version_str.split(".")[:2]) deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" @@ -411,10 +419,9 @@ def __eq__(self, other): return bool(self) == bool(other) -TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") - -# Deprecated +# Deprecated, use `torch_version_at_least` directly instead +TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") +TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0")