44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import platform
8- import sys
97from copy import deepcopy
108
119import pytest
1513 StretchedIntxWeightConfig ,
1614 StretchedUnifTorchaoQuantizer ,
1715)
18- from torchao .prototype .quantization .dynamic_activation_lut import (
19- StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig ,
16+ from torchao .prototype .quantization .int8_lut_tensor . int8_lut_tensor import (
17+ _is_kernel_library_loaded ,
2018)
19+ from torchao .prototype .tensor_conversion .api import _convert_model_for_aarch64
2120from torchao .quantization import quantize_
2221from torchao .quantization .granularity import PerAxis , PerGroup
23- from torchao .quantization .quant_api import _is_linear
2422from torchao .quantization .utils import compute_error
2523
26- is_arm64_mac = sys .platform == "darwin" and platform .machine () == "arm64"
27-
2824
2925class ToyLinearModel (torch .nn .Module ):
3026 def __init__ (self , d1 = 512 , d2 = 256 , d3 = 128 , d4 = 8 ):
@@ -59,7 +55,9 @@ def run_before_and_after_tests():
5955@pytest .mark .parametrize ("granularity" , [PerGroup (32 ), PerAxis (0 )])
6056@pytest .mark .parametrize ("bit_width" , [1 , 2 , 3 , 4 ])
6157@pytest .mark .parametrize ("lead_dim" , [(5 ,), (2 , 3 )])
62- @pytest .mark .skipif (not is_arm64_mac , reason = "requires arm64 mac" )
58+ @pytest .mark .skipif (
59+ not _is_kernel_library_loaded (), reason = "Kernel library is not loaded"
60+ )
6361def test_parq_conversion (dtype , granularity , bit_width , lead_dim ):
6462 torch .manual_seed (0 )
6563 quantizer = StretchedUnifTorchaoQuantizer (bit_width )
@@ -68,38 +66,22 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
6866 quant_min = quantizer .quant_min ,
6967 quant_max = quantizer .quant_max ,
7068 granularity = granularity ,
71- activation_quantization = None ,
72- version = 1 ,
69+ activation_quantization = "int8_asym_per_token" ,
7370 )
7471
7572 parq_model = ToyLinearModel (128 , 256 , 128 , 1 ).to (dtype )
7673 activations = parq_model .example_inputs (lead_dim = lead_dim , dtype = dtype )
77- parq_model_with_dyn_quant = deepcopy (parq_model )
7874 quantize_ (parq_model , config )
7975
80- # Apply dynamic activation to parq model. This will serve as the LUT reference
81- dyn_act_config = deepcopy (config )
82- dyn_act_config .activation_quantization = "int8_asym_per_token"
83- quantize_ (parq_model_with_dyn_quant , dyn_act_config , filter_fn = _is_linear )
84-
8576 # Convert PARQ model to lowbit LUT model
8677 lut_model = deepcopy (parq_model )
87- conversion_config = (
88- StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig (
89- config .b , config .granularity
90- )
91- )
92- quantize_ (lut_model , conversion_config , filter_fn = conversion_config .get_filter_fn ())
78+ _convert_model_for_aarch64 (lut_model , tensor_type = "int8_lut_tensor" )
9379
9480 # Run both models and compare
9581 parq_out = parq_model (activations )
96- parq_with_dyn_quant_out = parq_model_with_dyn_quant (activations )
9782 lut_out = lut_model (activations )
9883
99- sqnr = compute_error (parq_out , parq_with_dyn_quant_out ).item ()
100- assert sqnr > 20.0 , f"sqnr { sqnr } is too low"
101-
102- sqnr = compute_error (lut_out , parq_with_dyn_quant_out ).item ()
84+ sqnr = compute_error (parq_out , lut_out ).item ()
10385 if dtype == torch .float32 :
10486 assert sqnr > 40.0 , f"sqnr { sqnr } is too low"
10587 elif dtype == torch .bfloat16 :
@@ -112,32 +94,27 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
11294@pytest .mark .parametrize ("granularity" , [PerGroup (32 ), PerAxis (0 )])
11395@pytest .mark .parametrize ("bit_width" , [1 , 2 , 3 , 4 ])
11496@pytest .mark .parametrize ("lead_dim" , [(5 ,), (2 , 3 )])
115- @pytest .mark .skipif (not is_arm64_mac , reason = "requires arm64 mac" )
97+ @pytest .mark .skipif (
98+ not _is_kernel_library_loaded (), reason = "Kernel library is not loaded"
99+ )
116100def test_export (dtype , granularity , bit_width , lead_dim ):
117101 quantizer = StretchedUnifTorchaoQuantizer (bit_width )
118102 config = StretchedIntxWeightConfig (
119103 b = bit_width ,
120104 quant_min = quantizer .quant_min ,
121105 quant_max = quantizer .quant_max ,
122106 granularity = granularity ,
123- activation_quantization = None ,
124- version = 1 ,
107+ activation_quantization = "int8_asym_per_token" ,
125108 )
126109
127110 parq_model = ToyLinearModel (128 , 256 , 128 , 8 ).to (dtype )
128111 activations = parq_model .example_inputs (lead_dim = lead_dim )
129112 quantize_ (parq_model , config )
130113
131- conversion_config = (
132- StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig (
133- config .b , config .granularity
134- )
135- )
136- quantize_ (
137- parq_model , conversion_config , filter_fn = conversion_config .get_filter_fn ()
138- )
114+ _convert_model_for_aarch64 (parq_model )
139115
140116 ep = torch .export .export (parq_model , (activations ,))
117+
141118 assert (
142119 f"torch.ops.torchao._linear_8bit_act_{ bit_width } bit_weight.default"
143120 in ep .graph_module .code
0 commit comments