Skip to content

Commit 63c363d

Browse files
[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it? 1. Rename common_fused_moe.py to fused_moe.py. 2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize to prepare_finalize.py / PrepareAndFinalize. 3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe. 4. Move vllm_ascend/ops/fused_moe.py to vllm_ascend/ops/fused_moe/fused_moe.py ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@17c540a Signed-off-by: Pr0Wh1teGivee <[email protected]>
1 parent 0637e8f commit 63c363d

25 files changed

+183
-199
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
import torch_npu
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030

31-
from vllm_ascend.ops.moe.experts_selector import select_experts
32-
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
33-
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
31+
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
32+
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
33+
from vllm_ascend.ops.fused_moe.token_dispatcher import \
34+
TokenDispatcherWithAllGather
3435

3536
NUM_EXPERTS = [8, 64]
3637
EP_SIZE = [1]
@@ -182,7 +183,7 @@ def test_token_dispatcher_with_all_gather_quant(
182183
):
183184
context_mock = MagicMock()
184185
context_mock.fused_moe_state = 0
185-
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context",
186+
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
186187
return_value=context_mock):
187188
a = torch.randn((m, k), device=device, dtype=dtype) / 10
188189
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
@@ -282,9 +283,9 @@ def test_select_experts(
282283
dtype=torch.int32)
283284
custom_routing_function.return_value = (mock_weights, mock_ids)
284285

285-
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
286+
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
286287
) as mock_native_grouped_topk, \
287-
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
288+
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
288289
return_value=MagicMock(weight_prefetch_method=MagicMock())):
289290
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
290291
x)
@@ -318,7 +319,7 @@ def test_select_experts(
318319

319320
@pytest.mark.parametrize("device", DEVICE)
320321
def test_select_experts_invalid_scoring_func(device: str):
321-
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
322+
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
322323
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
323324
pytest.raises(ValueError,
324325
match="Unsupported scoring function: invalid"):

tests/ut/models/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def mock_distributed():
9090
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
9191
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
9292

93-
with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
94-
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
95-
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
93+
with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
94+
patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
95+
patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
9696
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
9797
_PP=pp_group), \
9898
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \

tests/ut/ops/test_comm_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytest_mock import MockerFixture
2121

2222
from tests.ut.base import PytestBase
23-
from vllm_ascend.ops.moe.comm_utils import (
23+
from vllm_ascend.ops.fused_moe.comm_utils import (
2424
_gather_along_first_dim, async_all_to_all,
2525
gather_from_sequence_parallel_region)
2626

tests/ut/ops/test_common_fused_moe.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

tests/ut/ops/test_fused_ops.py renamed to tests/ut/ops/test_fused_moe.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
from tests.ut.base import TestBase
2626
from vllm_ascend.ascend_forward_context import MoECommType
27-
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
28-
from vllm_ascend.ops.moe.experts_selector import select_experts
29-
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
27+
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
28+
from vllm_ascend.ops.fused_moe.fused_moe import (
29+
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
30+
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
31+
unified_apply_mlp)
3032
from vllm_ascend.utils import AscendSocVersion, adapt_patch
3133

3234
adapt_patch(True)
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
6971
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
7072
mock_vllm_config.model_config.max_model_len = 2048
7173

72-
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
73-
return_value=mock_vllm_config)
74-
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
74+
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
7575
return_value=mock_vllm_config)
76+
mocker.patch(
77+
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
78+
return_value=mock_vllm_config)
7679

7780

7881
@pytest.fixture
@@ -105,37 +108,37 @@ def mock_finalize(hidden_states, **kwargs):
105108

106109
with patch('torch.distributed.get_rank', return_value=0), \
107110
patch('torch.distributed.get_world_size', return_value=4), \
108-
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
109-
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
110-
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
111-
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
111+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
112+
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
113+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
114+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
112115
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
113-
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
116+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
114117
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
115118
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
116119
return_value=mock_dp_and_tp_group(mocker)), \
117-
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
120+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
118121
return_value=MagicMock(
119122
torchair_graph_config=MagicMock(enabled=False),
120123
enable_multistream_moe=False,
121124
expert_map_path=None
122125
)), \
123-
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
126+
patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
124127
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
125-
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
128+
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
126129
return_value=mock_forward_context_obj), \
127-
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
130+
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
128131
return_value=mock_forward_context_obj), \
129132
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
130-
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
133+
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
131134
return_value=mock_forward_context_obj), \
132-
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
135+
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
133136
return_value=None), \
134-
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
137+
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
135138
return_value=None), \
136-
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
139+
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
137140
return_value=None), \
138-
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
141+
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
139142
return_value=mock_forward_context_obj):
140143

