Skip to content

Commit da539f2

Browse files
author
yangcheng
committed
add ut
Signed-off-by: yangcheng <[email protected]>
1 parent 5f0b42e commit da539f2

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
from typing import List, TypedDict
16+
from unittest.mock import MagicMock, patch
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
from pytest_mock import MockerFixture
22+
23+
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
24+
AscendUnquantizedFusedMoEMethod)
25+
from vllm_ascend.utils import adapt_patch # noqa E402
26+
27+
adapt_patch(True)
28+
29+
30+
def mock_ep_group(mocker):
31+
mock_group = mocker.MagicMock()
32+
mock_group.rank_in_group = 0
33+
mock_group.rank = 0
34+
mock_group.world_size = 4
35+
mock_group.device_group = "mock_group_ep"
36+
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
37+
return mock_group
38+
39+
40+
def mock_dp_and_tp_group(mocker):
41+
mock_group = mocker.MagicMock()
42+
mock_group.rank_in_group = 0
43+
mock_group.world_size = 2
44+
mock_group.device_group = "mock_group"
45+
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
46+
return mock_group
47+
48+
49+
@pytest.fixture
50+
def mock_dist_env(mocker: MockerFixture):
51+
# init dist env patch
52+
53+
with patch('torch.distributed.get_rank', return_value=0), \
54+
patch('torch.distributed.get_world_size', return_value=4), \
55+
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
56+
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
57+
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
58+
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
59+
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
60+
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
61+
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
62+
return_value=torch.randn(5, 32)), \
63+
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter',
64+
return_value=torch.randn(5, 32)), \
65+
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
66+
return_value=mock_dp_and_tp_group(mocker)), \
67+
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
68+
return_value=MagicMock(
69+
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
70+
expert_map_path=None
71+
)), \
72+
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
73+
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
74+
patch('vllm_ascend.ops.fused_moe.get_forward_context',
75+
return_value=MagicMock(
76+
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
77+
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
78+
)), \
79+
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
80+
return_value=MagicMock(
81+
parallel_config=MagicMock(tensor_parallel_size=2),
82+
scheduler_config=MagicMock(max_num_seqs=4),
83+
model_config=MagicMock(max_model_len=2048)
84+
)):
85+
yield
86+
87+
88+
@pytest.fixture
89+
def mock_moe_env(mocker: MockerFixture):
90+
# init moe env patch
91+
92+
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
93+
torch.randn(8, 2),
94+
torch.randint(0, 8, (8, 2)),
95+
None
96+
)), \
97+
patch('torch_npu.npu_moe_init_routing', return_value=(
98+
torch.randn(8, 2),
99+
torch.randint(0, 8, (8, 2)),
100+
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
101+
)), \
102+
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
103+
torch.randn(8, 2)
104+
)), \
105+
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
106+
torch.randn(16, 2)
107+
)), \
108+
patch("torch_npu.npu_moe_distribute_combine", return_value=(
109+
torch.randn(16, 2)
110+
)), \
111+
patch("torch_npu.npu_grouped_matmul", return_value=(
112+
(torch.randn(8, 2), torch.randn(8, 2))
113+
)), \
114+
patch("torch_npu.npu_swiglu", return_value=(
115+
torch.randn(16, 2)
116+
)), \
117+
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
118+
torch.randn(8, 2),
119+
torch.randint(0, 8, (8, 2)),
120+
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
121+
)), \
122+
patch("torch_npu.npu_moe_finalize_routing", return_value=(
123+
torch.randn(16, 2)
124+
)):
125+
yield
126+
127+
128+
@pytest.fixture
129+
def default_moe_config():
130+
"""default moe config"""
131+
return {
132+
'num_experts': 8,
133+
'top_k': 2,
134+
'hidden_size': 512,
135+
'intermediate_size': 1024
136+
}
137+
138+
139+
@pytest.fixture
140+
def moe_method(mock_dist_env):
141+
moe = MagicMock()
142+
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
143+
return AscendUnquantizedFusedMoEMethod(moe)
144+
145+
146+
class Device(TypedDict):
147+
device_id: int
148+
device_expert: List[int]
149+
150+
151+
class Layer(TypedDict):
152+
layer_id: int
153+
device_count: int
154+
device_list: List[Device]
155+
156+
157+
class MockData(TypedDict):
158+
moe_layer_count: int
159+
layer_list: List[Layer]
160+
161+
162+
class MockQuantMethod(nn.Module):
163+
164+
def __init__(self, shared_experts, num_tokens):
165+
super().__init__()
166+
if shared_experts:
167+
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
168+
torch.randn(num_tokens, 10)))
169+
else:
170+
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
171+
172+
def forward(self, *args, **kwargs):
173+
return self.apply(*args, **kwargs)
174+
175+
176+
class TestAscendFusedMoe:
177+
178+
def test_init_no_quant(self, mock_dist_env, default_moe_config):
179+
layer = AscendFusedMoE(**default_moe_config)
180+
181+
layer.w13_weight = nn.Parameter(
182+
torch.randn(default_moe_config['num_experts'],
183+
default_moe_config['intermediate_size'] * 2,
184+
default_moe_config['hidden_size']))
185+
layer.w2_weight = nn.Parameter(
186+
torch.randn(default_moe_config['num_experts'],
187+
default_moe_config['hidden_size'],
188+
default_moe_config['intermediate_size']))
189+
190+
assert layer.num_experts == default_moe_config['num_experts']
191+
assert layer.top_k == default_moe_config['top_k']
192+
assert hasattr(layer, 'w13_weight')
193+
assert hasattr(layer, 'w2_weight')
194+
assert layer.moe_instance_id == 0
195+
196+
# check group_topk
197+
with pytest.raises(AssertionError):
198+
error_config = default_moe_config.copy()
199+
error_config['use_grouped_topk'] = True
200+
layer = AscendFusedMoE(**error_config)
201+
202+
# check scoring_func
203+
with pytest.raises(ValueError):
204+
error_config = default_moe_config.copy()
205+
error_config['scoring_func'] = "random"
206+
layer = AscendFusedMoE(**error_config)
207+
208+
def test_init_with_quant(self, mock_dist_env, default_moe_config):
209+
mock_quant_config = MagicMock()
210+
mock_quant_method = MagicMock()
211+
mock_quant_config.get_quant_mothod.return_value = mock_quant_method
212+
213+
moe = AscendFusedMoE(**default_moe_config,
214+
quant_config=mock_quant_config)
215+
216+
assert moe.quant_method is not None
217+
218+
@pytest.mark.parametrize(
219+
"others_param",
220+
[[None,
221+
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
222+
[2, None, False, 5, None], [None, None, True, 5, None],
223+
[None, None, False, 1, None], [None, None, True, 5, 1],
224+
[None, None, False, 5, 1]])
225+
def test_forward(self, mock_dist_env, default_moe_config, others_param):
226+
"""
227+
1 test has shared_experts
228+
2 test has top_k
229+
3 test is_prefill is true
230+
4 test single num_tokens(decode)
231+
5 test ep_size is 1 and is_prefill is true
232+
6 test ep_size is 1 and is_prefill is False
233+
"""
234+
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
235+
inputs = torch.randn(num_tokens, 32)
236+
router_logits = torch.randn(num_tokens, 8)
237+
moe = AscendFusedMoE(**default_moe_config)
238+
239+
if ep_size == 1:
240+
moe.moe_parallel_config.ep_size = 1
241+
242+
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
243+
output = moe.forward(inputs,
244+
router_logits,
245+
is_prefill=is_prefill,
246+
top_k=top_k,
247+
shared_experts=shared_experts)
248+
249+
moe.quant_method.apply.assert_called_once()
250+
251+
if shared_experts:
252+
assert output[0].shape == (num_tokens, 32)
253+
assert output[1].shape == (num_tokens, 10)
254+
else:
255+
assert output.shape == (num_tokens, 32)
256+
257+
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
258+
default_moe_config):
259+
inputs = torch.randn(5, 32)
260+
router_logits = torch.randn(5, 8)
261+
moe = AscendFusedMoE(**default_moe_config)
262+
263+
moe.quant_method = MockQuantMethod(None, 5)
264+
output = moe._forward_ms_fused_moe_comp(inputs,
265+
router_logits,
266+
is_prefill=False,
267+
real_top_k=1)
268+
269+
moe.quant_method.apply.assert_called_once()
270+
271+
assert output.shape == (5, 32)
272+
273+
274+
class TestAscendUnquantizedFusedMoEMethod:
275+
276+
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
277+
layer = MagicMock()
278+
layer.w13_weight.data = torch.randn(16, 32)
279+
layer.w2_weight.data = torch.randn(16, 32)
280+
281+
moe_method.process_weights_after_loading(layer)
282+
283+
assert isinstance(layer.w13_weight, torch.nn.Parameter)
284+
assert isinstance(layer.w2_weight, torch.nn.Parameter)
285+
assert not layer.w13_weight.requires_grad
286+
assert not layer.w2_weight.requires_grad
287+
288+
@pytest.mark.parametrize(
289+
"others_param",
290+
[[256, 4, False], [128, 1, False], [128, 1, True], [128, 4, False]])
291+
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
292+
mock_moe_env, others_param):
293+
"""
294+
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
295+
2 test use_select_experts and fused_experts
296+
3 test use select_gating_topk_softmax_experts and fused_experts
297+
4 test use select_experts and fused_experts_with_all2all_buffer
298+
"""
299+
global_num_experts, ep_size, select_softmax = others_param
300+
with patch(
301+
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
302+
select_softmax):
303+
moe_method.ep_size = ep_size
304+
x = torch.randn(8, 2, 2)
305+
router_logits = torch.randn(8, 8)
306+
layer = MagicMock()
307+
layer.w13_weight = torch.randn(8, 16, 1)
308+
layer.w2_weight = torch.randn(16, 8, 1)
309+
result = moe_method.apply(layer=layer,
310+
x=x,
311+
router_logits=router_logits,
312+
top_k=2,
313+
renormalize=True,
314+
global_num_experts=global_num_experts,
315+
is_prefill=False)
316+
317+
if ep_size == 1:
318+
assert result.shape == (16, 2)
319+
else:
320+
assert result.shape == x.shape
321+
322+
@pytest.mark.parametrize("others_param",
323+
[[16, False], [1, True], [1, False], [4, False]])
324+
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
325+
mock_moe_env, others_param):
326+
"""
327+
1 test use_select_experts and use fused_expters_with_mc2
328+
2 test use_select_experts and fused_experts_with_all2all_buffer
329+
3 test use_select_experts and fused_experts_with_all2all
330+
4 test use_select_experts and fused_experts
331+
"""
332+
ep_size, alltoall_buffer = others_param
333+
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
334+
alltoall_buffer):
335+
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
336+
moe_method.ep_size = ep_size
337+
x = torch.randn(8, 2, 2)
338+
if ep_size == 1:
339+
x = x.view(-1, 2)
340+
router_logits = torch.randn(8, 8)
341+
if alltoall_buffer:
342+
moe_method.max_model_len = 1
343+
layer = MagicMock()
344+
layer.w13_weight = torch.randn(8, 16, 1)
345+
layer.w2_weight = torch.randn(16, 8, 1)
346+
result = moe_method.apply(layer=layer,
347+
x=x,
348+
router_logits=router_logits,
349+
top_k=2,
350+
renormalize=True,
351+
global_num_experts=128,
352+
expert_map=expert_map,
353+
is_prefill=False)
354+
355+
if ep_size == 16 or ep_size == 1:
356+
assert result.shape == (16, 2)
357+
else:
358+
assert result.shape == x.shape

0 commit comments

Comments
 (0)