Skip to content

Commit 2693196

Browse files
momo609wangxiaoxin-sherie
andauthored
add gatherep select. (#2740)
### What this PR does / why we need it? add gatherep select. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@e599e2c Signed-off-by: wangxiaoxin-sherie <[email protected]> Co-authored-by: wangxiaoxin-sherie <[email protected]>
1 parent 6666e52 commit 2693196

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import unittest
3+
from unittest import mock
4+
5+
from vllm_ascend.ascend_forward_context import get_dispatcher_name
6+
7+
8+
class TestGetDispatcherName(unittest.TestCase):
9+
10+
def test_get_dispatcher_name(self):
11+
result = get_dispatcher_name(1, False)
12+
assert result == "TokenDispatcherWithAllGather"
13+
result = get_dispatcher_name(4, False)
14+
assert result == "TokenDispatcherWithAll2AllV"
15+
result = get_dispatcher_name(16, True)
16+
assert result == "TokenDispatcherWithAll2AllV"
17+
result = get_dispatcher_name(16, False)
18+
assert result == "TokenDispatcherWithMC2"
19+
with mock.patch.dict(os.environ,
20+
{"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}):
21+
result = get_dispatcher_name(16, False)
22+
assert result == "TokenDispatcherWithAllGather"

vllm_ascend/ascend_forward_context.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,12 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
4545
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
4646
if ep_size == 1:
4747
return "TokenDispatcherWithAllGather"
48-
49-
if ep_size < 16:
50-
return "TokenDispatcherWithAll2AllV"
51-
52-
if with_prefill:
48+
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1:
49+
return "TokenDispatcherWithAllGather"
50+
elif ep_size < 16 or with_prefill:
5351
return "TokenDispatcherWithAll2AllV"
54-
return "TokenDispatcherWithMC2"
52+
else:
53+
return "TokenDispatcherWithMC2"
5554

5655

5756
@contextmanager

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch_npu
2929
from vllm.distributed.parallel_state import get_ep_group
3030

31+
import vllm_ascend.envs as envs_ascend
3132
from vllm_ascend.distributed.parallel_state import get_mc2_group
3233
from vllm_ascend.distributed.tensor_parallel import \
3334
gather_from_sequence_parallel_region
@@ -50,6 +51,9 @@ def setup_token_dispatchers(ep_size: int, **kwargs):
5051

5152
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
5253
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
54+
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \
55+
and "TokenDispatcherWithAllGather" not in existing_dispatchers:
56+
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
5357
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
5458
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
5559
elif ep_size >= 16:

0 commit comments

Comments
 (0)