Skip to content

Deprecate old TORCH_VERSION variables #2719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import warnings
from unittest.mock import patch

import torch
Expand All @@ -12,7 +13,7 @@
from torchao.utils import TorchAOBaseTensor, torch_version_at_least


class TestTorchVersionAtLeast(unittest.TestCase):
class TestTorchVersion(unittest.TestCase):
def test_torch_version_at_least(self):
test_cases = [
("2.5.0a0+git9f17037", "2.5.0", True),
Expand All @@ -35,6 +36,55 @@ def test_torch_version_at_least(self):
f"Failed for torch.__version__={torch_version}, comparing with {compare_version}",
)

def test_torch_version_deprecation(self):
"""
Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER*
trigger deprecation warnings on use, not on import.
"""
# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Importing and referencing should not trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
from torchao.utils import (
TORCH_VERSION_AFTER_2_2,
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_2,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
TORCH_VERSION_AT_LEAST_2_8,
)

deprecated_api_to_name = [
(TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"),
(TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"),
(TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"),
(TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"),
(TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"),
(TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"),
(TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"),
(TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"),
(TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"),
(TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"),
(TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"),
]
self.assertEqual(len(_warnings), 0)

# Accessing the boolean value should trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
for api, name in deprecated_api_to_name:
num_warnings_before = len(_warnings)
if api:
pass
regex = f"{name} is deprecated and will be removed"
self.assertEqual(len(_warnings), num_warnings_before + 1)
self.assertIn(regex, str(_warnings[-1].message))


class TestTorchAOBaseTensor(unittest.TestCase):
def test_print_arg_types(self):
Expand Down
72 changes: 54 additions & 18 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import re
import time
import warnings
from functools import reduce
from importlib.metadata import version
from math import gcd
Expand Down Expand Up @@ -377,13 +378,59 @@ def torch_version_at_least(min_version):
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0


TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0")
TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")
TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0")
TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0")
def _deprecated_torch_version_at_least(version_str: str) -> str:
"""
Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log
a deprecation warning if the variable is used.
"""
version_str_var_name = "_".join(version_str.split(".")[:2])
deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
return _BoolDeprecationWrapper(
torch_version_at_least(version_str),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is current torch version related to this? this is talking about deprecation of these variables right? if so, then these variables should be deprecated now, regardless of what the system torch version is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I can deprecate them all. I think these variables exist so we don't keep calling this function, but yeah we definitely don't need to expose them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok just deprecated them all

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks, also the deprecation doesn't need to depend on the current pytorch version I think

we could also add a cache for torch_version_at_least as well I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deprecation doesn't need to depend on the current pytorch version

Do you mean just mention TORCH_VERSION_AT_LEAST_* is deprecated in the warning instead of including the version number?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the current deprecation message seems to be printed only if the current pytorch version is above some version_str, what I'm proposing is just remove this check and always print the deprecation

Copy link
Contributor Author

@andrewor14 andrewor14 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the warning should always be printed I think. The current pytorch version is not used to decide whether or not we print it (it's only passed to torch_version_at_least)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, I misunderstood, so the bool value is to pass around the return value of the torch_version_at_least function while also print the message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exactly

deprecation_msg,
)


def _deprecated_torch_version_after(version_str: str) -> str:
"""
Wrapper for existing TORCH_VERSION_AFTER* variables that will log
a deprecation warning if the variable is used.
"""
bool_value = is_fbcode() or version("torch") >= version_str
version_str_var_name = "_".join(version_str.split(".")[:2])
deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
return _BoolDeprecationWrapper(bool_value, deprecation_msg)


class _BoolDeprecationWrapper:
"""
A deprecation wrapper that logs a warning when the given bool value is accessed.
"""

def __init__(self, bool_value: bool, msg: str):
self.bool_value = bool_value
self.msg = msg

def __bool__(self):
warnings.warn(self.msg)
return self.bool_value

def __eq__(self, other):
return bool(self) == bool(other)


# Deprecated, use `torch_version_at_least` directly instead
TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0")
TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0")
TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0")
TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0")
TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0")
TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")


"""
Expand Down Expand Up @@ -766,11 +813,6 @@ def fill_defaults(args, n, defaults_tail):
return r


## Deprecated, will be deleted in the future
def _torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version


# Supported AMD GPU Models and their LLVM gfx Codes:
#
# | AMD GPU Model | LLVM gfx Code |
Expand Down Expand Up @@ -857,12 +899,6 @@ def ceil_div(a, b):
return (a + b - 1) // b


TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev")


def is_package_at_least(package_name: str, min_version: str):
package_exists = importlib.util.find_spec(package_name) is not None
if not package_exists:
Expand Down
Loading