141144
yield {
@@ -319,8 +322,8 @@ def test_cumsum_group_list_with_type_2(self):
319322

320323
class TestUnifiedApplyMLP(TestBase):
321324

322-
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
323-
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
325+
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
326+
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
324327
@patch('torch_npu.npu_grouped_matmul')
325328
@patch('torch_npu.npu_dynamic_quant')
326329
@patch('torch_npu.npu_dequant_swiglu_quant')
@@ -384,7 +387,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
384387

385388
self.assertEqual(result.dtype, torch.bfloat16)
386389

387-
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
390+
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
388391
@patch('torch_npu.npu_grouped_matmul')
389392
@patch('torch_npu.npu_swiglu')
390393
@patch('torch_npu.npu_dynamic_quant')
@@ -426,7 +429,7 @@ def test_unified_apply_mlp_without_quantization(self,
426429
self.assertEqual(result.shape, hidden_states.shape)
427430
self.assertEqual(result.dtype, torch.float16)
428431

429-
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
432+
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
430433
@patch('torch_npu.npu_grouped_matmul')
431434
@patch('torch_npu.npu_swiglu')
432435
@patch('torch_npu.npu_dynamic_quant')
@@ -486,7 +489,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
486489
self.assertEqual(result.shape, hidden_states.shape)
487490
self.assertEqual(result.dtype, torch.bfloat16)
488491

489-
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
492+
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
490493
@patch('torch_npu.npu_grouped_matmul')
491494
@patch('torch_npu.npu_swiglu')
492495
@patch('torch_npu.npu_dynamic_quant')
@@ -531,7 +534,7 @@ def test_unified_apply_mlp_without_quantization_310p(
531534
self.assertEqual(result.shape, hidden_states.shape)
532535
self.assertEqual(result.dtype, torch.float16)
533536

534-
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
537+
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
535538
@patch("torch_npu.npu_grouped_matmul")
536539
@patch("torch_npu.npu_swiglu")
537540
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@@ -595,3 +598,39 @@ def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
595598
self.assertTrue(mock_forward_context.with_quant)
596599
self.assertEqual(result.shape, hidden_states.shape)
597600
self.assertEqual(result.dtype, torch.bfloat16)
601+
602+
603+
class TestLoadWeight(TestBase):
604+
605+
def test_load_w13_transpose(self):
606+
with patch.object(AscendFusedMoE, "__init__",
607+
lambda self, *args, **kwargs: None):
608+
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
609+
610+
expert_data = torch.randn(128, 8)
611+
loaded_weight = torch.randn(128, 4)
612+
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
613+
614+
expert_data = torch.randn(8, 128)
615+
loaded_weight = torch.randn(128, 4)
616+
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
617+
618+
expert_data = torch.randn(128, 8)
619+
loaded_weight = torch.randn(128, 4)
620+
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
621+
622+
expert_data = torch.randn(8, 128)
623+
loaded_weight = torch.randn(128, 4)
624+
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
625+
626+
def test_load_w2_transpose(self):
627+
with patch.object(AscendFusedMoE, "__init__",
628+
lambda self, *args, **kwargs: None):
629+
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
630+
expert_data = torch.randn(128, 4)
631+
loaded_weight = torch.randn(128, 8)
632+
moe._load_w2(expert_data, 1, loaded_weight, 0)
633+
634+
expert_data = torch.randn(4, 128)
635+
loaded_weight = torch.randn(128, 8)
636+
moe._load_w2(expert_data, 1, loaded_weight, 0)

tests/ut/ops/test_moe_comm_method.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
55

66
from tests.ut.base import TestBase
7-
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
8-
AlltoAllCommImpl, MC2CommImpl)
7+
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
8+
AlltoAllCommImpl,
9+
MC2CommImpl)
910

1011

1112
class TestMoECommMethod(TestBase):
@@ -24,12 +25,14 @@ def setUp(self):
2425
self.moe_config.dp_group = MagicMock()
2526
self.moe_config.num_global_redundant_experts = 0
2627

27-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
28-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
28+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
29+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
2930
@patch(
30-
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
31+
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
32+
)
33+
@patch(
34+
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
3135
)
32-
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
3336
def test_all_gather_comm_impl(self, mock_token_dispatcher,
3437
mock_prepare_finalize,
3538
mock_get_forward_context,
@@ -72,12 +75,11 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
7275
context_metadata=context_metadata)
7376
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
7477

75-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
76-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
78+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
79+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
7780
@patch(
78-
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
79-
)
80-
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
81+
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
82+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
8183
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
8284
mock_get_forward_context,
8385
mock_get_current_vllm_config):
@@ -121,12 +123,14 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
121123
context_metadata=context_metadata)
122124
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
123125

124-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
125-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
126+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
127+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
126128
@patch(
127-
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
129+
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
130+
)
131+
@patch(
132+
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
128133
)
129-
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
130134
def test_alltoall_comm_impl(self, mock_token_dispatcher,
131135
mock_prepare_finalize,
132136
mock_get_forward_context,
@@ -163,13 +167,15 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
163167
mock_pf_instance.prepare.assert_called_once_with(
164168
hidden_states, router_logits, False, False, None)
165169

166-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
167-
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
170+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
171+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
172+
@patch(
173+
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
174+
)
168175
@patch(
169-
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
176+
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
170177
)
171-
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
172-
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
178+
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
173179
def test_fused_experts_method(self, mock_unified_apply_mlp,
174180
mock_token_dispatcher, mock_prepare_finalize,
175181
mock_get_forward_context,

0 commit comments

Comments
 (0)