diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index becb44a5e0..1b17c40fb0 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. import tempfile -import unittest +import pytest import torch +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import ( TestCase, - instantiate_parametrized_tests, parametrize, run_tests, ) @@ -33,9 +33,19 @@ def get_config(group_size): ) -@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -@unittest.skipIf(not torch.xpu.is_available(), "XPU not available") class Int4PlainInt32Tensor(TestCase): + _MIN_VER = { + "xpu": "2.8.0", + "npu": "2.7.1", + } + + def setUp(self): + min_req = type(self)._MIN_VER.get(self.device_type) + if not torch_version_at_least(min_req): + self.skipTest( + f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}" + ) + @parametrize( "sizes", [ @@ -46,24 +56,35 @@ class Int4PlainInt32Tensor(TestCase): ) @parametrize("dtype", [torch.bfloat16, torch.half]) @parametrize("group_size", [32, 64, 128]) - def test_linear(self, sizes, dtype, group_size): - device = "xpu" + @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) + def test_linear(self, device, sizes, dtype, group_size, thresholds): M, N, K = sizes + if "npu" in device and group_size == K: + pytest.skip( + f"{device} does not support group_size equal to K dimension ({group_size} == {K})" + ) + threshold = thresholds.get(device.split(":")[0]) + input = torch.randn(*M, K, dtype=dtype, device=device) linear = torch.nn.Linear(K, N, dtype=dtype, device=device) original = linear(input) quantize_(linear, get_config(group_size)) quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) + self.assertTrue(compute_error(original, quantized) > threshold) - compiled_linear = torch.compile(linear) - quantized_and_compiled = compiled_linear(input) - self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + if "xpu" in device: + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > threshold) @parametrize("dtype", [torch.bfloat16, torch.half]) - def test_module_path(self, dtype): - linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") - quantize_(linear, get_config(group_size=128)) + def test_module_path(self, device, dtype): + K, N, group_size = 128, 256, 128 + if "npu" in device: + group_size = 64 + + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + quantize_(linear, get_config(group_size)) self.assertEqual( str(type(linear.weight)), "", @@ -78,13 +99,21 @@ def test_module_path(self, dtype): "", ) - def test_activation_prescaling(self): - dtype = torch.bfloat16 - device = "xpu" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) + def test_activation_prescaling(self, device, dtype, thresholds): + if "xpu" in device and dtype == torch.float16: + pytest.skip(f"{device} test_activation_prescaling don't test {dtype}") + + threshold = thresholds.get(device.split(":")[0]) + K, N, group_size = 128, 256, 128 + if "npu" in device: + group_size = 64 + + input = torch.randn(1, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) original = linear(input) - quantize_(linear, get_config(128)) + quantize_(linear, get_config(group_size)) qw = linear.weight assert isinstance(qw, SupportsActivationPreScaling), ( "Expected int4 tensor supports activation prescaling" @@ -95,10 +124,12 @@ def test_activation_prescaling(self): quantized = linear(input) # making sure activation pre scaling is successfully applied to the activation - self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold) -instantiate_parametrized_tests(Int4PlainInt32Tensor) +instantiate_device_type_tests( + Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True +) if __name__ == "__main__": diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index dbd7983b8e..a4b5d2801e 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -71,8 +71,12 @@ use_hqq = False quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")) ``` -Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. - +Note: +- The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. +- Third-party backend CI status: + - Ascend NPU(requires torch_npu ≥ 2.7.1) + [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) + #### A16W8 Int8 WeightOnly Quantization ```python diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 2c8de1a2d0..51e09dbe9c 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -14,9 +14,7 @@ choose_qparams_affine, quantize_affine, ) -from torchao.utils import ( - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor, torch_version_at_least __all__ = [ "Int4PlainInt32Tensor", @@ -91,58 +89,158 @@ def from_hp( w: torch.Tensor, block_size: List[int], ): - assert w.ndim == 2 and w.device.type == "xpu", ( - f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" - ) - assert len(block_size) == w.ndim - assert w.dtype in [torch.float16, torch.bfloat16], ( - f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" - ) - original_shape = w.shape - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - scale_dtype = None - zero_point_dtype = torch.int32 - scale, zero_point = choose_qparams_affine( - w, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - ) - int_data = quantize_affine( - w, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - ) - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" - ) - packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - packed_weight.contiguous(), 8 - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - return Int4PlainInt32Tensor( - packed_weight, - scale.transpose(0, 1).contiguous(), - zero_point.transpose(0, 1).contiguous().to(torch.int8), - block_size, - original_shape, - act_pre_scale=None, - ) + if w.device.type == "xpu": + return _from_hp_xpu(cls, w, block_size) + elif w.device.type == "npu": + return _from_hp_npu(cls, w, block_size) + else: + raise NotImplementedError( + f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet." + ) + + +def _from_hp_xpu( + cls, + w: torch.Tensor, + block_size: List[int], +): + assert w.ndim == 2 and w.device.type == "xpu", ( + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = torch.int32 + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" + ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + return Int4PlainInt32Tensor( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + block_size, + original_shape, + act_pre_scale=None, + ) + + +def _from_hp_npu( + cls, + w: torch.Tensor, + block_size: List[int], +): + assert ( + torch.accelerator.is_available() + and torch.accelerator.current_accelerator().type == "npu" + and torch_version_at_least("2.7.1") + ), ( + f"PyTorch NPU 2.7.1+ needed for int4 packing and matmul ops, {torch.__version__} found" + ) + + assert w.ndim == 2 and w.device.type == "npu", ( + f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + + group_size = block_size[1] + k_dim = w.shape[-1] + assert group_size >= 32 and group_size % 32 == 0 and group_size < k_dim, ( + f"Invalid group_size={group_size}: " + f"expected to be a multiple of 32, " + f"in range [32, {k_dim - 1}] for per-group quantization, " + f"but got group_size={group_size} (k_dim={k_dim})." + ) + + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = -8 + quant_max = 7 + eps = 1e-6 + scale_dtype = w.dtype + zero_point_dtype = w.dtype + + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + + assert int_data.dtype == torch.int32, ( + "torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" + ) + assert int_data.shape[-1] % 8 == 0, ( + f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" + ) + + packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( + int_data.contiguous(), 0 + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + + return Int4PlainInt32Tensor( + packed_weight.contiguous(), + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous(), + block_size, + original_shape, + act_pre_scale=None, + ) implements = Int4PlainInt32Tensor.implements @@ -157,6 +255,22 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) + + if input_tensor.device.type == "xpu": + return _linear_xpu(input_tensor, weight_tensor, bias) + elif input_tensor.device.type == "npu": + return _linear_npu(input_tensor, weight_tensor, bias) + else: + raise NotImplementedError( + f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet." + ) + + +def _linear_xpu( + input_tensor, + weight_tensor, + bias, +): assert input_tensor.device.type == "xpu", ( f"For XPU device only but got: {input_tensor.device}" ) @@ -201,6 +315,71 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) +def _linear_npu( + input_tensor, + weight_tensor, + bias, +): + assert input_tensor.device.type == "npu", ( + f"For NPU device only but got: {input_tensor.device.type}" + ) + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( + f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.shape + orig_dtype = act_mat.dtype + + # dtype alignment + if act_mat.dtype == torch.float16: + scale = scale.to(torch.float16) + zero_point = zero_point.to(torch.float16) + if bias is not None: + bias = bias.to(torch.float16) + elif act_mat.dtype == torch.bfloat16: + scale = scale.to(torch.bfloat16) + zero_point = zero_point.to(torch.bfloat16) + if bias is not None: + bias = bias.to(torch.float32) + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + + y = torch.ops.npu.npu_weight_quant_batchmatmul( + x=act_mat, + weight=packed_weight.transpose(-1, -2), + antiquant_scale=scale, + antiquant_offset=zero_point, + antiquant_group_size=groupsize, + bias=bias, + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + Int4PlainInt32Tensor.__module__ = "torchao.quantization" # Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`