From 5ffde172b51b3a1013e983963d2e85b1484694ab Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Fri, 15 Aug 2025 15:03:47 +0800 Subject: [PATCH 1/3] [main][quantization] Adapt to the new format of ds w4a8 quantization weights Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/quantization/test_w4a8_dynamic.py | 121 ++++++++++++++++----- vllm_ascend/quantization/quant_config.py | 7 +- vllm_ascend/quantization/w4a8_dynamic.py | 118 +++++++++++++------- 3 files changed, 177 insertions(+), 69 deletions(-) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 8c52e3252f..c4366703b7 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -1,3 +1,4 @@ +import copy from unittest.mock import Mock, patch import torch @@ -31,79 +32,139 @@ def test_get_pergroup_param(self): class TestAscendW4A8DynamicFusedMoEMethod(TestBase): + experts = 8 + input_size = 16 + output_size = 56 + group_size = 2 + @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group') @patch("vllm_ascend.ascend_config.get_ascend_config") @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group') @patch('torch.distributed.get_rank', return_value=0) def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config, - mock_get_ep_group): + mock_get_ep_group, get_current_vllm_config): mock_ascend_config = Mock() mock_ascend_config.torchair_graph_config = Mock(enabled=False) mock_get_ascend_config.return_value = mock_ascend_config + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock(quant_description={ + "group_size": self.group_size, + "version": "0.0.0" + }) + mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) + get_current_vllm_config.return_value = mock_vllm_config self.quant_method = AscendW4A8DynamicFusedMoEMethod() def test_get_weight(self): - param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16) + # old quant version weight + param_dict = self.quant_method.get_weight(self.experts, + self.input_size, + self.output_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual(param_dict["w13_weight"].shape, + (self.experts, 2 * self.input_size, self.output_size)) + # new quant version weight + self.quant_method.new_quant_version = True + param_dict = self.quant_method.get_weight(self.experts, + self.input_size, + self.output_size, + torch.bfloat16) self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) - self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14)) + self.assertEqual(param_dict["w13_weight"].shape, + (self.experts, self.input_size, self.output_size)) - @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') - def test_get_dynamic_quant_param(self, mock_get_current_vllm_config): - mock_vllm_config = Mock() - mock_vllm_config.quant_config = Mock( - quant_description={"group_size": 2}) - mock_get_current_vllm_config.return_value = mock_vllm_config + def test_get_dynamic_quant_param(self): + # old quant version weight param_dict = self.quant_method.get_dynamic_quant_param( - 8, 4, 14, torch.bfloat16) + self.experts, self.input_size, self.output_size, torch.bfloat16) self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1)) + self.assertEqual(param_dict["w13_weight_scale"].shape, + (self.experts, 2 * self.input_size, 1)) self.assertEqual(param_dict["w13_weight_scale_second"].dtype, torch.bfloat16) self.assertEqual(param_dict["w13_weight_scale_second"].shape, - (8, 8, 7)) + (self.experts, 2 * self.input_size, + self.output_size // self.group_size)) self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) - self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1)) + self.assertEqual(param_dict["w2_weight_scale"].shape, + (self.experts, self.output_size, 1)) self.assertEqual(param_dict["w2_weight_scale_second"].dtype, torch.bfloat16) self.assertEqual(param_dict["w2_weight_scale_second"].shape, - (8, 14, 2)) + (self.experts, self.output_size, + self.input_size // self.group_size)) + # new quant version weight + self.quant_method.new_quant_version = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32) + self.assertEqual( + param_dict["w2_scale_bias"].shape, + (self.experts, self.output_size, 16 // self.quant_method.tp_size)) @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): + # old quant version weight layer = torch.nn.Module() - layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14), - dtype=torch.int8), + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4), - dtype=torch.int8), + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (8, 8, 1), dtype=torch.bfloat16), + (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), requires_grad=False) - layer.w13_weight_offset = torch.nn.Parameter(torch.zeros( - (8, 8, 1), dtype=torch.bfloat16), - requires_grad=False) layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( - (8, 8, 7), dtype=torch.bfloat16), + (self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.bfloat16), requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (8, 14, 1), dtype=torch.bfloat16), + (self.experts, self.output_size, 1), dtype=torch.bfloat16), requires_grad=False) - layer.w2_weight_offset = torch.nn.Parameter(torch.zeros( - (8, 14, 1), dtype=torch.bfloat16), - requires_grad=False) layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( - (8, 14, 2), dtype=torch.bfloat16), + (self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.bfloat16), requires_grad=False) + new_layer = copy.deepcopy(layer) mock_npu.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor() self.quant_method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "w13_scale_bias")) - self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8)) + self.assertEqual(layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32) self.assertTrue(hasattr(layer, "w2_scale_bias")) - self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14)) + self.assertEqual(layer.w2_scale_bias.data.shape, + (self.experts, self.output_size)) self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.quant_method.new_quant_version = True + new_layer.w13_weight.data = torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8) + new_layer.w2_weight.data = torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8) + w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), + dtype=torch.float32) + new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros( + (self.experts, self.output_size, 16 // self.quant_method.tp_size), + dtype=torch.float32) + new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + self.quant_method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) + self.assertEqual(new_layer.w2_scale_bias.data.shape, + (self.experts, self.output_size)) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index abd7625ec1..ee6793b718 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -44,7 +44,7 @@ @register_quantization_config(ASCEND_QUATIZATION_METHOD) class AscendQuantConfig(QuantizationConfig): """Config class for Ascend - + This class is a general class that parse quantization configs that are supported on ascend hardware. """ @@ -295,6 +295,9 @@ def create_weights( extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + per_group_param = [ + "weight_scale_second", "weight_offset_second", "scale_bias" + ] dynamic_quant_param = self.quant_method.get_dynamic_quant_param( num_experts, intermediate_size_per_partition, hidden_size, params_dtype) @@ -302,7 +305,7 @@ def create_weights( param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) - if "weight_scale_second" in param_key or "weight_offset_second" in param_key: + if any(fields in param_key for fields in per_group_param): setattr(param, "quant_method", FusedMoeWeightScaleSupported.GROUP.value) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index ce238f40ad..f7d838dd32 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -136,6 +136,18 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + # NOTE: new quantize weights: 2 int4 pack into int8 + self.new_quant_version = quant_version == "1.0.0" + self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size + if self.new_quant_version and self.tp_size > 16: + raise ValueError( + "The current weight does not support moe part tp>16.") + try: device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group @@ -146,32 +158,32 @@ def __init__(self): except AttributeError: self.moe_all_to_all_group_name = "" - @staticmethod - def get_weight(num_experts: int, intermediate_size_per_partition: int, - hidden_sizes: int, + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = {} + if self.new_quant_version: + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_sizes // 2 + else: + w13_output_size = 2 * intermediate_size_per_partition + w2_output_size = hidden_sizes + param_dict["w13_weight"] = torch.empty(num_experts, - 2 * - intermediate_size_per_partition, + w13_output_size, hidden_sizes, dtype=torch.int8) param_dict["w2_weight"] = torch.empty(num_experts, - hidden_sizes, + w2_output_size, intermediate_size_per_partition, dtype=torch.int8) return param_dict - @staticmethod - def get_dynamic_quant_param(num_experts: int, + def get_dynamic_quant_param(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = {} - config = get_current_vllm_config() - group_size = config.quant_config.quant_description.get( - "group_size", 256) - param_dict["w13_weight_scale"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, @@ -187,13 +199,13 @@ def get_dynamic_quant_param(num_experts: int, param_dict["w13_weight_scale_second"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, - hidden_sizes // group_size, + hidden_sizes // self.group_size, dtype=params_dtype) param_dict["w13_weight_offset_second"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, - hidden_sizes // group_size, + hidden_sizes // self.group_size, dtype=params_dtype) param_dict["w2_weight_scale"] = torch.empty(num_experts, @@ -207,14 +219,25 @@ def get_dynamic_quant_param(num_experts: int, param_dict["w2_weight_scale_second"] = torch.empty( num_experts, hidden_sizes, - intermediate_size_per_partition // group_size, + intermediate_size_per_partition // self.group_size, dtype=params_dtype) param_dict["w2_weight_offset_second"] = torch.empty( num_experts, hidden_sizes, - intermediate_size_per_partition // group_size, + intermediate_size_per_partition // self.group_size, dtype=params_dtype) + if self.new_quant_version: + param_dict["w13_scale_bias"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w2_scale_bias"] = torch.empty(num_experts, + hidden_sizes, + 16 // self.tp_size, + dtype=torch.float32) + return param_dict def apply( @@ -320,12 +343,17 @@ def apply( def process_scale(self, weight: torch.Tensor, scale, per_group_scale): group_num, k, n = weight.shape + # the weight of the new version is reduced by half by pack n, so it needs to be restored + if self.new_quant_version: + n = n * 2 per_group_scale = per_group_scale.reshape(group_num, -1, n) group_num, quantgroup_num, n = per_group_scale.shape - weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ - per_group_scale.reshape([group_num, quantgroup_num, 1, n]) - weight_high = weight_high.reshape([group_num, k, n]) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + bias = None + if not self.new_quant_version: + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) scale_fp32 = (scale * per_group_scale).to(torch.float16).to( torch.float32) scale_fp32_np = scale_fp32.cpu().numpy() @@ -342,6 +370,32 @@ def process_scale(self, weight: torch.Tensor, scale, per_group_scale): sscale_uint64_tensor = sscale_uint64_tensor.npu() return sscale_uint64_tensor, bias + def update_bias(self, layer, w13_bias, w2_bias): + if self.new_quant_version: + layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + else: + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + def pack_to_int32(self, weight: torch.Tensor): + if self.new_quant_version: + group_num, k, n = weight.shape + assert n % 4 == 0, "the last dim of weight needs to be divided by 4" + packed_n = n // 4 + # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 + packed_weight = torch.from_numpy( + np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) + return packed_weight.reshape(group_num, k, packed_n).npu() + else: + return torch_npu.npu_quantize(weight.to(torch.float32), + torch.tensor([1.]).npu(), None, + torch.quint4x2, -1, False) + def process_weights_after_loading(self, layer): if self.transpose_weight: layer.w13_weight.data = layer.w13_weight.data.transpose( @@ -352,29 +406,19 @@ def process_weights_after_loading(self, layer): 1, 2).contiguous() layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( 1, 2).contiguous() - layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( - layer.w13_weight_offset.data.shape[0], -1) - layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( - layer.w2_weight_offset.data.shape[0], -1) layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( 1, 2).contiguous() layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( 1, 2).contiguous() - layer.w13_weight_scale_second.data, bias = self.process_scale( + layer.w13_weight_scale_second.data, w13_bias = self.process_scale( layer.w13_weight, layer.w13_weight_scale.data, layer.w13_weight_scale_second.data) - param = torch.nn.Parameter(bias, requires_grad=False) - layer.register_parameter("w13_scale_bias", param) - layer.w2_weight_scale_second.data, bias1 = self.process_scale( + layer.w2_weight_scale_second.data, w2_bias = self.process_scale( layer.w2_weight, layer.w2_weight_scale.data, layer.w2_weight_scale_second.data) - param = torch.nn.Parameter(bias1, requires_grad=False) - layer.register_parameter("w2_scale_bias", param) - - layer.w13_weight.data = torch_npu.npu_quantize( - layer.w13_weight.data.to(torch.float32), - torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) - layer.w2_weight.data = torch_npu.npu_quantize( - layer.w2_weight.data.to(torch.float32), - torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) + + self.update_bias(layer, w13_bias, w2_bias) + + layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) From 13d6fe0d1b33998b89cca5dd639bf1f1187781b5 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Fri, 15 Aug 2025 22:34:23 +0800 Subject: [PATCH 2/3] add e2e Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- .../e2e/multicard/test_offline_inference_distributed.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 459d826972..e869c2d599 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -31,6 +31,10 @@ from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +DEEPSEEK_W4A8_MODELS = [ + "vllm-ascend/DeepSeek-V3-W4A8-Pruing", + "vllm-ascend/DeepSeek-R1-w4a8-pruning" +] def test_models_distributed_QwQ(): @@ -211,14 +215,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS) @patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"}) -def test_models_distributed_DeepSeek_W4A8DYNAMIC(): +def test_models_distributed_DeepSeek_W4A8DYNAMIC(model): prompts = [ "Hello, my name is", ] max_tokens = 5 with VllmRunner( - snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"), + snapshot_download(model), dtype="auto", tensor_parallel_size=2, quantization="ascend", From 547075a04c0bf59fd31fd5742c786c0130517d11 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Wed, 20 Aug 2025 18:39:49 +0800 Subject: [PATCH 3/3] trigger ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/quantization/test_w4a8_dynamic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index c4366703b7..7bee119312 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -57,7 +57,7 @@ def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config, self.quant_method = AscendW4A8DynamicFusedMoEMethod() def test_get_weight(self): - # old quant version weight + # old quant version w4a8 weight param_dict = self.quant_method.get_weight(self.experts, self.input_size, self.output_size,