Skip to content

[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in W8A8_DYNAMIC quantized MoE layers #2275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
a5aefdf
feat(performance): support `GroupedMatmulSwigluQuant` in `W8A8_DYNAMI…
zhoux77899 Aug 8, 2025
0f688cd
fix(lint): fix lint
zhoux77899 Aug 8, 2025
840c03f
fix(bug): fix bug
zhoux77899 Aug 8, 2025
cdf5e1e
feat(ops): enable grouped_matmul_swiglu_quant by default
zhoux77899 Aug 8, 2025
c3c0913
fix(lint): fix lint
zhoux77899 Aug 8, 2025
f05687f
fix(test): fix broken test
zhoux77899 Aug 8, 2025
4f3afe6
fix(lint): fix lint
zhoux77899 Aug 8, 2025
3b32dc8
fix(test): temporally skip broken test due to oom
zhoux77899 Aug 9, 2025
a3c9b44
fix(test): change bias1 to tensor
zhoux77899 Aug 9, 2025
67e9872
Merge branch 'main' into main_gmmswigluquant
zhoux77899 Aug 11, 2025
68e31db
fix(bug): update group_list handling and weight scale in dynamic methods
zhoux77899 Aug 11, 2025
a3715ec
fix(lint): fix lint
zhoux77899 Aug 11, 2025
58d6371
fix(lint): fix lint
zhoux77899 Aug 11, 2025
a46315d
feat(ops): replace all splited gmm and swiglu
zhoux77899 Aug 12, 2025
5ee5a83
Merge branch 'main_gmmswigluquant' of https://github.com/zhoux77899/v…
zhoux77899 Aug 12, 2025
0ea5246
fix(lint): fix lint
zhoux77899 Aug 12, 2025
d9b16fc
feat(quantization): split w4a8 and w8a8 apply
zhoux77899 Aug 12, 2025
9ade98e
fix(test): replace w8a8 function in apply
zhoux77899 Aug 12, 2025
6af87be
feat(cumsum): add cumsum_group_list function for group list processing
zhoux77899 Aug 13, 2025
aed8264
fix(lint): fix lint
zhoux77899 Aug 13, 2025
c40f61d
fix(lint): fix lint
zhoux77899 Aug 13, 2025
c437d39
feat(torchair): consider not using gmmswigluquant when torchair enabled
zhoux77899 Aug 14, 2025
3309579
fix(lint): fix lint
zhoux77899 Aug 14, 2025
04523e6
Merge branch 'main' into main_gmmswigluquant
zhoux77899 Aug 14, 2025
3d2b849
fix(dtype): unify `w1_scale` dtype
zhoux77899 Aug 14, 2025
51ec3d8
fix(lint): fix lint
zhoux77899 Aug 14, 2025
4705afb
fix(bias): unify `bias1`
zhoux77899 Aug 14, 2025
bd74a40
test(ut): add w8a8_dynamic ut
zhoux77899 Aug 14, 2025
b627283
fix(lint): fix lint
zhoux77899 Aug 14, 2025
cddcf20
fix(ut): fix broken ut
zhoux77899 Aug 14, 2025
402760a
fix(lint): fix lint
zhoux77899 Aug 14, 2025
cda8d3b
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 15, 2025
6a5fb6f
test(ci): add `w4a8_dynamic` ut
zhoux77899 Aug 15, 2025
b341141
fix(test): fix broken ut
zhoux77899 Aug 15, 2025
064217e
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 15, 2025
826d713
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 19, 2025
cbbaae4
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 20, 2025
e648aa1
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 20, 2025
eb6b0d5
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 20, 2025
fd675eb
Merge branch 'vllm-project:main' into main_gmmswigluquant
zhoux77899 Aug 20, 2025
b0585aa
Merge branch 'main' into main_gmmswigluquant
zhoux77899 Aug 20, 2025
e2d7cc3
fix(test): fix broken ut
zhoux77899 Aug 20, 2025
90ea998
fix(lint): fix lint
zhoux77899 Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod, apply_mlp,
apply_mlp_decode)


class TestAscendW4A8DynamicLinearMethod(TestBase):
Expand Down Expand Up @@ -104,9 +105,11 @@ def test_get_dynamic_quant_param(self):
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))

@patch('torch_npu.npu_format_cast_')
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
def test_process_weights_after_loading(self, mock_npu_format_cast,
mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros(
Expand Down Expand Up @@ -137,6 +140,7 @@ def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):

mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
mock_npu_format_cast.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
Expand Down Expand Up @@ -168,3 +172,68 @@ def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))

@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_dynamic_quant")
def test_apply_mlp(
self,
mock_dynamic_quant,
mock_grouped_matmul,
mock_swiglu,
):
placeholder = torch.randn(128, 128, dtype=torch.bfloat16)
placeholder_int8 = torch.randint(0, 100, (128, 128), dtype=torch.int8)
placeholder_ones = torch.ones(128, dtype=torch.int32)

mock_dynamic_quant.return_value = (
placeholder_int8,
placeholder_ones,
)
mock_grouped_matmul.return_value = [placeholder]
mock_swiglu.return_value = placeholder

result = apply_mlp(
hidden_states=placeholder,
w1=placeholder,
w1_scale=placeholder,
w2=placeholder,
w2_scale=placeholder,
group_list=placeholder,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)

@patch("torch_npu.npu_dequant_swiglu_quant")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_dynamic_quant")
def test_apply_mlp_decode(
self,
mock_dynamic_quant,
mock_grouped_matmul,
mock_dequant_swiglu_quant,
):
placeholder = torch.randn(128, 128, dtype=torch.bfloat16)
placeholder_int8 = torch.randint(0, 100, (128, 128), dtype=torch.int8)
placeholder_ones = torch.ones(128, dtype=torch.int32)

mock_dynamic_quant.return_value = (
placeholder_int8,
placeholder_ones,
)
mock_grouped_matmul.return_value = [placeholder]
mock_dequant_swiglu_quant.return_value = (
placeholder_int8,
placeholder_ones,
)

result = apply_mlp_decode(
hidden_states=placeholder,
w1=placeholder,
w1_scale=placeholder,
w2=placeholder,
w2_scale=placeholder,
group_list=placeholder,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
209 changes: 199 additions & 10 deletions tests/ut/quantization/test_w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,159 @@
import torch

from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all
from vllm_ascend.quantization.w8a8_dynamic import (
fused_experts, fused_experts_with_all2all, fused_experts_with_allgather,
fused_experts_with_mc2)


class TestAscendW8A8FusedMoEMethod(TestBase):

def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.top_k = 8
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)

@patch("torch_npu.npu_grouped_matmul_finalize_routing")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch.zeros")
@patch("torch_npu.npu_moe_init_routing_v2")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch.distributed.get_world_size")
@patch("torch.distributed.get_rank")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
def test_fused_experts_with_allgather(
self,
mock_get_ep_group,
mock_get_rank,
mock_get_world_size,
mock_dynamic_quant,
mock_moe_init_routing_v2,
mock_zeros,
mock_grouped_matmul_swiglu_quant,
mock_grouped_matmul_finalize_routing,
):
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)

expert_map = MagicMock()
ep_group = MagicMock()
ep_group.device_group = "ep_group"

mock_get_ep_group.return_value = ep_group
mock_get_rank.return_value = 0
mock_get_world_size.return_value = 1
mock_dynamic_quant.return_value = (
placeholder_int8,
placeholder_ones,
)
mock_moe_init_routing_v2.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
self.placeholder,
)
mock_zeros.return_value = torch.zeros(
(self.num_tokens, self.hidden_size), dtype=torch.bfloat16)
mock_grouped_matmul_swiglu_quant.return_value = (
placeholder_int8,
self.placeholder,
self.placeholder,
)
mock_grouped_matmul_finalize_routing.return_value = self.placeholder

result = fused_experts_with_allgather(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=self.top_k,
expert_map=expert_map,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

@patch("torch_npu.npu_moe_distribute_combine_v2", create=True)
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch_npu.npu_moe_distribute_dispatch_v2", create=True)
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_soc_version")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
def test_fused_experts_with_mc2(
self,
mock_get_mc2_group,
mock_get_ascend_soc_version,
mock_dispatch_v2,
mock_grouped_matmul_swiglu_quant,
mock_grouped_matmul,
mock_combine_v2,
):
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)

expert_map = MagicMock()
ep_group = MagicMock()
ep_group.rank_in_group = 0
ep_group.world_size = 1
mock_get_mc2_group.return_value = ep_group
mock_get_ascend_soc_version.return_value = MagicMock()
mock_dispatch_v2.return_value = (
self.placeholder,
self.placeholder,
self.placeholder,
placeholder_ones,
placeholder_ones,
)
mock_grouped_matmul_swiglu_quant.return_value = (
placeholder_int8,
self.placeholder,
self.placeholder,
)
mock_grouped_matmul.return_value = self.placeholder
mock_combine_v2.return_value = self.placeholder

result = fused_experts_with_mc2(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=self.top_k,
expert_map=expert_map,
moe_all_to_all_group_name="group",
log2phy=None,
global_redundant_expert_num=256,
mc2_mask=self.placeholder,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
def test_fused_experts_with_all2all(
self, mock_moe_init_routing, mock_moe_finalize_routing,
mock_dynamic_quant, mock_grouped_matmul_swiglu_quant,
mock_grouped_matmul, mock_moe_re_routing, mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
Expand All @@ -49,7 +177,11 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_grouped_matmul_swiglu_quant.return_value = (
placeholder_int8,
self.placeholder,
self.placeholder,
)
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
Expand All @@ -64,7 +196,7 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
top_k=self.top_k,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
Expand All @@ -73,3 +205,60 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing,
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_compute_expert_tokens")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts(
self,
mock_moe_init_routing,
mock_moe_compute_expert_tokens,
mock_dynamic_quant,
mock_grouped_matmul_swiglu_quant,
mock_grouped_matmul,
mock_moe_finalize_routing,
):
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)

mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_compute_expert_tokens.return_value = placeholder_ones
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_grouped_matmul_swiglu_quant.return_value = (
placeholder_int8,
self.placeholder,
self.placeholder,
)
mock_grouped_matmul_swiglu_quant.return_value = (
placeholder_int8,
self.placeholder,
self.placeholder,
)
mock_moe_finalize_routing.return_value = self.placeholder

result = fused_experts(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=self.top_k,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))
Loading
Loading