|
24 | 24 |
|
25 | 25 | from tests.ut.base import TestBase |
26 | 26 | from vllm_ascend.ascend_forward_context import MoECommType |
27 | | -from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod |
28 | | -from vllm_ascend.ops.moe.experts_selector import select_experts |
29 | | -from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp |
| 27 | +from vllm_ascend.ops.fused_moe.experts_selector import select_experts |
| 28 | +from vllm_ascend.ops.fused_moe.fused_moe import ( |
| 29 | + AscendFusedMoE, AscendUnquantizedFusedMoEMethod) |
| 30 | +from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list, |
| 31 | + unified_apply_mlp) |
30 | 32 | from vllm_ascend.utils import AscendSocVersion, adapt_patch |
31 | 33 |
|
32 | 34 | adapt_patch(True) |
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture): |
69 | 71 | mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) |
70 | 72 | mock_vllm_config.model_config.max_model_len = 2048 |
71 | 73 |
|
72 | | - mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config', |
73 | | - return_value=mock_vllm_config) |
74 | | - mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', |
| 74 | + mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config', |
75 | 75 | return_value=mock_vllm_config) |
| 76 | + mocker.patch( |
| 77 | + 'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config', |
| 78 | + return_value=mock_vllm_config) |
76 | 79 |
|
77 | 80 |
|
78 | 81 | @pytest.fixture |
@@ -105,37 +108,37 @@ def mock_finalize(hidden_states, **kwargs): |
105 | 108 |
|
106 | 109 | with patch('torch.distributed.get_rank', return_value=0), \ |
107 | 110 | patch('torch.distributed.get_world_size', return_value=4), \ |
108 | | - patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
109 | | - patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
110 | | - patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
111 | | - patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
| 111 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
| 112 | + patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
| 113 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
| 114 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
112 | 115 | patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
113 | | - patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
| 116 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
114 | 117 | patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
115 | 118 | patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', |
116 | 119 | return_value=mock_dp_and_tp_group(mocker)), \ |
117 | | - patch('vllm_ascend.ops.common_fused_moe.get_ascend_config', |
| 120 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config', |
118 | 121 | return_value=MagicMock( |
119 | 122 | torchair_graph_config=MagicMock(enabled=False), |
120 | 123 | enable_multistream_moe=False, |
121 | 124 | expert_map_path=None |
122 | 125 | )), \ |
123 | | - patch('vllm_ascend.ops.common_fused_moe.determine_expert_map', |
| 126 | + patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map', |
124 | 127 | return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ |
125 | | - patch('vllm_ascend.ops.common_fused_moe.get_forward_context', |
| 128 | + patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context', |
126 | 129 | return_value=mock_forward_context_obj), \ |
127 | | - patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', |
| 130 | + patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context', |
128 | 131 | return_value=mock_forward_context_obj), \ |
129 | 132 | patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ |
130 | | - patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', |
| 133 | + patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context', |
131 | 134 | return_value=mock_forward_context_obj), \ |
132 | | - patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', |
| 135 | + patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', |
133 | 136 | return_value=None), \ |
134 | | - patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', |
| 137 | + patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', |
135 | 138 | return_value=None), \ |
136 | | - patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', |
| 139 | + patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', |
137 | 140 | return_value=None), \ |
138 | | - patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', |
| 141 | + patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', |
139 | 142 | return_value=mock_forward_context_obj): |
140 | 143 |
|
141 | 144 | yield { |
@@ -319,8 +322,8 @@ def test_cumsum_group_list_with_type_2(self): |
319 | 322 |
|
320 | 323 | class TestUnifiedApplyMLP(TestBase): |
321 | 324 |
|
322 | | - @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') |
323 | | - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') |
| 325 | + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') |
| 326 | + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') |
324 | 327 | @patch('torch_npu.npu_grouped_matmul') |
325 | 328 | @patch('torch_npu.npu_dynamic_quant') |
326 | 329 | @patch('torch_npu.npu_dequant_swiglu_quant') |
@@ -384,7 +387,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, |
384 | 387 |
|
385 | 388 | self.assertEqual(result.dtype, torch.bfloat16) |
386 | 389 |
|
387 | | - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') |
| 390 | + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') |
388 | 391 | @patch('torch_npu.npu_grouped_matmul') |
389 | 392 | @patch('torch_npu.npu_swiglu') |
390 | 393 | @patch('torch_npu.npu_dynamic_quant') |
@@ -426,7 +429,7 @@ def test_unified_apply_mlp_without_quantization(self, |
426 | 429 | self.assertEqual(result.shape, hidden_states.shape) |
427 | 430 | self.assertEqual(result.dtype, torch.float16) |
428 | 431 |
|
429 | | - @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') |
| 432 | + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') |
430 | 433 | @patch('torch_npu.npu_grouped_matmul') |
431 | 434 | @patch('torch_npu.npu_swiglu') |
432 | 435 | @patch('torch_npu.npu_dynamic_quant') |
@@ -486,7 +489,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale( |
486 | 489 | self.assertEqual(result.shape, hidden_states.shape) |
487 | 490 | self.assertEqual(result.dtype, torch.bfloat16) |
488 | 491 |
|
489 | | - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') |
| 492 | + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') |
490 | 493 | @patch('torch_npu.npu_grouped_matmul') |
491 | 494 | @patch('torch_npu.npu_swiglu') |
492 | 495 | @patch('torch_npu.npu_dynamic_quant') |
@@ -531,7 +534,7 @@ def test_unified_apply_mlp_without_quantization_310p( |
531 | 534 | self.assertEqual(result.shape, hidden_states.shape) |
532 | 535 | self.assertEqual(result.dtype, torch.float16) |
533 | 536 |
|
534 | | - @patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context") |
| 537 | + @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context") |
535 | 538 | @patch("torch_npu.npu_grouped_matmul") |
536 | 539 | @patch("torch_npu.npu_swiglu") |
537 | 540 | @patch("torch_npu.npu_grouped_matmul_swiglu_quant") |
@@ -595,3 +598,39 @@ def test_unified_apply_mlp_with_quantization_and_fusion_mlp( |
595 | 598 | self.assertTrue(mock_forward_context.with_quant) |
596 | 599 | self.assertEqual(result.shape, hidden_states.shape) |
597 | 600 | self.assertEqual(result.dtype, torch.bfloat16) |
| 601 | + |
| 602 | + |
| 603 | +class TestLoadWeight(TestBase): |
| 604 | + |
| 605 | + def test_load_w13_transpose(self): |
| 606 | + with patch.object(AscendFusedMoE, "__init__", |
| 607 | + lambda self, *args, **kwargs: None): |
| 608 | + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) |
| 609 | + |
| 610 | + expert_data = torch.randn(128, 8) |
| 611 | + loaded_weight = torch.randn(128, 4) |
| 612 | + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) |
| 613 | + |
| 614 | + expert_data = torch.randn(8, 128) |
| 615 | + loaded_weight = torch.randn(128, 4) |
| 616 | + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) |
| 617 | + |
| 618 | + expert_data = torch.randn(128, 8) |
| 619 | + loaded_weight = torch.randn(128, 4) |
| 620 | + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) |
| 621 | + |
| 622 | + expert_data = torch.randn(8, 128) |
| 623 | + loaded_weight = torch.randn(128, 4) |
| 624 | + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) |
| 625 | + |
| 626 | + def test_load_w2_transpose(self): |
| 627 | + with patch.object(AscendFusedMoE, "__init__", |
| 628 | + lambda self, *args, **kwargs: None): |
| 629 | + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) |
| 630 | + expert_data = torch.randn(128, 4) |
| 631 | + loaded_weight = torch.randn(128, 8) |
| 632 | + moe._load_w2(expert_data, 1, loaded_weight, 0) |
| 633 | + |
| 634 | + expert_data = torch.randn(4, 128) |
| 635 | + loaded_weight = torch.randn(128, 8) |
| 636 | + moe._load_w2(expert_data, 1, loaded_weight, 0) |
0 commit comments