From e1d7de3d96e4ddad3e6196fad90d31bee701f7fa Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 10:50:46 -0700 Subject: [PATCH 01/13] Deprecate old TORCH_VERSION variables **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 46 ++++++++++++++++++++++++++++++- torchao/utils.py | 69 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3ba2f32613..0697a97f72 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest +import warnings from unittest.mock import patch import torch @@ -12,7 +13,7 @@ from torchao.utils import TorchAOBaseTensor, torch_version_at_least -class TestTorchVersionAtLeast(unittest.TestCase): +class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ ("2.5.0a0+git9f17037", "2.5.0", True), @@ -35,6 +36,49 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) + def test_torch_version_deprecation(self): + """ + Test that TORCH_VERSION_AT_LEAST_2_5 and before and TORCH_VERSION_AFTER* + trigger a deprecation warning. + """ + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Importing and referencing should not trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + from torchao.utils import ( + TORCH_VERSION_AFTER_2_2, + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, + TORCH_VERSION_AT_LEAST_2_2, + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + ) + + deprecated_api_to_name = { + TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", + TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", + TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", + TORCH_VERSION_AT_LEAST_2_2: "TORCH_VERSION_AT_LEAST_2_2", + TORCH_VERSION_AFTER_2_5: "TORCH_VERSION_AFTER_2_5", + TORCH_VERSION_AFTER_2_4: "TORCH_VERSION_AFTER_2_4", + TORCH_VERSION_AFTER_2_3: "TORCH_VERSION_AFTER_2_3", + TORCH_VERSION_AFTER_2_2: "TORCH_VERSION_AFTER_2_2", + } + self.assertEqual(len(_warnings), 0) + + # Accessing the boolean value should trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + for api, name in deprecated_api_to_name.items(): + num_warnings_before = len(_warnings) + if api: + pass + regex = f"{name} is deprecated and will be removed" + self.assertEqual(len(_warnings), num_warnings_before + 1) + self.assertIn(regex, str(_warnings[-1].message)) + class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index fb82b9f005..e0ffabc3cf 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,6 +8,7 @@ import itertools import re import time +import warnings from functools import reduce from importlib.metadata import version from math import gcd @@ -377,13 +378,62 @@ def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 +# Deprecated, will be deleted in the future +def _torch_version_after(min_version): + return is_fbcode() or version("torch") >= min_version + + +def _get_old_torch_version_deprecation_msg(version_str: str) -> str: + return f"TORCH_VERSION_AT_LEAST_{version_str} is deprecated and will be removed in torchao 0.14.0" + + +def _get_torch_version_after_deprecation_msg(version_str: str) -> str: + return f"TORCH_VERSION_AFTER_{version_str} is deprecated and will be removed in torchao 0.14.0" + + +class _BoolDeprecationWrapper: + """ + A deprecation wrapper that logs a warning when the given bool value is accessed. + """ + + def __init__(self, bool_value: bool, msg: str): + self.bool_value = bool_value + self.msg = msg + + def __bool__(self): + warnings.warn(self.msg) + return self.bool_value + + TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0") + +# Deprecated +TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( + torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") +) +TORCH_VERSION_AT_LEAST_2_4 = _BoolDeprecationWrapper( + torch_version_at_least("2.4.0"), _get_old_torch_version_deprecation_msg("2_4") +) +TORCH_VERSION_AT_LEAST_2_3 = _BoolDeprecationWrapper( + torch_version_at_least("2.3.0"), _get_old_torch_version_deprecation_msg("2_3") +) +TORCH_VERSION_AT_LEAST_2_2 = _BoolDeprecationWrapper( + torch_version_at_least("2.2.0"), _get_old_torch_version_deprecation_msg("2_2") +) +TORCH_VERSION_AFTER_2_5 = _BoolDeprecationWrapper( + _torch_version_after("2.5.0.dev"), _get_torch_version_after_deprecation_msg("2_5") +) +TORCH_VERSION_AFTER_2_4 = _BoolDeprecationWrapper( + _torch_version_after("2.4.0.dev"), _get_torch_version_after_deprecation_msg("2_4") +) +TORCH_VERSION_AFTER_2_3 = _BoolDeprecationWrapper( + _torch_version_after("2.3.0.dev"), _get_torch_version_after_deprecation_msg("2_3") +) +TORCH_VERSION_AFTER_2_2 = _BoolDeprecationWrapper( + _torch_version_after("2.2.0.dev"), _get_torch_version_after_deprecation_msg("2_2") +) """ @@ -766,11 +816,6 @@ def fill_defaults(args, n, defaults_tail): return r -## Deprecated, will be deleted in the future -def _torch_version_at_least(min_version): - return is_fbcode() or version("torch") >= min_version - - # Supported AMD GPU Models and their LLVM gfx Codes: # # | AMD GPU Model | LLVM gfx Code | @@ -857,12 +902,6 @@ def ceil_div(a, b): return (a + b - 1) // b -TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") - - def is_package_at_least(package_name: str, min_version: str): package_exists = importlib.util.find_spec(package_name) is not None if not package_exists: From 922fc3e01872101b32463781279f49f780a8fc90 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 11:07:05 -0700 Subject: [PATCH 02/13] Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] --- test/test_utils.py | 4 +++- torchao/utils.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 0697a97f72..ddbbf68ecc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -38,7 +38,7 @@ def test_torch_version_at_least(self): def test_torch_version_deprecation(self): """ - Test that TORCH_VERSION_AT_LEAST_2_5 and before and TORCH_VERSION_AFTER* + Test that TORCH_VERSION_AT_LEAST_2_6 and before and TORCH_VERSION_AFTER* trigger a deprecation warning. """ # Reset deprecation warning state, otherwise we won't log warnings here @@ -55,9 +55,11 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, ) deprecated_api_to_name = { + TORCH_VERSION_AT_LEAST_2_6: "TORCH_VERSION_AT_LEAST_2_6", TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", diff --git a/torchao/utils.py b/torchao/utils.py index e0ffabc3cf..307e02c4a7 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -407,9 +407,11 @@ def __bool__(self): TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") # Deprecated +TORCH_VERSION_AT_LEAST_2_6 = _BoolDeprecationWrapper( + torch_version_at_least("2.6.0"), _get_old_torch_version_deprecation_msg("2_6") +) TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") ) From fc7dffe6de70b24018b40249e80ddbd0cad64f53 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 12:01:09 -0700 Subject: [PATCH 03/13] Drop support for PyTorch 2.5 and before **Summary:** We gate on the PyTorch version throughout the repo. Recently PyTorch 2.8 was released, so the oldest PyTorch version we need to support is 2.6. After this commit, we assume the user is running PyTorch 2.6+, and remove all references to the following variables, which are deprecated. ``` TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` **Test Plan:** CI [ghstack-poisoned] --- .github/workflows/regression_test.yml | 24 +- benchmarks/benchmark_aq.py | 47 +--- docs/source/pretraining.rst | 4 - docs/source/quick_start.rst | 6 - scripts/quick_start.py | 11 +- test/core/test_config.py | 5 +- test/dtypes/test_affine_quantized.py | 7 +- test/dtypes/test_affine_quantized_float.py | 9 - .../test_affine_quantized_tensor_parallel.py | 5 - test/dtypes/test_floatx.py | 6 +- test/dtypes/test_uint4.py | 15 +- test/dtypes/test_uintx.py | 45 +--- test/float8/test_base.py | 19 +- test/float8/test_compile.py | 14 +- test/float8/test_dtensor.py | 7 - test/float8/test_float8_utils.py | 4 - test/float8/test_fsdp.py | 7 - test/float8/test_fsdp2/test_fsdp2.py | 8 +- test/float8/test_fsdp2_tp.py | 7 - test/float8/test_fsdp_compile.py | 7 - test/float8/test_numerics_integration.py | 14 +- test/hqq/test_hqq_affine.py | 4 - test/integration/test_integration.py | 169 ++---------- test/prototype/moe_training/test_kernels.py | 10 +- test/prototype/test_autoround.py | 7 - test/prototype/test_awq.py | 9 +- test/prototype/test_codebook_coreml.py | 3 +- test/prototype/test_parq.py | 13 +- test/prototype/test_quantized_training.py | 36 +-- test/prototype/test_smoothquant.py | 17 +- .../pt2e/test_arm_inductor_quantizer.py | 28 +- test/quantization/pt2e/test_duplicate_dq.py | 6 +- test/quantization/pt2e/test_quantize_pt2e.py | 7 +- .../pt2e/test_quantize_pt2e_qat.py | 6 +- test/quantization/pt2e/test_representation.py | 6 +- .../pt2e/test_x86inductor_fusion.py | 11 +- .../pt2e/test_x86inductor_quantizer.py | 6 +- test/quantization/test_gptq.py | 6 - test/quantization/test_marlin_qqq.py | 2 - test/quantization/test_moe_quant.py | 21 +- test/quantization/test_qat.py | 124 +-------- test/quantization/test_quant_api.py | 83 ------ test/quantization/test_quant_primitives.py | 104 ++------ test/sparsity/test_fast_sparse_training.py | 4 +- test/sparsity/test_marlin.py | 2 - test/sparsity/test_sparse_api.py | 14 - test/test_low_bit_optim.py | 9 - test/test_ops.py | 33 +-- torchao/_executorch_ops.py | 61 +---- torchao/_models/llama/eval.py | 9 - torchao/_models/llama/generate.py | 19 +- torchao/_models/sam/eval_combo.py | 13 - torchao/dtypes/affine_quantized_tensor.py | 10 +- torchao/dtypes/fbgemm_fp8_tensor.py | 6 +- torchao/dtypes/nf4tensor.py | 7 +- torchao/dtypes/uintx/int4_cpu_layout.py | 42 +-- ...8_dynamic_activation_intx_weight_layout.py | 10 - .../dtypes/uintx/tensor_core_tiled_layout.py | 12 +- torchao/dtypes/uintx/uintx_layout.py | 27 +- torchao/float8/README.md | 8 - torchao/float8/__init__.py | 28 +- torchao/kernel/bsr_triton_ops.py | 10 +- torchao/kernel/intmm.py | 137 ++++------ torchao/kernel/intmm_triton.py | 20 +- torchao/ops.py | 12 +- torchao/optim/cpu_offload.py | 8 +- torchao/optim/subclass_4bit.py | 31 +-- torchao/optim/subclass_8bit.py | 30 +-- torchao/optim/subclass_fp8.py | 8 +- torchao/prototype/autoround/eval_autoround.py | 3 +- .../float8nocompile/examples/example.py | 4 - .../float8nocompile/test/fsdp_test.py | 4 - .../float8nocompile/test/train_test.py | 4 - torchao/prototype/hqq/hqq_tinygemm_linear.py | 7 +- .../mx_formats/inference_workflow.py | 20 +- torchao/prototype/mx_formats/kernels.py | 245 ++++++++---------- .../prototype/quantization/autoquant_v2.py | 25 +- .../int8_dynamic_activation_lut_tensor.py | 10 +- .../gguf/gguf_quantized_tensor.py | 10 +- torchao/prototype/spinquant/hadamard_utils.py | 13 +- torchao/quantization/README.md | 6 - torchao/quantization/autoquant.py | 36 +-- .../linear_activation_quantized_tensor.py | 10 +- .../quantization/linear_activation_scale.py | 12 +- ...inear_activation_weight_observed_tensor.py | 10 +- torchao/quantization/linear_quant_modules.py | 12 +- torchao/quantization/observer.py | 6 +- .../quantization/pt2e/_numeric_debugger.py | 12 +- torchao/quantization/pt2e/constant_fold.py | 8 +- torchao/quantization/pt2e/convert.py | 7 +- torchao/quantization/pt2e/observer.py | 7 - torchao/quantization/pt2e/prepare.py | 2 - torchao/quantization/pt2e/quantize_pt2e.py | 9 +- .../pt2e/quantizer/port_metadata_pass.py | 11 +- torchao/quantization/qat/linear.py | 6 +- torchao/quantization/quant_api.py | 113 ++++---- torchao/quantization/quant_primitives.py | 94 +++---- .../quantize_/common/kernel_preference.py | 5 +- .../workflows/float8/float8_tensor.py | 6 +- .../workflows/int4/int4_preshuffled_tensor.py | 6 +- .../quantize_/workflows/int4/int4_tensor.py | 6 +- torchao/quantization/utils.py | 9 +- ...t_tensor_linear_activation_quantization.py | 14 +- torchao/sparsity/training/__init__.py | 14 +- torchao/sparsity/training/autograd.py | 20 +- torchao/testing/pt2e/utils.py | 10 +- torchao/testing/utils.py | 5 - torchao/utils.py | 52 ++-- tutorials/quantize_vit/run_vit_b_quant.py | 6 - 109 files changed, 604 insertions(+), 1774 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 2453e7eaaf..0858076551 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -59,12 +59,6 @@ jobs: fail-fast: false matrix: include: - - name: CUDA 2.5.1 - runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' - gpu-arch-type: "cuda" - gpu-arch-version: "12.6" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CUDA 2.6 runs-on: linux.g5.12xlarge.nvidia.gpu torch-spec: 'torch==2.6.0' @@ -77,13 +71,13 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" + - name: CUDA 2.8 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.8.0' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + dev-requirements-overrides: "" - - name: CPU 2.5.1 - runs-on: linux.4xlarge - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu' - gpu-arch-type: "cpu" - gpu-arch-version: "" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CPU 2.6 runs-on: linux.4xlarge torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu' @@ -96,6 +90,12 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" + - name: CPU 2.8 + runs-on: linux.4xlarge + torch-spec: 'torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + dev-requirements-overrides: "" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index cdc6f6fe5a..7dd732debc 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -20,46 +20,26 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, -) def _int8wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod, **kwargs) + quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) def _int8da_int8w_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight(**kwargs), - set_inductor_config=False, - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod, **kwargs) + quantize_( + mod, + int8_dynamic_activation_int8_weight(**kwargs), + set_inductor_config=False, + ) def _int4wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy = kwargs.copy() - if "groupsize" in kwargs_copy: - kwargs_copy["group_size"] = kwargs_copy["groupsize"] - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int4_woqtensors(mod, **kwargs) + kwargs_copy = kwargs.copy() + if "groupsize" in kwargs_copy: + kwargs_copy["group_size"] = kwargs_copy["groupsize"] + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): @@ -195,13 +175,12 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): all_shapes = [ (20, 2048, 2048), ] print("_int8da_int8w_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -209,7 +188,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) print("_int8wo_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -218,7 +196,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): print("_int4wo_api") kwargs = {"groupsize": 32} - from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( diff --git a/docs/source/pretraining.rst b/docs/source/pretraining.rst index da9659b9a0..2f60719ec5 100644 --- a/docs/source/pretraining.rst +++ b/docs/source/pretraining.rst @@ -161,10 +161,6 @@ Below is a code snippet showing how to use it: from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 02b59c2430..0c01e992e0 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -95,16 +95,10 @@ it is also much faster! .. code:: py from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, benchmark_model, unwrap_tensor_subclass, ) - # Temporary workaround for tensor subclass + torch.compile - # Only needed for torch version < 2.5 - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - num_runs = 100 torch._dynamo.reset() example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) diff --git a/scripts/quick_start.py b/scripts/quick_start.py index 55c17a8684..6b56412f03 100644 --- a/scripts/quick_start.py +++ b/scripts/quick_start.py @@ -8,11 +8,7 @@ import torch from torchao.quantization import Int4WeightOnlyConfig, quantize_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - benchmark_model, - unwrap_tensor_subclass, -) +from torchao.utils import benchmark_model # ================ # | Set up model | @@ -50,11 +46,6 @@ def forward(self, x): # | Benchmark | # ============= -# Temporary workaround for tensor subclass + torch.compile -# Only needed for torch version < 2.5 -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - num_runs = 100 torch._dynamo.reset() example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) diff --git a/test/core/test_config.py b/test/core/test_config.py index fc752d989e..9574c3ec76 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -39,7 +39,6 @@ UIntXWeightOnlyConfig, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # Define test configurations as fixtures configs = [ @@ -85,11 +84,9 @@ ), AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING), AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"), + FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] -if TORCH_VERSION_AT_LEAST_2_6: - configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])] - # Create ids for better test naming def get_config_ids(configs): diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index bd5ed0c3b5..e27796bb74 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -41,7 +41,6 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, is_fbcode, @@ -151,11 +150,7 @@ def test_weights_only(self): with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) + _ = torch.load(f, weights_only=True) @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1dfed4dda8..d705b2cfe1 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -3,15 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import copy import io import random diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index c2eff77b07..fd5f43a470 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -24,7 +24,6 @@ ) from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 if common_utils.SEED is None: common_utils.SEED = 1234 @@ -127,10 +126,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 237bc2bd92..9a99ba0802 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -33,7 +33,7 @@ quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import is_fbcode _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -107,10 +107,6 @@ def test_to_copy_device(self, ebits, mbits): assert floatx_tensor_impl.device.type == "cpu" @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - reason="quantization only works with torch.compile for 2.5+", - ) @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index f7656ef19e..aa9eccc903 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -34,7 +34,6 @@ _replace_with_custom_fn_if_matches_filter, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): @@ -243,16 +242,10 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.export.texport_for_training( - m, - example_inputs, - ).module() - else: - m = torch._export.capture_pre_autograd_graph( - m, - example_inputs, - ).module() + m = torch.export.texport_for_training( + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 35c722365d..dbc69b8ee9 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -14,24 +14,16 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, -) -# torch.uintx dtypes are introduced in 2.3 -if TORCH_VERSION_AT_LEAST_2_3: - dtypes = ( - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) -else: - dtypes = () +dtypes = ( + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, +) group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] @@ -65,9 +57,6 @@ def forward(self, x): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") @@ -86,9 +75,6 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) @@ -103,9 +89,6 @@ def test_uintx_weight_only_model_quant(dtype, group_size, device): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_quant(dtype, group_size, device): input_float = torch.randn((1, 256), dtype=torch.float16, device=device) mapping_type = MappingType.SYMMETRIC @@ -140,9 +123,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_target_dtype(dtype): from torchao.quantization.quant_api import uintx_weight_only @@ -154,10 +134,6 @@ def test_uintx_target_dtype(dtype): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, - reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", -) def test_uintx_target_dtype_compile(dtype): from torchao.quantization.quant_api import uintx_weight_only @@ -170,9 +146,6 @@ def test_uintx_target_dtype_compile(dtype): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_model_size(dtype): from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c2b2c5488a..1f9ae19346 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -13,17 +13,6 @@ import torch import torch.nn as nn -from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - from torchao.float8.config import ( Float8LinearConfig, Float8LinearRecipeName, @@ -53,7 +42,13 @@ tensor_to_scale, ) from torchao.testing.training.test_utils import get_test_float8_linear_config -from torchao.utils import is_MI300, is_ROCM +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_MI300, + is_ROCM, + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index a196d87430..04f03bb0ee 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -10,16 +10,6 @@ from io import StringIO import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -42,6 +32,10 @@ ScaledMMConfig, ) from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index f357196785..7285d4bbc0 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -12,14 +12,7 @@ import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.distributed._tensor.common_dtensor import ( diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 888c7aadb1..c253af55ea 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -10,10 +10,6 @@ from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) # source for notable single-precision cases: diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 3017c8b539..a25bd53509 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -16,13 +16,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index ef87e5fcda..e7b3b8be91 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -10,13 +10,6 @@ from typing import Any, List, Optional import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - import torch import torch._dynamo.testing import torch.distributed as dist @@ -47,6 +40,7 @@ check_parity_bf16_mp, check_parity_no_mp, ) +from torchao.utils import is_sm_at_least_89 if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index 8a735c5865..ea93d5949d 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -13,14 +13,7 @@ import copy import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index a78a30925c..eb32c40aa3 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -12,13 +12,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index db02444109..8da36cef8e 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -10,16 +10,6 @@ from typing import Optional import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn import torch.nn.functional as F @@ -34,6 +24,10 @@ ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) torch.manual_seed(0) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index a6990549a3..728bf9378b 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -15,9 +15,6 @@ uintx_weight_only, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, -) cuda_available = torch.cuda.is_available() @@ -78,7 +75,6 @@ def _eval_hqq(dtype): @unittest.skipIf(not cuda_available, "Need CUDA available") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") class TestHQQ(unittest.TestCase): def _test_hqq( self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5514228f4b..6ff0e6f08c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,9 +40,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -79,10 +77,6 @@ ) from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, @@ -116,14 +110,7 @@ def _int8wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5 or ( - not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing - ): - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod) + quantize_(mod, int8_weight_only(set_inductor_config=False)) def _int8wo_groupwise_api(mod): @@ -135,18 +122,13 @@ def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, ): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight( - act_mapping_type=act_mapping_type, - set_inductor_config=False, - ), - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + set_inductor_config=False, + ), + ) def _int4wo_api(mod, use_hqq=False): @@ -163,18 +145,12 @@ def _int4wo_api(mod, use_hqq=False): mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False ) unwrap_tensor_subclass(mod) - elif TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int4_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) else: - change_linear_weights_to_int4_woqtensors(mod) + quantize_(mod, int4_weight_only(set_inductor_config=False)) def _int8da_int4w_api(mod): quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) # TODO: use this to reduce the number of tests @@ -393,7 +369,6 @@ def test_swap(self): assert torch.allclose(y_ref, y) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_weight_t_and_non_t_numerics_match(self): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format @@ -710,8 +685,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -730,8 +703,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": @@ -789,9 +760,6 @@ def _test_lin_weight_subclass_impl( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" - ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8DynamicallyQuantizedLinearWeight.from_float, @@ -808,9 +776,6 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8DynamicallyQuantizedLinearWeight.from_float, @@ -820,9 +785,6 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" @@ -836,9 +798,6 @@ def test_aq_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight2.from_float, @@ -848,9 +807,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight3.from_float, @@ -860,9 +816,6 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -892,9 +845,6 @@ def test_autoquantizable_flatten_unflatten(self): for device, dtype in COMMON_DEVICE_DTYPE ] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( @@ -919,9 +869,6 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): @@ -933,8 +880,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -953,8 +898,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") @unittest.skip("Skip to fix CI until we deprecate these APIs long term") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): @@ -1026,13 +969,6 @@ def _test_lin_weight_subclass_api_impl( ) ) def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): - if ( - not TORCH_VERSION_AT_LEAST_2_5 - and dtype in (torch.float16, torch.bfloat16) - and act_mapping is MappingType.ASYMMETRIC - and device == "cpu" - ): - self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, @@ -1042,12 +978,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype @@ -1055,9 +985,6 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." - ) @skip_if_rocm("Test flaky on ROCm, under investigation") def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() @@ -1066,8 +993,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1079,7 +1004,6 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.") def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1093,9 +1017,6 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater" - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): if dtype != torch.float16: @@ -1139,8 +1060,6 @@ def test_gemlite_layout(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: @@ -1162,16 +1081,9 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): kwargs_copy = kwargs.copy() - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy["group_size"] = groupsize - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - kwargs_copy["inner_k_tiles"] = inner_k_tiles - del kwargs_copy["layout"] - change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy)) self._test_lin_weight_subclass_api_impl( api, @@ -1252,11 +1164,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1289,11 +1197,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1395,18 +1299,10 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch 2.3+.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1416,9 +1312,6 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." - ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1467,7 +1360,6 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") @@ -1562,7 +1454,6 @@ class TestAutoQuant(unittest.TestCase): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) @@ -1604,7 +1495,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1626,9 +1516,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # Skip certain shapes on older PyTorch versions - if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") # TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} is flaky, skipping") @@ -1657,7 +1544,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1685,7 +1571,6 @@ def forward(self, x): assert len(_AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1735,7 +1620,6 @@ def test_autoquant_manual(self, device, dtype): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1745,9 +1629,6 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest("bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. - if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") class NeedsKwargs(torch.nn.Module): def __init__(self): @@ -1782,7 +1663,6 @@ def forward(self, x, y): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1835,9 +1715,6 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." - ) def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 @@ -1868,9 +1745,6 @@ def test_autoquant_hp_float(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_autoquant_int4wo(self, device, dtype): if device == "cpu": @@ -1906,9 +1780,6 @@ def test_autoquant_int4wo(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf( True, "Skipping for now, do to lowering bug in inductor" ) # TODO unblock when fixed @@ -1948,7 +1819,6 @@ def test_autoquant_float8(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" @@ -2011,7 +1881,6 @@ def forward(self, x): ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( @@ -2067,12 +1936,9 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - if TORCH_VERSION_AT_LEAST_2_5: - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() - else: - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export_for_training( + model, example_inputs, strict=True + ).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: @@ -2111,7 +1977,6 @@ class TestUtils(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index b24b61be8c..a10f41e696 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -7,15 +7,9 @@ import pytest import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - # We need to skip before doing any imports which would use triton, since -# triton won't be available on CPU builds and torch < 2.5 -if not ( - TORCH_VERSION_AT_LEAST_2_5 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 -): +# triton won't be available on CPU builds +if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9): pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/prototype/test_autoround.py b/test/prototype/test_autoround.py index 483704a28c..cf7f956a11 100644 --- a/test/prototype/test_autoround.py +++ b/test/prototype/test_autoround.py @@ -25,7 +25,6 @@ prepare_model_for_applying_auto_round_, ) from torchao.prototype.autoround.multi_tensor import MultiTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 _AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -92,9 +91,6 @@ def _check_params_and_buffers_type(module, check_fun): class TestAutoRound(TestCase): @pytest.mark.skip("these tests are broken on main branch") - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="Requires torch 2.5 or later" - ) @parametrize("device", _AVAILABLE_DEVICES) @torch.no_grad() def test_auto_round(self, device: str): @@ -136,9 +132,6 @@ def test_auto_round(self, device: str): assert after_quant is not None, "Quantized model forward pass failed" @pytest.mark.skip("these tests are broken on main branch") - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="Requires torch 2.5 or later" - ) @parametrize("device", _AVAILABLE_DEVICES) @torch.no_grad() def test_wrap_model_with_multi_tensor(self, device: str): diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 5538fa513d..181445470e 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -15,10 +15,7 @@ from torchao.prototype.awq import AWQConfig, AWQStep from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - _is_fbgemm_genai_gpu_available, -) +from torchao.utils import _is_fbgemm_genai_gpu_available class ToyLinearModel(torch.nn.Module): @@ -50,10 +47,6 @@ def forward(self, x): not _is_fbgemm_genai_gpu_available(), reason="need to install fbgemm_gpu_genai package", ) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, - reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig", -) class TestAWQ(TestCase): def test_awq_config(self): base_config = Int4WeightOnlyConfig() diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index 69956c7729..a9519f7321 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -14,7 +14,7 @@ ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least +from torchao.utils import is_package_at_least @unittest.skipIf( @@ -75,7 +75,6 @@ def test_quantize_api(self): ) assert type(m[0].weight) == CodebookQuantizedTensor - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.") def test_export(self): m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32) quantize_(m, CodebookWeightOnlyConfig(self.code_dtype, self.block_size)) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 6ceeb0d795..85a6e2b0c2 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -42,11 +42,7 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, - check_cpu_version, -) +from torchao.utils import check_cpu_version _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -198,7 +194,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16) @@ -215,7 +210,6 @@ def test_int4_weight_only(self, group_size: int = 32): model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3, 4, 8]) @common_utils.parametrize("group_size", [32, 512]) def test_intx_weight_only(self, b: int = 2, group_size: int = 32): @@ -233,7 +227,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): quantizer = UnifTorchaoQuantizer() compare_quantized_models(model, m_ref, quantizer, b, group_size) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") def test_int4_weight_only_e2e(self, group_size: int = 32): model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) @@ -255,7 +248,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ) compare_parq_convert(model, m_ref, optimizer, config) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3, 4, 8]) def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): @@ -305,7 +297,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32 torch.testing.assert_close(q, q_ref, atol=0, rtol=0) torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3]) @common_utils.parametrize("group_size", [32, 512]) def test_intx_weight_only(self, b: int = 2, group_size: int = 32): @@ -327,7 +318,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): compare_quantized_models(model, m_ref, quantizer, b, group_size) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3]) def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): @@ -359,7 +349,6 @@ class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3, 4, 8]) @common_utils.parametrize("model_dtype", [torch.float16, torch.float32]) @common_utils.parametrize("group_size", [32, 128]) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index c9d51389d1..836e2c302e 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -3,15 +3,9 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Requires torch>=2.4", allow_module_level=True) - import copy +import pytest import torch import torch.distributed as dist import torch.nn.functional as F @@ -312,21 +306,19 @@ def test_fsdp2_correctness(self): (bitnet_training(), mp_policy, 1e-5), ] - # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 - if TORCH_VERSION_AT_LEAST_2_6: - # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. - # We would need to cast all params to BF16 in forward and backward pass, while keeping - # the params in FP32 for optim step. - # torch.autocast() will only do this for F.linear() layer (and its backward). - # To keep it simple, we just use a larger tolerance here. - bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) - - extra_args = [ - (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), - (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), - (bitnet_training(), bf16_mp_policy, 1e-2), - ] - test_args.extend(extra_args) + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + + extra_args = [ + (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), + (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), + (bitnet_training(), bf16_mp_policy, 1e-2), + ] + test_args.extend(extra_args) self.run_subtests({"args": test_args}, self._run_subtest) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 568b2d964f..85893f2241 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -22,9 +22,6 @@ dequantize_per_channel, dynamically_quantize_per_channel, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) class ToyLinearModel(torch.nn.Module): @@ -56,9 +53,8 @@ class TestSmoothQuant(unittest.TestCase): @classmethod def setUpClass(cls): """Set up class-level configuration for tests.""" - if TORCH_VERSION_AT_LEAST_2_5: - # This test case will trigger recompilation many times, so set a large cache_size_limit here - torch._dynamo.config.cache_size_limit = 128 + # This test case will trigger recompilation many times, so set a large cache_size_limit here + torch._dynamo.config.cache_size_limit = 128 @unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") @common_utils.parametrize("bias", [True, False]) @@ -96,8 +92,7 @@ def forward(self, x): quantize_(m, SmoothQuantConfig(), is_observed_linear) # Apply compilation if supported - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) + m = torch.compile(m, fullgraph=True) # Step 2: Inference quantized model with torch.inference_mode(): @@ -213,8 +208,7 @@ def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): quantize_(m, SmoothQuantConfig(), is_observed_linear) # Apply compilation if supported - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) + m = torch.compile(m, fullgraph=True) # Step 2: Setup save/load model with recipe functionality insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) @@ -231,8 +225,7 @@ def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) - if TORCH_VERSION_AT_LEAST_2_5: - m_save_load = torch.compile(m_save_load, fullgraph=True) + m_save_load = torch.compile(m_save_load, fullgraph=True) # Step 5: Validate outputs on full dataset with torch.inference_mode(): diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py index 750e88d451..4c3b397382 100644 --- a/test/quantization/pt2e/test_arm_inductor_quantizer.py +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -6,12 +6,23 @@ # Owner(s): ["oncall: quantization"] import copy +import functools import itertools +import platform import unittest from enum import Enum import torch import torch.nn as nn +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skipIfNoInductorSupport, +) +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq from torchao.quantization.pt2e import ObserverBase @@ -26,22 +37,7 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( QUANT_ANNOTATION_KEY, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -import functools -import platform - -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, -) -from torch.testing._internal.common_quantization import ( - QuantizationTestCase, - skipIfNoInductorSupport, -) -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 def skipIfNoArm(fn): diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index a1b43b4f3a..8430f605e1 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -11,6 +11,7 @@ from typing import Any import torch +from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase from torch.testing._internal.common_utils import IS_WINDOWS, run_tests @@ -33,10 +34,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class TestHelperModules: diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 19f208a55c..0c1a1f23c9 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -19,6 +19,7 @@ per_channel_weight_observer_range_neg_127_to_127, weight_observer_range_neg_127_to_127, ) +from torch.export import export_for_training from torch.fx import Node from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -66,11 +67,7 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index d8a2c8df03..e0a51453a9 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -18,6 +18,7 @@ default_symmetric_qnnpack_qat_qconfig, ) from torch.ao.quantization.quantize_fx import prepare_qat_fx +from torch.export import export_for_training from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -51,10 +52,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class PT2EQATTestCase(QuantizationTestCase): diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index 2123995a4b..abe79a08e3 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -11,6 +11,7 @@ import torch from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -27,10 +28,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @skipIfNoQNNPACK diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ffaa4573d8..42439552c6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -26,6 +26,7 @@ IS_FBCODE, IS_LINUX, IS_X86, + TEST_ACL, instantiate_parametrized_tests, parametrize, ) @@ -45,15 +46,7 @@ X86InductorQuantizer, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_8, -) - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.testing._internal.common_utils import TEST_ACL -else: - TEST_ACL = False +from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 # The dict value is match_nodes(computation_op+unary_op) unary_list = { diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 4476b18697..9dc7da3571 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -35,10 +36,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class NodePosType(Enum): diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 98760f8cf6..163819bea7 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -12,9 +12,6 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, -) torch.manual_seed(0) @@ -101,7 +98,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -114,7 +110,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[0], tensor1)) self.assertTrue(torch.equal(mt.values[1], tensor2)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -126,7 +121,6 @@ def test_multitensor_pad_unpad(self): mt.unpad() self.assertEqual(mt.count, 1) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 8fe21c6bd3..56b309b948 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -24,7 +24,6 @@ _choose_qparams_and_quantize_affine_qqq, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skip_if_rocm("ROCm enablement in progress") @@ -67,7 +66,6 @@ def test_marlin_qqq(self): "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..fae4d8e41e 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -27,11 +27,7 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - is_sm_at_least_90, -) +from torchao.utils import is_sm_at_least_90 if torch.version.hip is not None: pytest.skip( @@ -116,8 +112,6 @@ def _test_impl_moe_quant( def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -142,8 +136,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig(Int4WeightOnlyConfig()) tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -164,8 +156,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -188,8 +178,6 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): def test_int8wo_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -208,9 +196,6 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -230,8 +215,6 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8DynamicActivationInt8WeightConfig(), @@ -255,8 +238,6 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): def test_int8dq_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index bb4bfe7f10..f591904fff 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -83,11 +83,6 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, -) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() @@ -194,9 +189,6 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -241,9 +233,6 @@ def test_fake_quantize_per_channel_group(self): ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -341,9 +330,6 @@ def _set_ptq_weight( else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_linear(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear @@ -374,9 +360,6 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -412,9 +395,6 @@ def test_qat_8da4w_quantizer(self): ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -426,9 +406,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -487,9 +464,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -586,9 +560,6 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -655,9 +626,6 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear @@ -693,18 +661,12 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer @@ -790,9 +752,6 @@ def test_composable_qat_quantizer(self): values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_embedding(self): from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, @@ -970,15 +929,14 @@ def test_fake_quantize_config_dtype(self): with self.assertRaisesRegex(ValueError, msg): IntxFakeQuantizeConfig(torch.float32, "per_token") # OK - if TORCH_VERSION_AT_LEAST_2_3: - IntxFakeQuantizeConfig(torch.uint1, "per_token") - IntxFakeQuantizeConfig(torch.uint2, "per_token") - IntxFakeQuantizeConfig(torch.uint3, "per_token") - IntxFakeQuantizeConfig(torch.uint4, "per_token") - IntxFakeQuantizeConfig(torch.uint5, "per_token") - IntxFakeQuantizeConfig(torch.uint6, "per_token") - IntxFakeQuantizeConfig(torch.uint7, "per_token") - IntxFakeQuantizeConfig(torch.uint8, "per_token") + IntxFakeQuantizeConfig(torch.uint1, "per_token") + IntxFakeQuantizeConfig(torch.uint2, "per_token") + IntxFakeQuantizeConfig(torch.uint3, "per_token") + IntxFakeQuantizeConfig(torch.uint4, "per_token") + IntxFakeQuantizeConfig(torch.uint5, "per_token") + IntxFakeQuantizeConfig(torch.uint6, "per_token") + IntxFakeQuantizeConfig(torch.uint7, "per_token") + IntxFakeQuantizeConfig(torch.uint8, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT1, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT2, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT3, "per_token") @@ -1003,9 +961,6 @@ def test_fake_quantize_config_dynamic_and_range_learning(self): torch.int8, "per_channel", is_dynamic=True, range_learning=True ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -1059,9 +1014,6 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -1108,9 +1060,6 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_8da4w(self): module = torch.nn.ModuleList( [ @@ -1130,9 +1079,6 @@ def test_replace_linear_8da4w(self): assert isinstance(module[0], Int8DynActInt4WeightQATLinear) assert isinstance(module[1], Int8DynActInt4WeightQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_int4(self): module = torch.nn.ModuleList( [torch.nn.Linear(in_features=256, out_features=50, bias=True)] @@ -1165,9 +1111,6 @@ def test_replace_linear_int4(self): ) assert isinstance(module[0], Int4WeightOnlyQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -1205,9 +1148,6 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. @@ -1261,9 +1201,6 @@ def test_qat_prototype_bc(self): Int8DynActInt4WeightQATQuantizer, ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_config_init(self): """ Test that the correct errors are thrown if `QATConfig` is not instantiated properly. @@ -1317,9 +1254,6 @@ def test_qat_config_init(self): ): QATConfig(fq_config, step="prepare") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_prepare(self): """ Test that the following: @@ -1368,9 +1302,6 @@ def test_quantize_api_prepare(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_errors(self): """ Test that we throw exceptions with helpful error messages if `quantize_` @@ -1390,9 +1321,6 @@ def test_quantize_api_errors(self): with self.assertRaisesRegex(ValueError, "does not have QAT support"): quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_e2e(self): """ Test that the following: @@ -1441,9 +1369,6 @@ def test_quantize_api_e2e(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantize_config_torch_intx(self): """ Test that `IntxFakeQuantizeConfig` works with torch.intx. @@ -1461,9 +1386,6 @@ def test_fake_quantize_config_torch_intx(self): out2 = linear2(*x2) torch.testing.assert_close(out1, out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantizer_repr(self): """ Test that `repr(IntxFakeQuantizer(config))` exposes useful config details. @@ -1476,9 +1398,6 @@ def test_fake_quantizer_repr(self): self.assertTrue("PerGroup" in fake_quantizer_repr) self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_linear_bias(self): """ Test that QAT supports linear bias. @@ -1494,9 +1413,6 @@ def test_qat_linear_bias(self): m(*example_inputs) @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): """ Test that the following produce the exact same numerics: @@ -1514,9 +1430,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): """ Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces @@ -1555,9 +1468,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): ) self.assertEqual(len(non_inf_sqnr), 0, fail_message) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_config_eps(self): """ Test that users can set arbitrary eps value in `IntxFakeQuantizeConfig`. @@ -1584,9 +1494,6 @@ def test_fake_quantize_config_eps(self): actual_out = fake_quantizer(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_eps(self): """ Test that the 8da4w QAT flow uses the expected eps. @@ -1633,9 +1540,6 @@ def test_qat_8da4w_eps(self): actual_out = converted_model.linear1(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantizer_range_learning(self): """ Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly. @@ -1671,9 +1575,6 @@ def test_fake_quantizer_range_learning(self): self.assertTrue(fake_quantizer.zero_point.requires_grad) fake_quantizer(*example_inputs) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_range_learning(self): """ Test end-to-end QAT flow with range learning. @@ -1754,9 +1655,6 @@ def test_float8_rowwise_fake_quantize(self): ).to_original_precision() torch.testing.assert_close(out, out_expected, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_qat_fp8a4w_quantizer(self): """ Test basic model training with `Float8ActInt4WeightQATQuantizer`. @@ -1791,9 +1689,6 @@ def test_qat_fp8a4w_quantizer(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_legacy_quantize_api_e2e(self): """ Test that the following two APIs are numerically equivalent: @@ -1845,9 +1740,6 @@ def test_legacy_quantize_api_e2e(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_api_deprecation(self): """ Test that the appropriate deprecation warning is logged exactly once per class. diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b9d99e7ac7..3b26cd25d6 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -66,10 +66,6 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, @@ -279,7 +275,6 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() @@ -308,9 +303,6 @@ def api(model): atol, rtol = (1e-2, 1e-2) if torch.version.hip else (None, None) torch.testing.assert_close(ref, res.cpu(), atol=atol, rtol=rtol) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -323,9 +315,6 @@ def test_8da4w_quantizer(self): assert isinstance(m.linear2, Int8DynActInt4WeightLinear) m(*example_inputs) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer_linear_bias(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -444,7 +433,6 @@ def test_eval_wrapper_llama3(self): ) # TODO: move to a separate test file - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize( "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] ) @@ -484,8 +472,6 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_int4(self): for device in self.GPU_DEVICES: @@ -512,7 +498,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -532,50 +517,6 @@ def test_quantized_tensor_subclass_int8_wo(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") - def test_quantized_tensor_subclass_int8_dyn_quant(self): - # use multiples of 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs( - batch_size=20, dtype=torch.bfloat16, device="cuda" - ) - quantize_(m, int8_dynamic_activation_int8_weight()) - - assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance( - m.linear1.weight.original_weight_tensor, AffineQuantizedTensor - ) - assert isinstance( - m.linear2.weight.original_weight_tensor, AffineQuantizedTensor - ) - - # reference - _ref_change_linear_weights_to_int8_dqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - # workaround for export path - from torchao.utils import unwrap_tensor_subclass - - m_unwrapped = unwrap_tensor_subclass(m) - - m = torch.export.export(m_unwrapped, example_inputs, strict=True).module() - exported_model_res = m(*example_inputs) - - self.assertTrue(torch.equal(exported_model_res, ref)) - - # make sure it compiles - torch._export.aot_compile(m_unwrapped, example_inputs) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -594,7 +535,6 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -608,25 +548,6 @@ def test_int8wo_quantized_model_to_device(self): cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") - def test_int4wo_quantized_model_to_device(self): - # TODO: change initial model to "cpu" - devices = ["cuda", "cuda:0"] - for device in devices: - m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - - quantize_(m, int4_weight_only()) - ref = m(*example_inputs) - - example_inputs_cuda = (example_inputs[0].to(device),) - m.to(device=device) - cuda_res = m(*example_inputs_cuda) - self.assertEqual(cuda_res.cpu(), ref) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load_map_location(self): m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") @@ -648,7 +569,6 @@ def test_quantized_tensor_subclass_save_load_map_location(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_model_streaming(self): def reset_memory(): @@ -671,7 +591,6 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("use_hqq", [True, False]) @@ -698,7 +617,6 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "aten.mm.default" not in code[0] # TODO(#1690): move to new config names - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "config", @@ -795,7 +713,6 @@ def test_module_fqn_to_config_module_name(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+") def test_module_fqn_to_config_embedding_linear(self): weight_dtype = torch.int8 granularity = PerGroup(8) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 12027243a8..f3d265e14a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -29,10 +29,6 @@ groupwise_affine_quantize_tensor_from_qparams, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, check_cpu_version, check_xpu_version, is_fbcode, @@ -132,11 +128,10 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) - if TORCH_VERSION_AT_LEAST_2_5: - if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) - if check_xpu_version(w.device): - w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) + if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if check_xpu_version(w.device): + w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) return w_int4x8 @@ -175,9 +170,6 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -264,34 +256,21 @@ def test_choose_qparams_group_sym_no_clipping_err(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 10) - if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - scale_dtype=torch.float64, - zero_point_dtype=torch.int64, - ) - else: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - ) - + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) scale_ref, zp_ref = ( torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( input, dtype @@ -347,9 +326,6 @@ def test_choose_qparams_tensor_sym(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) @@ -380,17 +356,11 @@ def test_quantize_activation_per_token_abs_max(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) @@ -404,9 +374,6 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) @@ -449,9 +416,6 @@ def test_quantize_dequantize_group_sym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) @@ -493,9 +457,6 @@ def test_quantize_dequantize_channel_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -535,9 +496,6 @@ def test_quantize_dequantize_tensor_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) @@ -578,9 +536,6 @@ def test_quantize_dequantize_channel_asym_4d(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC @@ -726,32 +681,22 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]: if zero_point_domain == ZeroPointDomain.INT: zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) - if TORCH_VERSION_AT_LEAST_2_5: - input_tmp = input - if (not (check_cpu_version(input.device))) and ( - not (check_xpu_version(input.device)) - ): - input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - if check_xpu_version(input.device): - input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain - ) - else: - if zero_point_domain == ZeroPointDomain.INT: - continue - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) + input_tmp = input + if (not (check_cpu_version(input.device))) and ( + not (check_xpu_version(input.device)) + ): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + if check_xpu_version(input.device): + input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain + ) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( input, scales, zeros, n_bit, groupsize, zero_point_domain ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -785,9 +730,6 @@ def test_fake_quantize_affine(self): ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 804a585dd8..424306f897 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,7 +15,7 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode +from torchao.utils import is_fbcode class ToyModel(nn.Module): @@ -32,7 +32,6 @@ def forward(self, x): class TestRuntimeSemiStructuredSparsity(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") @@ -81,7 +80,6 @@ def test_runtime_weight_sparsification(self): for name, mod in model_c.named_modules(): assert not isinstance(mod, SemiSparseLinear) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 783de6c6ae..3cf310d71f 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -20,7 +20,6 @@ from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): @@ -58,7 +57,6 @@ def test_quant_sparse_marlin_layout_eager(self): "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 5e3086c411..30a063bf79 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -18,12 +18,6 @@ quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, -) logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -31,7 +25,6 @@ class TestSemiStructuredSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skip("Temporarily skipping to unpin nightlies") def test_sparse(self): @@ -59,7 +52,6 @@ def test_sparse(self): class TestQuantSemiSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [False]) @unittest.skip("Temporarily skip to unbreak CI") @@ -97,7 +89,6 @@ def test_quant_semi_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse_marlin(self, compile): @@ -132,10 +123,6 @@ def test_sparse_marlin(self, compile): class TestBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, - "pytorch 2.4+ feature due to need for custom op support", - ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("input_shape", [1, 1024]) @@ -170,7 +157,6 @@ def test_sparse(self, compile, input_shape): class TestQuantBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse(self, compile): diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index 00c30b919a..64df37ac88 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -41,7 +41,6 @@ from torchao.optim.subclass_fp8 import OptimStateFp8 from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7, get_available_devices, ) @@ -222,8 +221,6 @@ def test_param_groups(self, optim_name, device): @parametrize("device", _DEVICES) def test_subclass_slice(self, subclass, shape, device): if subclass == OptimStateFp8: - if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5") if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 CUDA requires compute capability >= 8.9") @@ -469,9 +466,6 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return _FSDP_WORLD_SIZE - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): @@ -587,9 +581,6 @@ def _test_fsdp2(self, args): v2 = v2.dequantize() self.assertEqual(v1, v2) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): diff --git a/test/test_ops.py b/test/test_ops.py index faec689a69..bc9fe0e4f9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,7 +28,6 @@ ) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, ) @@ -281,25 +280,21 @@ def make_test_id(param): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) assert torch.equal(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ @@ -308,13 +303,10 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) opcheck( @@ -345,7 +337,6 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -413,7 +404,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -438,8 +428,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( # Unpack and dequantize unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) dq_ao = groupwise_affine_dequantize_tensor_from_qparams( unpacked, scales, zeros, n_bit=4, groupsize=group_size @@ -479,7 +468,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -488,8 +476,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - if TORCH_VERSION_AT_LEAST_2_5: - q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) @@ -501,9 +488,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size "test_autograd_registration", "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, (packed_w, scales_and_zeros, group_size, inner_k_tiles), @@ -766,9 +751,7 @@ def test_swizzle_mm(): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda") mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda") diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 4b761ad725..5d680bcf82 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -12,37 +12,17 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_channel_group to mitigate availability issue until it can be supplanted by new quantize_affine function. - - torch.ops.quantized_decomposed.quantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs ) @@ -50,50 +30,21 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_channel_group to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.quantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index cc4e439a49..57b67ab16e 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -28,7 +28,6 @@ quantize_, uintx_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass def run_evaluation( @@ -151,9 +150,6 @@ def run_evaluation( model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) quantizer.quantize(model, *inputs) model = model.to(device) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) if "float8wo" in quantization: quantize_(model, float8_weight_only()) if "float8dq" in quantization: @@ -239,11 +235,6 @@ def run_evaluation( ) elif quantization.startswith("awq-uintx"): from torchao._models._eval import TransformerEvalWrapper - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() from torchao.prototype.awq import ( AWQObservedLinear, awq_uintx, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8f02e83a99..0a18e41c39 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -20,11 +20,7 @@ write_json_result_ossci, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - get_model_size_in_bytes, -) +from torchao.utils import get_model_size_in_bytes torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False torch.backends.cuda.enable_cudnn_sdp(True) @@ -356,7 +352,6 @@ def ffn_or_attn_only(mod, fqn): uintx_weight_only, ) from torchao.quantization.granularity import PerRow, PerTensor - from torchao.utils import unwrap_tensor_subclass if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant @@ -505,11 +500,6 @@ def ffn_or_attn_only(mod, fqn): ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() from torchao.prototype.awq import ( AWQObservedLinear, awq_uintx, @@ -567,9 +557,6 @@ def ffn_or_attn_only(mod, fqn): group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "int8_dynamic_activation_intx_weight requires torch2.6+" - ) assert precision == torch.float32, ( "int8_dynamic_activation_intx_weight requires using precision=torch.float32" ) @@ -829,10 +816,6 @@ def ffn_or_attn_only(mod, fqn): model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # standalone sparsity elif sparsity: from torchao.sparsity import semi_sparse_weight, sparsify_ diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index a0410fb734..97bb04ef8b 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -28,7 +28,6 @@ quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass torch._dynamo.config.cache_size_limit = 50000 @@ -364,10 +363,6 @@ def mlp_only(mod, name): if compress == "int8_dynamic_quant": quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress == "sparse_mlp_only": def mlp_only(mod, name): @@ -395,10 +390,6 @@ def mlp_only(mod, name): mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress == "int4_weight_only_sparse": # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) @@ -415,10 +406,6 @@ def mlp_only(mod, name): mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress is not None and "autoquant_v2" in compress: example_input = torch.randn( diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index f4386e43ad..63e0dcc562 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -35,10 +35,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor logger = logging.getLogger(__name__) aten = torch.ops.aten @@ -613,6 +610,5 @@ def _apply_fn_to_data(self, fn): # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([AffineQuantizedTensor]) +# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([AffineQuantizedTensor]) diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index 85f83bcb50..6f007c9339 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -11,7 +11,6 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults, ) @@ -265,6 +264,5 @@ def _(func, types, args, kwargs): to_fbgemm_fp8 = FbgemmFp8Tensor.from_float -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmFp8Tensor]) +# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4764e8b69b..5542a9de58 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -15,8 +15,6 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed.device_mesh import DeviceMesh -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -1156,6 +1154,5 @@ def nf4_constructor( ) -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([NF4Tensor]) - torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index da19bbc259..cd09eec452 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -21,11 +21,7 @@ ZeroPointDomain, _quantize_affine_tinygemm, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - fill_defaults, -) +from torchao.utils import fill_defaults aten = torch.ops.aten @@ -114,29 +110,13 @@ def from_plain( ): assert isinstance(_layout, Int4CPULayout) - if TORCH_VERSION_AT_LEAST_2_6: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( - int_data, - 1, # TODO:remove - ) - elif TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - else: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) @@ -284,8 +264,7 @@ def _is_float(dtype): def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): return ( - TORCH_VERSION_AT_LEAST_2_6 - and is_device(input_tensor.device.type, "cpu") + is_device(input_tensor.device.type, "cpu") and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) and not is_traceable_wrapper_subclass(input_tensor) @@ -300,9 +279,6 @@ def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_6, ( - f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" - ) assert is_device(input_tensor.device.type, "cpu"), ( f"For CPU device only but got: {input_tensor.device}" ) diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index dc7b073f32..fb75f3380b 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -19,7 +19,6 @@ _DTYPE_TO_QVALUE_BOUNDS, ZeroPointDomain, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -170,9 +169,6 @@ def from_plain( if layout.target != Target.ATEN: _check_torchao_ops_loaded() else: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "aten target is requires torch version > 2.6.0" - ) assert torch.backends.kleidiai.is_available(), ( "ATEN target requires torch.backends.kleidiai.is_available()" ) @@ -378,7 +374,6 @@ def _impl_2d_aten(input_tensor, weight_tensor): ) if target == Target.ATEN: - assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" _impl_2d = _impl_2d_aten else: _impl_2d = _impl_2d_non_aten @@ -420,11 +415,6 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( Constructs an AffineQuantizedTensor with PackedLinearInt8DynamicActivationIntxWeightLayout from plain data. """ - # TORCH_VERSION_AT_LEAST_2_6 is needed for torch.intx with x < 8 - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0" - ) - layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target) bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype] diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 591d9a9be1..992294b766 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -24,7 +24,6 @@ _quantize_affine_tinygemm, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple, ) @@ -274,14 +273,9 @@ def from_plain( ) def quant_2d(int_data_2d): - if TORCH_VERSION_AT_LEAST_2_5: - int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( - torch.uint8 - ) - else: - assert int_data_2d.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) return torch.ops.aten._convert_weight_to_int4pack( int_data_2d.contiguous(), _layout.inner_k_tiles ) diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 96e5401de5..3180e9f2c9 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -14,7 +14,7 @@ from torchao.dtypes.utils import ( Layout, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor from .bitpacking import pack, unpack @@ -24,20 +24,17 @@ _DTYPE_TO_BIT_WIDTH = {} _BIT_WIDTH_TO_DTYPE = {} -if TORCH_VERSION_AT_LEAST_2_3: - _DTYPE_TO_BIT_WIDTH = { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - - _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} -else: - print("uintx feature requires torch 2.3+, please upgrade pytorch") +_DTYPE_TO_BIT_WIDTH = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, +} + +_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} class UintxTensor(TorchAOBaseTensor): diff --git a/torchao/float8/README.md b/torchao/float8/README.md index ede3f66b3d..8856c9140a 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -27,10 +27,6 @@ import time import torch import torch.nn as nn from torchao.float8 import convert_to_float8_training, Float8LinearConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input M, K, N = 4096, 8192, 4096 @@ -232,10 +228,6 @@ import torch.nn.functional as F from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 170d0ddd81..04589312a2 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -1,4 +1,7 @@ # Lets define a few top level things here +# Needed to load Float8TrainingTensor with weights_only = True +from torch.serialization import add_safe_globals + from torchao.float8.config import ( CastConfig, Float8GemmConfig, @@ -19,22 +22,17 @@ from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torchao.float8.inference import Float8MMConfig from torchao.float8.types import FP8Granularity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if TORCH_VERSION_AT_LEAST_2_5: - # Needed to load Float8TrainingTensor with weights_only = True - from torch.serialization import add_safe_globals - add_safe_globals( - [ - Float8TrainingTensor, - ScaledMMConfig, - GemmInputRole, - LinearMMConfig, - Float8MMConfig, - ScalingGranularity, - ] - ) +add_safe_globals( + [ + Float8TrainingTensor, + ScaledMMConfig, + GemmInputRole, + LinearMMConfig, + Float8MMConfig, + ScalingGranularity, + ] +) __all__ = [ # configuration diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py index 18cfba9ad9..4d80c4c577 100644 --- a/torchao/kernel/bsr_triton_ops.py +++ b/torchao/kernel/bsr_triton_ops.py @@ -9,15 +9,7 @@ from typing import Optional import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if TORCH_VERSION_AT_LEAST_2_4: - from torch._dynamo.utils import warn_once -else: - import warnings - - warn_once = warnings.warn +from torch._dynamo.utils import warn_once from torch.sparse._triton_ops import ( broadcast_batch_dims, launch_kernel, diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 2f064b3f2f..2fd9854c17 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -8,17 +8,13 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, check_cpu_version +from torchao.utils import check_cpu_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) try: - # Only works for torch2.2 or newer. - if TORCH_VERSION_AT_LEAST_2_2: - from torchao.kernel import intmm_triton - else: - intmm_triton = None + from torchao.kernel import intmm_triton except ImportError: logger.warning( "Warning: Detected no triton, on systems without Triton certain kernels will not work" @@ -28,85 +24,66 @@ AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0))) -# torch._int_mm doesn't exist before 2.2 -if TORCH_VERSION_AT_LEAST_2_2: - from torch._dynamo import is_compiling as dynamo_is_compiling - from torch._higher_order_ops.out_dtype import out_dtype - - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a safe integer matrix multiplication, considering different paths for - torch.compile, cublas, and fallback cases. - - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. - - Returns: - torch.Tensor: The result of the matrix multiplication. - - Raises: - AssertionError: If the tensors are not on the same device. - """ - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - if input.device.type == "cpu": - # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend - return out_dtype( - torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() - ) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert mat2.device == input.device, ( - f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - ) - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 - ) +from torch._dynamo import is_compiling as dynamo_is_compiling +from torch._higher_order_ops.out_dtype import out_dtype + + +def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + """ + Performs a safe integer matrix multiplication, considering different paths for + torch.compile, cublas, and fallback cases. + + Args: + input (torch.Tensor): The input tensor of shape [i, j]. + mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul( - input.cpu().to(torch.int32), mat2.cpu().to(torch.int32) - ).to(input.device.type) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - try: - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( - torch.int32 + Returns: + torch.Tensor: The result of the matrix multiplication. + + Raises: + AssertionError: If the tensors are not on the same device. + """ + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + if input.device.type == "cpu": + # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend + return out_dtype( + torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() ) -else: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a fallback integer matrix multiplication for torch versions before 2.2. + # error checking for cublas path + assert mat2.device == input.device, ( + f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + ) + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 + ) - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) - Returns: - torch.Tensor: The result of the matrix multiplication in int32. - """ - # We can improve on this by writing Triton code that works for older versions of Triton - # that ship with 2.1 or 2.0. + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( torch.int32 ) diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 1a516a7163..6f657cdfd8 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -10,7 +10,6 @@ import triton.language as tl from torchao.kernel.autotuner import get_best_config_fn -from torchao.utils import TORCH_VERSION_AFTER_2_5 # TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option int8_mm_kernel_configs = sum( @@ -38,16 +37,15 @@ [], ) -if TORCH_VERSION_AFTER_2_5: - if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": - int8_mm_kernel_configs = [ - (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) - for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( - [16, 32, 64, 128, 256], repeat=3 - ) - for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] - for num_warps in [2, 4, 8] - ] +if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": + int8_mm_kernel_configs = [ + (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] + for num_warps in [2, 4, 8] + ] # Baseline configs from pytorch/pytorch diff --git a/torchao/ops.py b/torchao/ops.py index babe5506c0..4b643cae98 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -9,8 +9,6 @@ import torch from torch import Tensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") lib.define( "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" @@ -74,20 +72,14 @@ def register_custom_op(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - return torch.library.impl(f"{name}", "CUDA")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator diff --git a/torchao/optim/cpu_offload.py b/torchao/optim/cpu_offload.py index cca55749db..53acd4057f 100644 --- a/torchao/optim/cpu_offload.py +++ b/torchao/optim/cpu_offload.py @@ -8,7 +8,7 @@ import torch from torch.optim.optimizer import Optimizer, ParamsT -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices +from torchao.utils import get_available_devices # NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR @@ -36,11 +36,7 @@ def __init__( kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW - if ( - optimizer_class is torch.optim.AdamW - and TORCH_VERSION_AT_LEAST_2_4 - and "fused" not in kwargs - ): + if optimizer_class is torch.optim.AdamW and "fused" not in kwargs: kwargs.update(fused=True) param_groups = list(params) diff --git a/torchao/optim/subclass_4bit.py b/torchao/optim/subclass_4bit.py index bc5fd33414..82bb6a3788 100644 --- a/torchao/optim/subclass_4bit.py +++ b/torchao/optim/subclass_4bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -113,25 +110,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState4bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - self.shape, - ) - - OptimState4bit.to = _to - del _to # make sure to not re-use - - @OptimState4bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -268,7 +246,4 @@ def _(func, types, args, kwargs): return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState4bit]) +add_safe_globals([OptimState4bit]) diff --git a/torchao/optim/subclass_8bit.py b/torchao/optim/subclass_8bit.py index d3f7634526..bbc6cfa958 100644 --- a/torchao/optim/subclass_8bit.py +++ b/torchao/optim/subclass_8bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -101,24 +98,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState8bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - ) - - OptimState8bit.to = _to - del _to # make sure to not re-use - - @OptimState8bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -237,7 +216,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState8bit]) +add_safe_globals([OptimState8bit]) diff --git a/torchao/optim/subclass_fp8.py b/torchao/optim/subclass_fp8.py index 1ae670dd6d..e898932138 100644 --- a/torchao/optim/subclass_fp8.py +++ b/torchao/optim/subclass_fp8.py @@ -7,9 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -192,7 +193,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimStateFp8]) +add_safe_globals([OptimStateFp8]) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 16c1736843..04864e546a 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -12,7 +12,6 @@ import torchao import torchao.prototype.autoround.utils as ar_utils import torchao.quantization -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 logger = logging.getLogger(__name__) @@ -165,7 +164,7 @@ def main(args): bench_accuracy(model, tokenizer, tasks=args.tasks, msg=msg) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) diff --git a/torchao/prototype/float8nocompile/examples/example.py b/torchao/prototype/float8nocompile/examples/example.py index 97d42eee90..1351e2c938 100644 --- a/torchao/prototype/float8nocompile/examples/example.py +++ b/torchao/prototype/float8nocompile/examples/example.py @@ -9,10 +9,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = ( diff --git a/torchao/prototype/float8nocompile/test/fsdp_test.py b/torchao/prototype/float8nocompile/test/fsdp_test.py index 4e73fb9b97..375e48311d 100644 --- a/torchao/prototype/float8nocompile/test/fsdp_test.py +++ b/torchao/prototype/float8nocompile/test/fsdp_test.py @@ -22,10 +22,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index 3f2ee47cd7..aceca5b400 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -11,10 +11,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index f15c9a8104..8f049b431b 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -17,7 +17,7 @@ from torch import Tensor, nn from torchao.dtypes.utils import is_device -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, check_cpu_version +from torchao.utils import check_cpu_version class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -209,9 +209,8 @@ def hqq_quants_to_torch_quants( .reshape(shape) .contiguous() ) - if TORCH_VERSION_AT_LEAST_2_5: - if not is_device(W_q.device.type, "cpu"): - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 133cedee74..96c4c6c73b 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -25,7 +25,6 @@ register_quantize_module_handler, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, ) @@ -213,16 +212,15 @@ def _nvfp4_inference_linear_transform( return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - MXTensor, - NVFP4Tensor, - NVFP4MMConfig, - MXGemmKernelChoice, - _input_activation_quant_func_mxfp, - ] - ) +torch.serialization.add_safe_globals( + [ + MXTensor, + NVFP4Tensor, + NVFP4MMConfig, + MXGemmKernelChoice, + _input_activation_quant_func_mxfp, + ] +) import torch.nn as nn diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index f506681223..cabb61276a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -17,7 +17,6 @@ _floatx_unpacked_to_f32, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100, ) @@ -25,7 +24,7 @@ # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert # at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): +if has_triton(): from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( @@ -752,7 +751,6 @@ def triton_f4_to_scaled_bf16( Output: a tensor of bfloat16 values, multiplied by the encoded scale """ s_e8m0 = s_e8m0.view(torch.uint8) - assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" new_shape = (*x.shape[:-1], x.shape[-1] * 2) output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() @@ -855,119 +853,104 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: return output -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) +@torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) +def triton_f6_e2m3_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 + packed_mx_block_size = 3 * mx_block_size // 4 - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E2M3, - mbits_f6=MBITS_F6_E2M3, - f6_exp_bias=F6_E2M3_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E2M3, + mbits_f6=MBITS_F6_E2M3, + f6_exp_bias=F6_E2M3_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 +@torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) +def triton_f6_e3m2_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + packed_mx_block_size = 3 * mx_block_size // 4 - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - n_mx_blocks = x.numel() // packed_mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E3M2, - mbits_f6=MBITS_F6_E3M2, - f6_exp_bias=F6_E3M2_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - @triton_f6_e3m2_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) + n_mx_blocks = x.numel() // packed_mx_block_size + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E3M2, + mbits_f6=MBITS_F6_E3M2, + f6_exp_bias=F6_E3M2_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @triton_f6_e2m3_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) -else: +@triton_f6_e3m2_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") +@triton_f6_e2m3_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) # pack/unpack code copy-pasted from @@ -1049,48 +1032,42 @@ def pack_uint6_pytorch(uint8_data: torch.Tensor) -> torch.Tensor: ).view(packed_shape) -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::pack_uint6", mutates_args=()) - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # ensure input data is contiguous before passing to kernel - assert uint8_data.is_contiguous() +@torch.library.custom_op("ao::pack_uint6", mutates_args=()) +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + # ensure input data is contiguous before passing to kernel + assert uint8_data.is_contiguous() - # tensor should already be of shape [..., mx_block_size] - mx_block_size = uint8_data.shape[-1] - assert mx_block_size % 4 == 0 + # tensor should already be of shape [..., mx_block_size] + mx_block_size = uint8_data.shape[-1] + assert mx_block_size % 4 == 0 - # effective mx block size since we're packing 2 fp4 into 1 uint8 - packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] - n_mx_blocks = uint8_data.numel() // mx_block_size + # effective mx block size since we're packing 2 fp4 into 1 uint8 + packed_mx_block_size = 3 * mx_block_size // 4 + packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] + n_mx_blocks = uint8_data.numel() // mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - # contiguous uint8 container in which we can store the unpacked tensor - packed_uint8_data = torch.empty( - packed_shape, dtype=torch.uint8, device=uint8_data.device - ) + # contiguous uint8 container in which we can store the unpacked tensor + packed_uint8_data = torch.empty( + packed_shape, dtype=torch.uint8, device=uint8_data.device + ) - triton_pack_uint6_kernel[grid]( - uint8_data, - packed_uint8_data, - n_mx_blocks, - MX_BLOCK_SIZE=mx_block_size, - PACKED_MX_BLOCK_SIZE=packed_mx_block_size, - ) + triton_pack_uint6_kernel[grid]( + uint8_data, + packed_uint8_data, + n_mx_blocks, + MX_BLOCK_SIZE=mx_block_size, + PACKED_MX_BLOCK_SIZE=packed_mx_block_size, + ) - return packed_uint8_data + return packed_uint8_data - @pack_uint6.register_fake - def _(uint8_data): - out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) - return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) -else: - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # Dummy placeholder op for torch < 2.4 - raise AssertionError("fp6 packing unsupported without torch >= 2.4") +@pack_uint6.register_fake +def _(uint8_data): + out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) + return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) if TORCH_VERSION_AT_LEAST_2_7 and has_triton(): diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 9ddfddda08..1240bbacd0 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -47,8 +47,6 @@ ) from torchao.quantization.utils import _quantize_activation_per_token_absmax from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, @@ -469,6 +467,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -483,24 +483,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) return res diff --git a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py index c2e995e942..a15ea944fd 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py +++ b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py @@ -9,10 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -231,6 +228,5 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int8DynamicActivationLutTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int8DynamicActivationLutTensor]) +# Allow a model with Int8DynamicActivationLutTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int8DynamicActivationLutTensor]) diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index c1272fceb6..f26083b90d 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -14,10 +14,7 @@ _dequantize_gguf, _quantize_gguf, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor _QK_K = 256 aten = torch.ops.aten @@ -267,6 +264,5 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([GGUFQuantizedTensor]) +# Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([GGUFQuantizedTensor]) diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index 515a38ad83..0b276a0d03 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -11,7 +11,6 @@ import torch -from torchao.ops import lib from torchao.prototype.spinquant._hadamard_matrices import ( get_had12, get_had20, @@ -26,7 +25,6 @@ get_had156, get_had172, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 try: from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform @@ -50,21 +48,14 @@ def matmul_hadU(X, hadK, K): def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor") - return torch.library.impl(f"{name}", "cuda")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator def register_custom_op_abstract(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 47ecb9aabe..fa0293bf82 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -304,12 +304,6 @@ quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) ## If different zero_point_domain needed # quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) -# temporary workaround for tensor subclass + torch.compile -# NOTE: this is only need for torch version < 2.5+ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -from torchao.utils import unwrap_tensor_subclass -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(m) # compile the model to improve performance m = torch.compile(m, mode='max-autotune') diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index cf3fbad6ad..1fe30a59d1 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -31,8 +31,6 @@ compute_error, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, @@ -329,6 +327,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -343,24 +343,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - res = benchmarker.benchmark_gpu( lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) return res @@ -1346,12 +1331,11 @@ def finalize_autoquant(): return model -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) - torch.serialization.add_safe_globals( - [ - _to_float16, - _to_bfloat16, - _identity, - ] - ) +torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) +torch.serialization.add_safe_globals( + [ + _to_float16, + _to_bfloat16, + _identity, + ] +) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 658b172994..cbeb9cdb6f 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationQuantizedTensor", @@ -290,6 +287,5 @@ def _(func, types, args, kwargs): to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float # Converts a float tensor to LinearActivationQuantizedTensor for dynamic activation quantization -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 005bc8d32d..500228cf3c 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -6,10 +6,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationScaleMetadata", @@ -119,8 +116,5 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationScaleMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationScaleMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([WeightTensorWithLinearActivationScaleMetadata]) diff --git a/torchao/quantization/linear_activation_weight_observed_tensor.py b/torchao/quantization/linear_activation_weight_observed_tensor.py index 029b89e54b..d17bc382db 100644 --- a/torchao/quantization/linear_activation_weight_observed_tensor.py +++ b/torchao/quantization/linear_activation_weight_observed_tensor.py @@ -9,10 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.observer import AffineQuantizedObserverBase -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationWeightObservedTensor", @@ -153,6 +150,5 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) diff --git a/torchao/quantization/linear_quant_modules.py b/torchao/quantization/linear_quant_modules.py index 73e95036f1..de6755a55d 100644 --- a/torchao/quantization/linear_quant_modules.py +++ b/torchao/quantization/linear_quant_modules.py @@ -16,10 +16,7 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - find_multiple, -) +from torchao.utils import find_multiple from .quant_primitives import ( MappingType, @@ -60,7 +57,7 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if is_device(x.device.type, "cpu"): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x.to(precision), weight_int4pack, @@ -299,10 +296,7 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - if ( - is_device(w_int4x8.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(w_int4x8.device.type, "cpu"): weight_int4pack = ( torch.ops.aten._convert_weight_to_int4pack_for_cpu( w_int4x8.to(self.device), self.inner_k_tiles diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 6084da6e8d..6d928a4477 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -11,7 +11,6 @@ import torch from torchao.quantization.quant_primitives import _fake_quantize_affine -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from .granularity import ( Granularity, @@ -373,6 +372,5 @@ def calculate_qparams(self): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([PerRow, PerTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([PerRow, PerTensor]) diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index 0346981391..5211e0f340 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -14,13 +14,9 @@ from torch.ao.ns.fx.utils import compute_sqnr from torch.export import ExportedProgram from torch.fx import GraphModule, Node +from torch.fx.traceback import NodeSource from torch.nn import functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource - from .graph_utils import bfs_trace_with_node_process NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" @@ -262,12 +258,6 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: Returns: a model with output loggers for all unlifted nodes """ - if not TORCH_VERSION_AT_LEAST_2_6: - log.warning( - "prepare_for_propagation_comparison is only supported for PyTorch 2.6+" - ) - return model - # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: diff --git a/torchao/quantization/pt2e/constant_fold.py b/torchao/quantization/pt2e/constant_fold.py index 27f82e6757..365eb0a77a 100644 --- a/torchao/quantization/pt2e/constant_fold.py +++ b/torchao/quantization/pt2e/constant_fold.py @@ -12,8 +12,6 @@ from torch._inductor.freezing_utils import maybe_set_is_frozen_param from torch.utils._ordered_set import OrderedSet -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten # We would like to split modules into two subgraphs for runtime weight updates to work correctly. @@ -162,13 +160,9 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.convert_element_type.no_fuse, + torch.ops.torchao.dequantize_affine, ] - if TORCH_VERSION_AT_LEAST_2_5: - DEQUANT_OPS += [ - torch.ops.torchao.dequantize_affine, - ] - if node.target in DEQUANT_OPS: # For the pattern fp32_weight -> q -> dq # We only folding fp32_weight -> q diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 99516ac4c3..3728d7c252 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -69,14 +69,11 @@ from torch.fx import GraphModule from torch.fx.graph import Argument, Graph, Node from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY +from torch.fx.traceback import NodeSource, NodeSourceAction from torch.nn.utils.parametrize import type_before_parametrizations from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource, NodeSourceAction __all__ = [ "convert", @@ -188,8 +185,6 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): def add_quantize_dequantize_node_info(qdq_node, original_node): # propagate from_node info from observer/fake_quant node to quantize/dequantize node - if not TORCH_VERSION_AT_LEAST_2_6: - return qdq_node.meta[FROM_NODE_KEY] = [ NodeSource( original_node, diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 4115040669..60962f8d41 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1877,13 +1877,6 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): observer_node: the observer node to convert """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise NotImplementedError( - "convert for AffineQuantization is not implemented for pytorch version earlier than 2.5, please upgrade your pytorch to 2.5+." - ) - from torchao.quantization.pt2e.utils import create_getattr_from_value with model.graph.inserting_before(observer_node): diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index d8f5b99fc5..a1d57062f2 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,7 +38,6 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # TODO: make pt2e folder private? __all__ = [ @@ -553,7 +552,6 @@ def _maybe_insert_output_observer_for_node( isinstance(node, Node) and isinstance(new_output, Node) and FROM_NODE_KEY in node.meta - and TORCH_VERSION_AT_LEAST_2_6 ): new_output.meta[FROM_NODE_KEY] = node.meta[FROM_NODE_KEY] return new_output diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index 5eb385b7de..e58dc8e3ee 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -6,7 +6,7 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 if TORCH_VERSION_AT_LEAST_2_7: from .constant_fold import constant_fold @@ -217,14 +217,9 @@ def train_loop(model, train_data): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANT_OPS += [ - torch.ops.torchao.quantize_affine, - ] - def _quant_node_constraint(n: Node) -> bool: """If there is any pure ops between get_attr and quantize op they will be const propagated diff --git a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py index bef93a19fc..5e7e9344ee 100644 --- a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py +++ b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py @@ -15,7 +15,6 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY from torchao.quantization.pt2e.utils import _filter_sym_size_users from torchao.quantization.quant_primitives import quant_lib # noqa: F401 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from .quantizer import QuantizationSpecBase from .utils import is_valid_annotation @@ -34,27 +33,23 @@ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine, ] _CHOOSE_QPARAMS_OPS = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.torchao.choose_qparams_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANTIZE_OPS += [torch.ops.torchao.quantize_affine] - _DEQUANTIZE_OPS += [torch.ops.torchao.dequantize_affine] - _CHOOSE_QPARAMS_OPS += [torch.ops.torchao.choose_qparams_affine] - - def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: from_meta = from_node.meta for meta_name in _METADATA_TO_PORT: diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 59e759dab3..f94ec6f272 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -25,7 +25,6 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .fake_quantize_config import ( FakeQuantizeConfigBase, @@ -471,10 +470,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - if ( - is_device(q_weight.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(q_weight.device.type, "cpu"): q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( q_weight.to(child.weight.device), child.inner_k_tiles, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5d79563ab1..4e6bf7fa41 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -86,9 +86,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, _is_fbgemm_genai_gpu_available, is_MI300, is_sm_at_least_89, @@ -182,16 +179,16 @@ def _in_features_greater_than_16(mod, *args): return hasattr(mod, "in_features") and mod.in_features > 16 +# TODO: delete def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass, effectively applying the same form of quantization as apply_dynamic_quant while not modifying the linear modules. """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( @@ -207,6 +204,7 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): ) +# TODO: delete def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the @@ -214,10 +212,9 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): effectively applying the same form of quantization as apply_weight_only_int8_quant while not modifying the linear modules. """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) _replace_with_custom_fn_if_matches_filter( model, @@ -228,6 +225,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): ) +# TODO: delete def change_linear_weights_to_int4_woqtensors( model, groupsize=128, @@ -251,10 +249,9 @@ def change_linear_weights_to_int4_woqtensors( ZeroPointDomain.INT, ZeroPointDomain.NONE] `preserve_zero`: whether to preserve zero, default is False """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) if filter_fn is None: filter_fn = _is_linear @@ -655,20 +652,15 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: scale_dtype = torch.float32 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int8 - if TORCH_VERSION_AT_LEAST_2_6: - return to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - return to_affine_quantized_intx( - x, mapping_type, _get_per_token_block_size(x), target_dtype - ) + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -679,27 +671,17 @@ def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: zero_point_dtype = torch.int32 quant_min = 0 quant_max = 255 - if TORCH_VERSION_AT_LEAST_2_6: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) return out @@ -832,7 +814,6 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): args: weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 weight_granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(axis=0). weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. MappingType.SYMMETRIC requires ZeroPointDomain.NONE @@ -854,9 +835,6 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): layout: Layout = QDQLayout() def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Int8DynamicActivationIntxWeightConfig requires torch 2.6+" - ) assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" ) @@ -2045,7 +2023,6 @@ class IntxWeightOnlyConfig(AOBaseConfig): manner using the number of bits specified by weight_dtype. args: weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. @@ -2062,7 +2039,6 @@ class IntxWeightOnlyConfig(AOBaseConfig): layout: Layout = QDQLayout() def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" ) @@ -2286,16 +2262,15 @@ def _module_fqn_to_config_handler( return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - _int8_asymm_per_token_quant, - _int8_symm_per_token_reduced_range_quant, - _input_activation_quant_func_fp8, - _int4_symm_cutlass_quant, - _int8_symm_cutlass_quant, - _float8_cutlass_quant, - _float8_cutlass_quant_sparse, - Target, - ] - ) +torch.serialization.add_safe_globals( + [ + _int8_asymm_per_token_quant, + _int8_symm_per_token_reduced_range_quant, + _input_activation_quant_func_fp8, + _int4_symm_cutlass_quant, + _int8_symm_cutlass_quant, + _float8_cutlass_quant, + _float8_cutlass_quant_sparse, + Target, + ] +) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a91c3acd28..ebd2c7ecd8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -16,9 +16,6 @@ _n_ones, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, _register_custom_op, _register_meta_op, ) @@ -107,8 +104,7 @@ class TorchAODType(Enum): INT7 = auto() -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) +torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) FP8_TYPES = { torch.float8_e4m3fn, @@ -152,53 +148,49 @@ class TorchAODType(Enum): TorchAODType.INT7: (-(2**6), 2**6 - 1), } -# torch.uintX available only in PyTorch 2.3+ -if TORCH_VERSION_AT_LEAST_2_3: - _SUB_BYTE_UINT_BOUNDS = { - torch.uint1: (0, 2**1 - 1), - torch.uint2: (0, 2**2 - 1), - torch.uint3: (0, 2**3 - 1), - torch.uint4: (0, 2**4 - 1), - torch.uint5: (0, 2**5 - 1), - torch.uint6: (0, 2**6 - 1), - torch.uint7: (0, 2**7 - 1), +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} +_DTYPE_TO_BIT_WIDTH.update( + { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, } - _DTYPE_TO_BIT_WIDTH.update( - { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - ) - -# torch.intX available only in PyTorch 2.6+ -if TORCH_VERSION_AT_LEAST_2_6: - _SUB_BYTE_INT_BOUNDS.update( - { - torch.int1: (-(2**0), 2**0 - 1), - torch.int2: (-(2**1), 2**1 - 1), - torch.int3: (-(2**2), 2**2 - 1), - torch.int4: (-(2**3), 2**3 - 1), - torch.int5: (-(2**4), 2**4 - 1), - torch.int6: (-(2**5), 2**5 - 1), - torch.int7: (-(2**6), 2**6 - 1), - } - ) - _DTYPE_TO_BIT_WIDTH.update( - { - torch.int1: 1, - torch.int2: 2, - torch.int3: 3, - torch.int4: 4, - torch.int5: 5, - torch.int6: 6, - torch.int7: 7, - } - ) +) + +_SUB_BYTE_INT_BOUNDS.update( + { + torch.int1: (-(2**0), 2**0 - 1), + torch.int2: (-(2**1), 2**1 - 1), + torch.int3: (-(2**2), 2**2 - 1), + torch.int4: (-(2**3), 2**3 - 1), + torch.int5: (-(2**4), 2**4 - 1), + torch.int6: (-(2**5), 2**5 - 1), + torch.int7: (-(2**6), 2**6 - 1), + } +) +_DTYPE_TO_BIT_WIDTH.update( + { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 6, + torch.int7: 7, + } +) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py index 5430463543..c9b853f300 100644 --- a/torchao/quantization/quantize_/common/kernel_preference.py +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -8,8 +8,6 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - # can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) # after python 3.10 is end of life (https://devguide.python.org/versions/) @@ -33,5 +31,4 @@ class KernelPreference(str, Enum): FBGEMM = "fbgemm" -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([KernelPreference]) +torch.serialization.add_safe_globals([KernelPreference]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 611c476b76..c15703e706 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -35,7 +35,6 @@ _choose_quant_func_and_quantize_tensor, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, _is_fbgemm_genai_gpu_available, fill_defaults, @@ -608,6 +607,5 @@ def _(func, types, args, kwargs): Float8Tensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) +# Allow a model with Float8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index bd894ceea0..0cf52436cc 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -12,7 +12,6 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults, ) @@ -443,6 +442,5 @@ def _(func, types, args, kwargs): Int4PreshuffledTensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4PreshuffledTensor]) +# Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PreshuffledTensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py index 371ab6de2b..8a153c350d 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py @@ -11,7 +11,6 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults, ) @@ -307,6 +306,5 @@ def _(func, types, args, kwargs): Int4Tensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4Tensor]) +# Allow a model with Int4Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4Tensor]) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index a4097ecc25..d56fa0732d 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -25,7 +25,6 @@ quantize_affine, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, ) @@ -449,7 +448,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min, quant_max, ) - if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: + if w.shape[-1] > 1: if (not (check_cpu_version(int_data.device))) and ( not (check_xpu_version(int_data.device)) ): @@ -470,10 +469,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if ( - TORCH_VERSION_AT_LEAST_2_5 - and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) - and not (check_cpu_version(w_int4x8.device)) + if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not ( + check_cpu_version(w_int4x8.device) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 diff --git a/torchao/quantization/weight_tensor_linear_activation_quantization.py b/torchao/quantization/weight_tensor_linear_activation_quantization.py index 6612213bc1..c0b0a893e4 100644 --- a/torchao/quantization/weight_tensor_linear_activation_quantization.py +++ b/torchao/quantization/weight_tensor_linear_activation_quantization.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationQuantizationMetadata", @@ -201,8 +198,7 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationQuantizationMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationQuantizationMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals( + [WeightTensorWithLinearActivationQuantizationMetadata] +) diff --git a/torchao/sparsity/training/__init__.py b/torchao/sparsity/training/__init__.py index 3c4212101b..87ce3add4f 100644 --- a/torchao/sparsity/training/__init__.py +++ b/torchao/sparsity/training/__init__.py @@ -4,17 +4,15 @@ # LICENSE file in the root directory of this source tree. import torch +# load pointwise op support, which exists only for CUTLASS +from torch.sparse import SparseSemiStructuredTensorCUTLASS + from torchao.sparsity.training.autograd import semi_structured_sparsify from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -# load pointwise op support, which exists only for CUTLASS -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import SparseSemiStructuredTensorCUTLASS - SparseSemiStructuredTensorCUTLASS._load_dispatch_table( - CUTLASS_POINTWISE_OP_DISPATCH_TABLE - ) +SparseSemiStructuredTensorCUTLASS._load_dispatch_table( + CUTLASS_POINTWISE_OP_DISPATCH_TABLE +) __all__ = [ "SemiSparseLinear", diff --git a/torchao/sparsity/training/autograd.py b/torchao/sparsity/training/autograd.py index fafbd7c3c3..40c6c98083 100644 --- a/torchao/sparsity/training/autograd.py +++ b/torchao/sparsity/training/autograd.py @@ -6,18 +6,14 @@ from enum import Enum import torch -from torch.sparse import SparseSemiStructuredTensor - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import ( - SparseSemiStructuredTensorCUSPARSELT, - SparseSemiStructuredTensorCUTLASS, - ) - - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) +from torch.sparse import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, +) + +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) GRADIENT_TYPE = Enum("GRADIENT_TYPE", ["DENSE", "SPARSE", "STE"]) diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index c4773231a5..a41d3f597f 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -15,6 +15,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec, QuantizationTestCase, @@ -29,16 +30,9 @@ prepare_pt2e, prepare_qat_pt2e, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - "only works for torch 2.5+ since export_for_training is only supported after 2.5", -) class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 38fc8b04ce..33def3f998 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -24,7 +24,6 @@ ) from torchao.testing.model_architectures import LlamaModelsLlama4Experts from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, DummyModule, get_compute_capability, ) @@ -420,10 +419,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) diff --git a/torchao/utils.py b/torchao/utils.py index 307e02c4a7..040fd4625e 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -141,9 +141,8 @@ def get_available_devices(): devices.append("cuda") elif torch.xpu.is_available(): devices.append("xpu") - if TORCH_VERSION_AT_LEAST_2_5: - if torch.mps.is_available(): - devices.append("mps") + if torch.mps.is_available(): + devices.append("mps") return devices @@ -216,37 +215,31 @@ def _the_op_that_needs_to_be_preserved(...) ) def decorator(fn): - if TORCH_VERSION_AT_LEAST_2_5: - from torch._library.infer_schema import infer_schema + from torch._library.infer_schema import infer_schema - assert not any(c in fn.__name__ for c in ".<>"), ( - f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" - ) - op_name = fn.__name__ - if op_name[0] == "_": - op_name = op_name[1:] - schema = op_name + infer_schema(fn, mutates_args={}) - lib.define(schema) - lib.impl(op_name, fn, dispatch_key) - - lib_namespace = lib.ns - op = getattr(getattr(torch.ops, lib_namespace), op_name) - if inductor_decomposed: - register_decomposition([op])(fn) - return op - else: - return fn + assert not any(c in fn.__name__ for c in ".<>"), ( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__ + if op_name[0] == "_": + op_name = op_name[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, dispatch_key) + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + if inductor_decomposed: + register_decomposition([op])(fn) + return op return decorator def _register_meta_op(lib, op_name): def decorator(fn): - if TORCH_VERSION_AT_LEAST_2_5: - op = lib.impl(op_name, fn, "Meta") - return op - else: - return fn + op = lib.impl(op_name, fn, "Meta") + return op return decorator @@ -617,9 +610,8 @@ def decorator(tensor_impl_class): tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = ( tensor_impl_class.from_plain ) - if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this tensor impl subclass - torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) + # Allow serialization to work for models uses this tensor impl subclass + torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) return tensor_impl_class return decorator diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index faaa9b1ae9..c326828219 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -37,12 +37,6 @@ torch._inductor.config.use_mixed_mm = True ## compilation configs end -# temporary workaround for the API to work with torch.compile -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass - -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # temporary workaround to recover the perf with quantized model under torch.compile torch.backends.mha.set_fastpath_enabled(False) From 42f908154bf312c8e4094d52af34c77045f0cd6a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 12:20:30 -0700 Subject: [PATCH 04/13] Remove old `change_linear_weights_to_*` APIs **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- benchmarks/benchmark_aq.py | 4 +- test/integration/test_integration.py | 6 -- test/quantization/test_quant_api.py | 26 ------- torchao/quantization/README.md | 15 ---- torchao/quantization/quant_api.py | 103 --------------------------- 5 files changed, 3 insertions(+), 151 deletions(-) diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 7dd732debc..5106eb5494 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -75,11 +75,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs """ from torchao.quantization.quant_api import ( _get_subclass_inserter, - _in_features_greater_than_16, _is_linear, ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( *args diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6ff0e6f08c..c0f3bb0883 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,7 +40,6 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int8_dqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -1829,11 +1828,6 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): - if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": - self.skipTest( - f"{api} in {test_device} is not support for aoti compilation yet" - ) - if ( test_device == "cuda" and torch.cuda.is_available() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 3b26cd25d6..f979c9a588 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -146,32 +146,6 @@ def forward(self, x): return x -def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - The deprecated implementation for int8 dynamic quant API, used as a reference for - numerics and performance - """ - from torchao.quantization.quant_api import ( - _get_subclass_inserter, - _in_features_greater_than_16, - _is_linear, - ) - from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index fa0293bf82..ab3a27f05a 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -125,7 +125,6 @@ be applied individually. While there are a large variety of quantization apis, t #### A16W4 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 @@ -133,10 +132,6 @@ group_size = 32 # use_hqq flag for `Int4WeightOnlyConfig` quantization use_hqq = False quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors -change_linear_weights_to_int4_woqtensors(model) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. @@ -144,25 +139,15 @@ Note: The quantization error incurred by applying int4 quantization to your mode #### A16W8 Int8 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8WeightOnlyConfig quantize_(model, Int8WeightOnlyConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors -change_linear_weights_to_int8_woqtensors(model) ``` #### A8W8 Int8 Dynamic Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig quantize_(model, Int8DynamicActivationInt8WeightConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors -change_linear_weights_to_int8_dqtensors(model) ``` ### A16W8 Float8 WeightOnly Quantization diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4e6bf7fa41..c23343afb0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -172,109 +172,6 @@ } -###### -# TO BE DEPRECATED START -###### -def _in_features_greater_than_16(mod, *args): - return hasattr(mod, "in_features") and mod.in_features > 16 - - -# TODO: delete -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` - Tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - -# TODO: delete -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the - `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_weight_only_int8_quant while not modifying the linear modules. - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - _is_linear if filter_fn is None else filter_fn, - ) - - -# TODO: delete -def change_linear_weights_to_int4_woqtensors( - model, - groupsize=128, - inner_k_tiles=8, - filter_fn=None, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, -): - """ - Converts all linear weight tensors to the - `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] - `filter_fn`: function that takes a nn.Module instance and fully qualified name of the module, \ - returns True if we want to run `config` on - `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, \ - ZeroPointDomain.INT, ZeroPointDomain.NONE] - `preserve_zero`: whether to preserve zero, default is False - """ - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = _is_linear - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int4WeightOnlyQuantizedLinearWeight, - enable_parametrization=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ), - filter_fn, - ) - - -######## -# TO BE DEPRECATED END -######## - - def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, From 83fb7394f51a065268a9ad2ee21ca9341ae55655 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 12:54:15 -0700 Subject: [PATCH 05/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- torchao/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/utils.py b/torchao/utils.py index 040fd4625e..d5e43ca1f6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -397,6 +397,9 @@ def __bool__(self): warnings.warn(self.msg) return self.bool_value + def __eq__(self, other): + return bool(self) == bool(other) + TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") From 4697b22add0d51d6f6ed9e2b1daf0690513f6092 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 13:10:20 -0700 Subject: [PATCH 06/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- test/test_utils.py | 24 +++++++++---------- torchao/utils.py | 57 +++++++++++++++++----------------------------- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ddbbf68ecc..ebc23466c1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -58,22 +58,22 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_6, ) - deprecated_api_to_name = { - TORCH_VERSION_AT_LEAST_2_6: "TORCH_VERSION_AT_LEAST_2_6", - TORCH_VERSION_AT_LEAST_2_5: "TORCH_VERSION_AT_LEAST_2_5", - TORCH_VERSION_AT_LEAST_2_4: "TORCH_VERSION_AT_LEAST_2_4", - TORCH_VERSION_AT_LEAST_2_3: "TORCH_VERSION_AT_LEAST_2_3", - TORCH_VERSION_AT_LEAST_2_2: "TORCH_VERSION_AT_LEAST_2_2", - TORCH_VERSION_AFTER_2_5: "TORCH_VERSION_AFTER_2_5", - TORCH_VERSION_AFTER_2_4: "TORCH_VERSION_AFTER_2_4", - TORCH_VERSION_AFTER_2_3: "TORCH_VERSION_AFTER_2_3", - TORCH_VERSION_AFTER_2_2: "TORCH_VERSION_AFTER_2_2", - } + deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), + (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), + (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), + (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), + (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), + (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), + (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), + (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), + (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), + ] self.assertEqual(len(_warnings), 0) # Accessing the boolean value should trigger deprecation warning with warnings.catch_warnings(record=True) as _warnings: - for api, name in deprecated_api_to_name.items(): + for api, name in deprecated_api_to_name: num_warnings_before = len(_warnings) if api: pass diff --git a/torchao/utils.py b/torchao/utils.py index d5e43ca1f6..f4c3d8fa3d 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -371,17 +371,20 @@ def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 -# Deprecated, will be deleted in the future -def _torch_version_after(min_version): - return is_fbcode() or version("torch") >= min_version - - -def _get_old_torch_version_deprecation_msg(version_str: str) -> str: - return f"TORCH_VERSION_AT_LEAST_{version_str} is deprecated and will be removed in torchao 0.14.0" +def _deprecated_torch_version_at_least(version_str: str) -> str: + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper( + torch_version_at_least(version_str), + deprecation_msg, + ) -def _get_torch_version_after_deprecation_msg(version_str: str) -> str: - return f"TORCH_VERSION_AFTER_{version_str} is deprecated and will be removed in torchao 0.14.0" +def _deprecated_torch_version_after(version_str: str) -> str: + bool_value = is_fbcode() or version("torch") >= version_str + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper(bool_value, deprecation_msg) class _BoolDeprecationWrapper: @@ -405,33 +408,15 @@ def __eq__(self, other): TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") # Deprecated -TORCH_VERSION_AT_LEAST_2_6 = _BoolDeprecationWrapper( - torch_version_at_least("2.6.0"), _get_old_torch_version_deprecation_msg("2_6") -) -TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( - torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") -) -TORCH_VERSION_AT_LEAST_2_4 = _BoolDeprecationWrapper( - torch_version_at_least("2.4.0"), _get_old_torch_version_deprecation_msg("2_4") -) -TORCH_VERSION_AT_LEAST_2_3 = _BoolDeprecationWrapper( - torch_version_at_least("2.3.0"), _get_old_torch_version_deprecation_msg("2_3") -) -TORCH_VERSION_AT_LEAST_2_2 = _BoolDeprecationWrapper( - torch_version_at_least("2.2.0"), _get_old_torch_version_deprecation_msg("2_2") -) -TORCH_VERSION_AFTER_2_5 = _BoolDeprecationWrapper( - _torch_version_after("2.5.0.dev"), _get_torch_version_after_deprecation_msg("2_5") -) -TORCH_VERSION_AFTER_2_4 = _BoolDeprecationWrapper( - _torch_version_after("2.4.0.dev"), _get_torch_version_after_deprecation_msg("2_4") -) -TORCH_VERSION_AFTER_2_3 = _BoolDeprecationWrapper( - _torch_version_after("2.3.0.dev"), _get_torch_version_after_deprecation_msg("2_3") -) -TORCH_VERSION_AFTER_2_2 = _BoolDeprecationWrapper( - _torch_version_after("2.2.0.dev"), _get_torch_version_after_deprecation_msg("2_2") -) +TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") +TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") +TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") +TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") +TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") +TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") +TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") """ From ac6c78f663156b013f95f106a07152ff1733c920 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 13:16:17 -0700 Subject: [PATCH 07/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] From d6c4715260e4b3df37685d70356889748c218c5c Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 13:59:12 -0700 Subject: [PATCH 08/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- test/test_utils.py | 8 ++++++-- torchao/utils.py | 15 +++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ebc23466c1..9213097276 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -38,8 +38,8 @@ def test_torch_version_at_least(self): def test_torch_version_deprecation(self): """ - Test that TORCH_VERSION_AT_LEAST_2_6 and before and TORCH_VERSION_AFTER* - trigger a deprecation warning. + Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* + trigger deprecation warnings on use, not on import. """ # Reset deprecation warning state, otherwise we won't log warnings here warnings.resetwarnings() @@ -56,9 +56,13 @@ def test_torch_version_deprecation(self): TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, ) deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), + (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), diff --git a/torchao/utils.py b/torchao/utils.py index f4c3d8fa3d..63ea463b92 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -372,6 +372,10 @@ def torch_version_at_least(min_version): def _deprecated_torch_version_at_least(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log + a deprecation warning if the variable is used. + """ version_str_var_name = "_".join(version_str.split(".")[:2]) deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" return _BoolDeprecationWrapper( @@ -381,6 +385,10 @@ def _deprecated_torch_version_at_least(version_str: str) -> str: def _deprecated_torch_version_after(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AFTER* variables that will log + a deprecation warning if the variable is used. + """ bool_value = is_fbcode() or version("torch") >= version_str version_str_var_name = "_".join(version_str.split(".")[:2]) deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" @@ -404,10 +412,9 @@ def __eq__(self, other): return bool(self) == bool(other) -TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") - -# Deprecated +# Deprecated, use `torch_version_at_least` directly instead +TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") +TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") From 7980103f837a8d99da1d6284d504536effef4d18 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 8 Aug 2025 15:12:13 -0700 Subject: [PATCH 09/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- torchao/quantization/autoquant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1fe30a59d1..5745f00e99 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -343,9 +343,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) return res From 670dccf73b20dfa849eb55708565c7bddeada1c1 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 12 Aug 2025 09:25:27 -0700 Subject: [PATCH 10/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- test/integration/test_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index a97fa1afd3..76ac6c9109 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -968,6 +968,7 @@ def _test_lin_weight_subclass_api_impl( ) ) ) + @unittest.skip("skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): api = partial( _int8da_int8w_api, From cc84d2e529e11d63bb5a6f3e3f3a7c5981595523 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 12 Aug 2025 11:33:51 -0700 Subject: [PATCH 11/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- test/integration/test_integration.py | 197 +++++++++++++++++++++------ 1 file changed, 152 insertions(+), 45 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 76ac6c9109..09af28ba47 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,7 +40,6 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int8_dqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -76,7 +75,13 @@ compute_error as SQNR, ) from torchao.testing.utils import skip_if_rocm + +# TODO: stop using these from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, @@ -110,7 +115,14 @@ def _int8wo_api(mod): - quantize_(mod, int8_weight_only(set_inductor_config=False)) + if TORCH_VERSION_AT_LEAST_2_4: + quantize_(mod, int8_weight_only(set_inductor_config=False)) + if not TORCH_VERSION_AT_LEAST_2_5 or ( + not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing + ): + unwrap_tensor_subclass(mod) + else: + raise ValueError("should not be here") def _int8wo_groupwise_api(mod): @@ -122,13 +134,18 @@ def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, ): - quantize_( - mod, - int8_dynamic_activation_int8_weight( - act_mapping_type=act_mapping_type, - set_inductor_config=False, - ), - ) + if TORCH_VERSION_AT_LEAST_2_4: + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + set_inductor_config=False, + ), + ) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + raise ValueError("should not be here") def _int4wo_api(mod, use_hqq=False): @@ -145,12 +162,18 @@ def _int4wo_api(mod, use_hqq=False): mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False ) unwrap_tensor_subclass(mod) - else: + elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(set_inductor_config=False)) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + raise ValueError("should not be here") def _int8da_int4w_api(mod): quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) # TODO: use this to reduce the number of tests @@ -369,6 +392,7 @@ def test_swap(self): assert torch.allclose(y_ref, y) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_weight_t_and_non_t_numerics_match(self): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format @@ -685,6 +709,8 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -703,6 +729,8 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": @@ -760,6 +788,9 @@ def _test_lin_weight_subclass_impl( ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" + ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8DynamicallyQuantizedLinearWeight.from_float, @@ -776,6 +807,9 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8DynamicallyQuantizedLinearWeight.from_float, @@ -785,6 +819,9 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" @@ -798,6 +835,9 @@ def test_aq_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight2.from_float, @@ -807,6 +847,9 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight3.from_float, @@ -816,6 +859,9 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -845,6 +891,9 @@ def test_autoquantizable_flatten_unflatten(self): for device, dtype in COMMON_DEVICE_DTYPE ] ) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( @@ -869,6 +918,9 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): @@ -880,6 +932,8 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -898,6 +952,8 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") @unittest.skip("Skip to fix CI until we deprecate these APIs long term") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): @@ -968,8 +1024,14 @@ def _test_lin_weight_subclass_api_impl( ) ) ) - @unittest.skip("skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): + if ( + not TORCH_VERSION_AT_LEAST_2_5 + and dtype in (torch.float16, torch.bfloat16) + and act_mapping is MappingType.ASYMMETRIC + and device == "cpu" + ): + self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, @@ -979,6 +1041,12 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): + if ( + not TORCH_VERSION_AT_LEAST_2_6 + and dtype in (torch.float16, torch.bfloat16) + and device == "cpu" + ): + self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype @@ -986,6 +1054,9 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." + ) @skip_if_rocm("Test flaky on ROCm, under investigation") def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() @@ -994,6 +1065,8 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1005,6 +1078,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.") def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1018,6 +1092,9 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater" + ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): if dtype != torch.float16: @@ -1061,6 +1138,8 @@ def test_gemlite_layout(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: @@ -1082,9 +1161,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): kwargs_copy = kwargs.copy() - kwargs_copy["group_size"] = groupsize - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) + if TORCH_VERSION_AT_LEAST_2_4: + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy)) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + kwargs_copy["inner_k_tiles"] = inner_k_tiles + del kwargs_copy["layout"] + raise ValueError("should not be here") self._test_lin_weight_subclass_api_impl( api, @@ -1165,7 +1251,11 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") + mixed_mm_key, mixed_mm_val = ( + ("mixed_mm_choice", "triton") + if TORCH_VERSION_AT_LEAST_2_5 + else ("force_mixed_mm", True) + ) with config.patch( { @@ -1198,7 +1288,11 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") + mixed_mm_key, mixed_mm_val = ( + ("mixed_mm_choice", "triton") + if TORCH_VERSION_AT_LEAST_2_5 + else ("force_mixed_mm", True) + ) with config.patch( { @@ -1300,10 +1394,18 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): + if ( + not TORCH_VERSION_AT_LEAST_2_6 + and dtype in (torch.float16, torch.bfloat16) + and device == "cpu" + ): + self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch 2.3+.") + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1313,6 +1415,9 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." + ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1361,6 +1466,7 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") @@ -1455,6 +1561,7 @@ class TestAutoQuant(unittest.TestCase): ], ) ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) @@ -1496,6 +1603,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ], ) ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1517,6 +1625,9 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + # Skip certain shapes on older PyTorch versions + if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") # TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} is flaky, skipping") @@ -1545,6 +1656,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1572,6 +1684,7 @@ def forward(self, x): assert len(_AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1621,6 +1734,7 @@ def test_autoquant_manual(self, device, dtype): ], ) ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1630,27 +1744,9 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest("bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - - # Note: This test was incorrectly written before with this skip condition: - # - # m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: - # - # This is actually equivalent to: - # - # m1 == 1 or (m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5) - # - # which means we always skips the test as long as `m1 == 1` regardless of - # the pytorch version, which was not the intended behavior. Unfortunately, - # unskipping this test now leads to the following error when calling - # `aten._int_mm`: - # - # RuntimeError: self.size(0) needs to be greater than 16, but got 1 - # - # Therefore, we keep around this skip condition for now since it doesn't - # change the test behavior from before. For more details, please see - # https://github.com/pytorch/ao/pull/2720. - if m1 == 1: - self.skipTest(f"Shape {(m1, m2, k, n)} is not supported") + # This test fails on v0.4.0 and torch 2.4, so skipping for now. + if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: + self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") class NeedsKwargs(torch.nn.Module): def __init__(self): @@ -1685,6 +1781,7 @@ def forward(self, x, y): ], ) ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1737,6 +1834,9 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." + ) def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 @@ -1767,6 +1867,9 @@ def test_autoquant_hp_float(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_autoquant_int4wo(self, device, dtype): if device == "cpu": @@ -1802,6 +1905,9 @@ def test_autoquant_int4wo(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) @unittest.skipIf( True, "Skipping for now, do to lowering bug in inductor" ) # TODO unblock when fixed @@ -1841,6 +1947,7 @@ def test_autoquant_float8(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" @@ -1851,11 +1958,6 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): - if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": - self.skipTest( - f"{api} in {test_device} is not support for aoti compilation yet" - ) - if ( test_device == "cuda" and torch.cuda.is_available() @@ -1903,6 +2005,7 @@ def forward(self, x): ) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( @@ -1958,9 +2061,12 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() + if TORCH_VERSION_AT_LEAST_2_5: + model = torch.export.export_for_training( + model, example_inputs, strict=True + ).module() + else: + model = torch._export.capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: @@ -1999,6 +2105,7 @@ class TestUtils(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") From 68df77e7b5f0ffdefb997ea15216b7582695df3b Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 12 Aug 2025 14:56:35 -0700 Subject: [PATCH 12/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] --- test/integration/test_integration.py | 198 +++++++-------------------- 1 file changed, 46 insertions(+), 152 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 09af28ba47..5c29f0b8ad 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,6 +40,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, + change_linear_weights_to_int8_dqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -75,13 +76,7 @@ compute_error as SQNR, ) from torchao.testing.utils import skip_if_rocm - -# TODO: stop using these from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, @@ -115,14 +110,7 @@ def _int8wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5 or ( - not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing - ): - unwrap_tensor_subclass(mod) - else: - raise ValueError("should not be here") + quantize_(mod, int8_weight_only(set_inductor_config=False)) def _int8wo_groupwise_api(mod): @@ -134,18 +122,13 @@ def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, ): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight( - act_mapping_type=act_mapping_type, - set_inductor_config=False, - ), - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - raise ValueError("should not be here") + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + set_inductor_config=False, + ), + ) def _int4wo_api(mod, use_hqq=False): @@ -162,18 +145,12 @@ def _int4wo_api(mod, use_hqq=False): mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False ) unwrap_tensor_subclass(mod) - elif TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int4_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) else: - raise ValueError("should not be here") + quantize_(mod, int4_weight_only(set_inductor_config=False)) def _int8da_int4w_api(mod): quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) # TODO: use this to reduce the number of tests @@ -392,7 +369,6 @@ def test_swap(self): assert torch.allclose(y_ref, y) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_weight_t_and_non_t_numerics_match(self): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format @@ -709,8 +685,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -729,8 +703,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": @@ -788,9 +760,6 @@ def _test_lin_weight_subclass_impl( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" - ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8DynamicallyQuantizedLinearWeight.from_float, @@ -807,9 +776,6 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8DynamicallyQuantizedLinearWeight.from_float, @@ -819,9 +785,6 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" @@ -835,9 +798,6 @@ def test_aq_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight2.from_float, @@ -847,9 +807,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight3.from_float, @@ -859,9 +816,6 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -891,9 +845,6 @@ def test_autoquantizable_flatten_unflatten(self): for device, dtype in COMMON_DEVICE_DTYPE ] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( @@ -918,9 +869,6 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): @@ -932,8 +880,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -952,8 +898,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") @unittest.skip("Skip to fix CI until we deprecate these APIs long term") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): @@ -1024,14 +968,8 @@ def _test_lin_weight_subclass_api_impl( ) ) ) + @unittest.skip("skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): - if ( - not TORCH_VERSION_AT_LEAST_2_5 - and dtype in (torch.float16, torch.bfloat16) - and act_mapping is MappingType.ASYMMETRIC - and device == "cpu" - ): - self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, @@ -1041,12 +979,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype @@ -1054,9 +986,6 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." - ) @skip_if_rocm("Test flaky on ROCm, under investigation") def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() @@ -1065,8 +994,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1078,7 +1005,6 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.") def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1092,9 +1018,6 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater" - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): if dtype != torch.float16: @@ -1138,8 +1061,6 @@ def test_gemlite_layout(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: @@ -1161,16 +1082,9 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): kwargs_copy = kwargs.copy() - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy["group_size"] = groupsize - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - kwargs_copy["inner_k_tiles"] = inner_k_tiles - del kwargs_copy["layout"] - raise ValueError("should not be here") + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy)) self._test_lin_weight_subclass_api_impl( api, @@ -1251,11 +1165,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1288,11 +1198,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1394,18 +1300,10 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch 2.3+.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1415,9 +1313,6 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." - ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1466,7 +1361,7 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") + @unittest.skip("Seg fault?") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") @@ -1561,7 +1456,6 @@ class TestAutoQuant(unittest.TestCase): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) @@ -1603,7 +1497,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1625,9 +1518,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # Skip certain shapes on older PyTorch versions - if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") # TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} is flaky, skipping") @@ -1656,7 +1546,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1684,7 +1573,6 @@ def forward(self, x): assert len(_AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1734,7 +1622,6 @@ def test_autoquant_manual(self, device, dtype): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1744,9 +1631,27 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest("bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. - if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") + + # Note: This test was incorrectly written before with this skip condition: + # + # m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: + # + # This is actually equivalent to: + # + # m1 == 1 or (m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5) + # + # which means we always skips the test as long as `m1 == 1` regardless of + # the pytorch version, which was not the intended behavior. Unfortunately, + # unskipping this test now leads to the following error when calling + # `aten._int_mm`: + # + # RuntimeError: self.size(0) needs to be greater than 16, but got 1 + # + # Therefore, we keep around this skip condition for now since it doesn't + # change the test behavior from before. For more details, please see + # https://github.com/pytorch/ao/pull/2720. + if m1 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} is not supported") class NeedsKwargs(torch.nn.Module): def __init__(self): @@ -1781,7 +1686,6 @@ def forward(self, x, y): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1834,9 +1738,6 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." - ) def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 @@ -1867,9 +1768,6 @@ def test_autoquant_hp_float(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_autoquant_int4wo(self, device, dtype): if device == "cpu": @@ -1905,9 +1803,6 @@ def test_autoquant_int4wo(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf( True, "Skipping for now, do to lowering bug in inductor" ) # TODO unblock when fixed @@ -1947,7 +1842,6 @@ def test_autoquant_float8(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" @@ -1958,6 +1852,11 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): + if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": + self.skipTest( + f"{api} in {test_device} is not support for aoti compilation yet" + ) + if ( test_device == "cuda" and torch.cuda.is_available() @@ -2005,7 +1904,6 @@ def forward(self, x): ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( @@ -2061,12 +1959,9 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - if TORCH_VERSION_AT_LEAST_2_5: - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() - else: - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export_for_training( + model, example_inputs, strict=True + ).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: @@ -2105,7 +2000,6 @@ class TestUtils(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") From 158a8bcaf1cf5bdc925f7de140f3a3f10892d591 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 12 Aug 2025 16:03:07 -0700 Subject: [PATCH 13/13] Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned]