Skip to content

Commit 7c72764

Browse files
author
yangcheng
committed
add ut
Signed-off-by: yangcheng <[email protected]>
1 parent 5f0b42e commit 7c72764

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
from typing import List, TypedDict
16+
from unittest.mock import MagicMock, patch
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
from pytest_mock import MockerFixture
22+
23+
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
24+
AscendUnquantizedFusedMoEMethod)
25+
from vllm_ascend.utils import adapt_patch # noqa E402
26+
27+
adapt_patch(True)
28+
29+
30+
def mock_ep_group(mocker):
31+
mock_group = mocker.MagicMock()
32+
mock_group.rank_in_group = 0
33+
mock_group.rank = 0
34+
mock_group.world_size = 4
35+
mock_group.device_group = "mock_group_ep"
36+
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
37+
return mock_group
38+
39+
40+
def mock_etp_group(mocker):
41+
mock_group = mocker.MagicMock()
42+
mock_group.rank_in_group = 0
43+
mock_group.world_size = 1
44+
mock_group.device_group = "mock_group_etp"
45+
return mock_group
46+
47+
48+
def mock_dp_and_tp_group(mocker):
49+
mock_group = mocker.MagicMock()
50+
mock_group.rank_in_group = 0
51+
mock_group.world_size = 2
52+
mock_group.device_group = "mock_group"
53+
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
54+
return mock_group
55+
56+
57+
@pytest.fixture
58+
def mock_dist_env(mocker: MockerFixture):
59+
# init dist env patch
60+
61+
with patch('torch.distributed.get_rank', return_value=0), \
62+
patch('torch.distributed.get_world_size', return_value=4), \
63+
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
64+
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
65+
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
66+
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
67+
patch('vllm_ascend.ops.fused_moe.get_etp_group', return_value=mock_etp_group(mocker)), \
68+
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
69+
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
70+
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
71+
return_value=torch.randn(5, 32)), \
72+
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
73+
return_value=torch.randn(5, 32)), \
74+
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
75+
return_value=mock_dp_and_tp_group(mocker)), \
76+
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
77+
return_value=MagicMock(
78+
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
79+
expert_map_path=None
80+
)), \
81+
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
82+
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
83+
patch('vllm_ascend.ops.fused_moe.get_forward_context',
84+
return_value=MagicMock(
85+
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
86+
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
87+
)), \
88+
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
89+
return_value=MagicMock(
90+
parallel_config=MagicMock(tensor_parallel_size=2),
91+
scheduler_config=MagicMock(max_num_seqs=4),
92+
model_config=MagicMock(max_model_len=2048)
93+
)):
94+
yield
95+
96+
97+
@pytest.fixture
98+
def mock_moe_env(mocker: MockerFixture):
99+
# init moe env patch
100+
101+
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
102+
torch.randn(8, 2),
103+
torch.randint(0, 8, (8, 2)),
104+
None
105+
)), \
106+
patch('torch_npu.npu_moe_init_routing', return_value=(
107+
torch.randn(8, 2),
108+
torch.randint(0, 8, (8, 2)),
109+
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
110+
)), \
111+
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
112+
torch.randn(8, 2)
113+
)), \
114+
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
115+
torch.randn(16, 2)
116+
)), \
117+
patch("torch_npu.npu_moe_distribute_combine", return_value=(
118+
torch.randn(16, 2)
119+
)), \
120+
patch("torch_npu.npu_grouped_matmul", return_value=(
121+
(torch.randn(8, 2), torch.randn(8, 2))
122+
)), \
123+
patch("torch_npu.npu_swiglu", return_value=(
124+
torch.randn(16, 2)
125+
)), \
126+
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
127+
torch.randn(8, 2),
128+
torch.randint(0, 8, (8, 2)),
129+
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
130+
)), \
131+
patch("torch_npu.npu_moe_finalize_routing", return_value=(
132+
torch.randn(16, 2)
133+
)):
134+
yield
135+
136+
137+
@pytest.fixture
138+
def default_moe_config():
139+
"""default moe config"""
140+
return {
141+
'num_experts': 8,
142+
'top_k': 2,
143+
'hidden_size': 512,
144+
'intermediate_size': 1024
145+
}
146+
147+
148+
@pytest.fixture
149+
def moe_method(mock_dist_env):
150+
return AscendUnquantizedFusedMoEMethod()
151+
152+
153+
class Device(TypedDict):
154+
device_id: int
155+
device_expert: List[int]
156+
157+
158+
class Layer(TypedDict):
159+
layer_id: int
160+
device_count: int
161+
device_list: List[Device]
162+
163+
164+
class MockData(TypedDict):
165+
moe_layer_count: int
166+
layer_list: List[Layer]
167+
168+
169+
class MockQuantMethod(nn.Module):
170+
171+
def __init__(self, shared_experts, num_tokens):
172+
super().__init__()
173+
if shared_experts:
174+
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
175+
torch.randn(num_tokens, 10)))
176+
else:
177+
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
178+
179+
def forward(self, *args, **kwargs):
180+
return self.apply(*args, **kwargs)
181+
182+
183+
class TestAscendFusedMoe:
184+
185+
def test_init_no_quant(self, mock_dist_env, default_moe_config):
186+
layer = AscendFusedMoE(**default_moe_config)
187+
188+
layer.w13_weight = nn.Parameter(
189+
torch.randn(default_moe_config['num_experts'],
190+
default_moe_config['intermediate_size'] * 2,
191+
default_moe_config['hidden_size']))
192+
layer.w2_weight = nn.Parameter(
193+
torch.randn(default_moe_config['num_experts'],
194+
default_moe_config['hidden_size'],
195+
default_moe_config['intermediate_size']))
196+
197+
assert layer.num_experts == default_moe_config['num_experts']
198+
assert layer.top_k == default_moe_config['top_k']
199+
assert hasattr(layer, 'w13_weight')
200+
assert hasattr(layer, 'w2_weight')
201+
assert layer.moe_instance_id == 0
202+
203+
# check group_topk
204+
with pytest.raises(AssertionError):
205+
error_config = default_moe_config.copy()
206+
error_config['use_grouped_topk'] = True
207+
layer = AscendFusedMoE(**error_config)
208+
209+
# check scoring_func
210+
with pytest.raises(ValueError):
211+
error_config = default_moe_config.copy()
212+
error_config['scoring_func'] = "random"
213+
layer = AscendFusedMoE(**error_config)
214+
215+
def test_init_with_quant(self, mock_dist_env, default_moe_config):
216+
mock_quant_config = MagicMock()
217+
mock_quant_method = MagicMock()
218+
mock_quant_config.get_quant_mothod.return_value = mock_quant_method
219+
220+
moe = AscendFusedMoE(**default_moe_config,
221+
quant_config=mock_quant_config)
222+
223+
assert moe.quant_method is not None
224+
225+
@pytest.mark.parametrize(
226+
"others_param",
227+
[[None,
228+
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
229+
[2, None, False, 5, None], [None, None, True, 5, None],
230+
[None, None, False, 1, None], [None, None, True, 5, 1],
231+
[None, None, False, 5, 1]])
232+
def test_forward(self, mock_dist_env, default_moe_config, others_param):
233+
"""
234+
1 test has shared_experts
235+
2 test has top_k
236+
3 test is_prefill is true
237+
4 test single num_tokens(decode)
238+
5 test ep_size is 1 and is_prefill is true
239+
6 test ep_size is 1 and is_prefill is False
240+
"""
241+
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
242+
inputs = torch.randn(num_tokens, 32)
243+
router_logits = torch.randn(num_tokens, 8)
244+
moe = AscendFusedMoE(**default_moe_config)
245+
246+
if ep_size == 1:
247+
moe.moe_parallel_config.ep_size = 1
248+
249+
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
250+
output = moe.forward(inputs,
251+
router_logits,
252+
is_prefill=is_prefill,
253+
top_k=top_k,
254+
shared_experts=shared_experts)
255+
256+
moe.quant_method.apply.assert_called_once()
257+
258+
if shared_experts:
259+
assert output[0].shape == (num_tokens, 32)
260+
assert output[1].shape == (num_tokens, 10)
261+
else:
262+
assert output.shape == (num_tokens, 32)
263+
264+
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
265+
default_moe_config):
266+
inputs = torch.randn(5, 32)
267+
router_logits = torch.randn(5, 8)
268+
moe = AscendFusedMoE(**default_moe_config)
269+
270+
moe.quant_method = MockQuantMethod(None, 5)
271+
output = moe._forward_ms_fused_moe_comp(inputs,
272+
router_logits,
273+
is_prefill=False,
274+
real_top_k=1)
275+
276+
moe.quant_method.apply.assert_called_once()
277+
278+
assert output.shape == (5, 32)
279+
280+
281+
class TestAscendUnquantizedFusedMoEMethod:
282+
283+
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
284+
layer = MagicMock()
285+
layer.w13_weight.data = torch.randn(16, 32)
286+
layer.w2_weight.data = torch.randn(16, 32)
287+
288+
moe_method.process_weights_after_loading(layer)
289+
290+
assert isinstance(layer.w13_weight, torch.nn.Parameter)
291+
assert isinstance(layer.w2_weight, torch.nn.Parameter)
292+
assert not layer.w13_weight.requires_grad
293+
assert not layer.w2_weight.requires_grad
294+
295+
@pytest.mark.parametrize(
296+
"others_param",
297+
[[256, 4, False], [128, 1, False], [128, 1, True], [128, 4, False]])
298+
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
299+
mock_moe_env, others_param):
300+
"""
301+
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
302+
2 test use_select_experts and fused_experts
303+
3 test use select_gating_topk_softmax_experts and fused_experts
304+
4 test use select_experts and fused_experts_with_all2all_buffer
305+
"""
306+
global_num_experts, ep_size, select_softmax = others_param
307+
with patch(
308+
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
309+
select_softmax):
310+
moe_method.ep_group.world_size = ep_size
311+
x = torch.randn(8, 2, 2)
312+
router_logits = torch.randn(8, 8)
313+
layer = MagicMock()
314+
layer.w13_weight = torch.randn(8, 16, 1)
315+
layer.w2_weight = torch.randn(16, 8, 1)
316+
result = moe_method.apply(layer=layer,
317+
x=x,
318+
router_logits=router_logits,
319+
top_k=2,
320+
renormalize=True,
321+
global_num_experts=global_num_experts,
322+
is_prefill=False)
323+
324+
if ep_size == 1:
325+
assert result.shape == (16, 2)
326+
else:
327+
assert result.shape == x.shape
328+
329+
@pytest.mark.parametrize("others_param",
330+
[[16, False], [1, True], [1, False], [4, False]])
331+
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
332+
mock_moe_env, others_param):
333+
"""
334+
1 test use_select_experts and use fused_expters_with_mc2
335+
2 test use_select_experts and fused_experts_with_all2all_buffer
336+
3 test use_select_experts and fused_experts_with_all2all
337+
4 test use_select_experts and fused_experts
338+
"""
339+
ep_size, alltoall_buffer = others_param
340+
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
341+
alltoall_buffer):
342+
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
343+
moe_method.ep_group.world_size = ep_size
344+
x = torch.randn(8, 2, 2)
345+
if ep_size == 1:
346+
x = x.view(-1, 2)
347+
router_logits = torch.randn(8, 8)
348+
if alltoall_buffer:
349+
moe_method.max_model_len = 1
350+
layer = MagicMock()
351+
layer.w13_weight = torch.randn(8, 16, 1)
352+
layer.w2_weight = torch.randn(16, 8, 1)
353+
result = moe_method.apply(layer=layer,
354+
x=x,
355+
router_logits=router_logits,
356+
top_k=2,
357+
renormalize=True,
358+
global_num_experts=128,
359+
expert_map=expert_map,
360+
is_prefill=False)
361+
362+
if ep_size == 16 or ep_size == 1:
363+
assert result.shape == (16, 2)
364+
else:
365+
assert result.shape == x.shape

0 commit comments

Comments
 (0)