Skip to content

Commit a4bb618

Browse files
authored
Merge branch 'vllm-project:main' into main_gmmswigluquant
2 parents 7e83993 + e14f2ef commit a4bb618

File tree

10 files changed

+359
-370
lines changed

10 files changed

+359
-370
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import torch
2727
from vllm.model_executor.layers.activation import SiluAndMul
2828

29-
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
29+
from vllm_ascend.ops.fused_moe import fused_experts
30+
from vllm_ascend.ops.layers.experts_selector import select_experts
3031

3132
NUM_EXPERTS = [8, 64]
3233
EP_SIZE = [1, 4]
@@ -142,7 +143,7 @@ def test_select_experts(
142143
dtype=torch.int32)
143144
custom_routing_function.return_value = (mock_weights, mock_ids)
144145

145-
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
146+
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
146147
) as mock_native_grouped_topk:
147148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
148149
x)

tests/ut/ops/test_fused_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2626
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2727
AscendUnquantizedFusedMoEMethod)
28+
from vllm_ascend.ops.layers.experts_selector import select_experts
2829
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
2930

3031
adapt_patch(True)
@@ -389,3 +390,28 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
389390
assert result.shape == (16, 2)
390391
else:
391392
assert result.shape == x.shape
393+
394+
395+
class TestExpertsSelector:
396+
397+
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
398+
def test_select_experts(self, mock_dist_env, mock_moe_env,
399+
global_num_experts):
400+
401+
x = torch.randn(8, 2)
402+
router_logits = torch.randn(8, 2)
403+
topk_weights, topk_ids = select_experts(
404+
hidden_states=x,
405+
router_logits=router_logits,
406+
top_k=2,
407+
use_grouped_topk=False,
408+
renormalize=True,
409+
topk_group=None,
410+
num_expert_group=None,
411+
custom_routing_function=None,
412+
scoring_func="softmax",
413+
e_score_correction_bias=None,
414+
global_num_experts=global_num_experts)
415+
416+
assert topk_weights.shape == (8, 2)
417+
assert topk_ids.shape == (8, 2)

tests/ut/quantization/test_w8a8.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from tests.ut.base import TestBase
77
from vllm_ascend.attention.attention_v1 import AscendAttentionState
8+
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
9+
select_experts)
810
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
911
AscendW8A8FusedMoEMethod,
1012
AscendW8A8LinearMethod,
1113
fused_experts, fused_experts_310p,
12-
native_grouped_topk,
13-
quant_per_tensor, select_experts)
14+
quant_per_tensor)
1415

1516

1617
class TestQuantPerTensor(TestBase):
@@ -772,7 +773,7 @@ def test_grouped_topk(self, mock_topk):
772773
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
773774
self.assertEqual(ids.dtype, torch.int32)
774775

775-
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
776+
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
776777
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
777778
"""Test grouped topk with expert score correction bias"""
778779
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
@@ -868,9 +869,9 @@ def test_basic_group_selection(self):
868869

869870
with patch('torch.topk',
870871
return_value=(None, expected_topk_indices)) as mock_topk:
871-
result = native_grouped_topk(topk_weights=topk_weights,
872-
num_expert_group=2,
873-
topk_group=2)
872+
result = _native_grouped_topk(topk_weights=topk_weights,
873+
num_expert_group=2,
874+
topk_group=2)
874875

875876
mock_topk.assert_called_once()
876877

@@ -885,9 +886,9 @@ def test_partial_group_selection(self):
885886
expected_topk_indices = torch.tensor([[0], [1]])
886887

887888
with patch('torch.topk', return_value=(None, expected_topk_indices)):
888-
result = native_grouped_topk(topk_weights=topk_weights,
889-
num_expert_group=2,
890-
topk_group=1)
889+
result = _native_grouped_topk(topk_weights=topk_weights,
890+
num_expert_group=2,
891+
topk_group=1)
891892

892893
expected_result = torch.tensor(
893894
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
@@ -900,7 +901,7 @@ def test_single_group(self):
900901
expected_topk_indices = torch.tensor([[0], [0]])
901902

902903
with patch('torch.topk', return_value=(None, expected_topk_indices)):
903-
result = native_grouped_topk(topk_weights=topk_weights,
904-
num_expert_group=1,
905-
topk_group=1)
904+
result = _native_grouped_topk(topk_weights=topk_weights,
905+
num_expert_group=1,
906+
topk_group=1)
906907
self.assertTrue(result.numel() > 0)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
UnquantizedFusedMoEMethod
2525

2626
from vllm_ascend.ascend_config import get_ascend_config
27-
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
28-
unified_fused_experts)
27+
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
28+
from vllm_ascend.ops.layers.experts_selector import select_experts
2929
from vllm_ascend.utils import is_310p
3030

3131
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -59,7 +59,7 @@ def forward_oot(
5959
custom_routing_function: Optional[Callable] = None,
6060
scoring_func: str = "softmax",
6161
e_score_correction_bias: Optional[torch.Tensor] = None,
62-
global_num_experts: Optional[int] = None,
62+
global_num_experts: int = -1,
6363
expert_map: Optional[torch.Tensor] = None,
6464
apply_router_weight_on_input: bool = False,
6565
activation: str = "silu",
@@ -69,7 +69,6 @@ def forward_oot(
6969
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
7070

7171
topk_weights, topk_ids = select_experts(
72-
global_num_experts=global_num_experts,
7372
hidden_states=x,
7473
router_logits=router_logits,
7574
top_k=top_k,
@@ -80,7 +79,7 @@ def forward_oot(
8079
custom_routing_function=custom_routing_function,
8180
scoring_func=scoring_func,
8281
e_score_correction_bias=e_score_correction_bias,
83-
)
82+
global_num_experts=global_num_experts)
8483

8584
if topk_ids.shape[1] < top_k or is_310p():
8685
assert global_num_experts is not None

vllm_ascend/ops/fused_moe.py

Lines changed: 14 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
4949
from vllm_ascend.distributed.parallel_state import get_mc2_group
5050
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
51+
from vllm_ascend.ops.layers.experts_selector import select_experts
5152
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
5253
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5354
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
@@ -922,143 +923,6 @@ def fused_experts(
922923
return final_hidden_states
923924

924925

925-
def native_grouped_topk(
926-
topk_weights: torch.Tensor,
927-
num_expert_group: Optional[int],
928-
topk_group: Optional[int],
929-
):
930-
topk_group = 0 if topk_group is None else topk_group
931-
num_expert_group = 0 if num_expert_group is None else num_expert_group
932-
933-
num_token = topk_weights.shape[0]
934-
grouped_weights = topk_weights.view(num_token, num_expert_group,
935-
-1).max(dim=-1).values
936-
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
937-
k=topk_group,
938-
dim=-1,
939-
sorted=False)[1]
940-
topk_group_mask = torch.zeros_like(grouped_weights)
941-
topk_group_mask.scatter_(1, topk_group_indices, 1)
942-
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
943-
num_token, num_expert_group,
944-
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
945-
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
946-
947-
return topk_weights
948-
949-
950-
def select_experts(
951-
hidden_states: torch.Tensor,
952-
router_logits: torch.Tensor,
953-
top_k: int,
954-
use_grouped_topk: bool,
955-
renormalize: bool,
956-
topk_group: Optional[int] = None,
957-
num_expert_group: Optional[int] = None,
958-
custom_routing_function: Optional[Callable] = None,
959-
scoring_func: str = "softmax",
960-
e_score_correction_bias: Optional[torch.Tensor] = None,
961-
global_num_experts: Optional[torch.Tensor] = None
962-
) -> tuple[torch.Tensor, torch.Tensor]:
963-
"""
964-
Select top-k experts based on router logits.
965-
966-
Args:
967-
hidden_states: Hidden states of shape (num_tokens, hidden_size).
968-
router_logits: Router logits of shape (num_tokens, num_experts).
969-
top_k: Number of experts to select.
970-
use_grouped_topk: Whether to group experts before selecting top-k.
971-
renormalize: Whether to renormalize the routing weights.
972-
topk_group: Number of expert groups to select from.
973-
num_expert_group: Number of experts in each group.
974-
custom_routing_function: Custom routing function.
975-
scoring_func: Scoring function to use.
976-
e_score_correction_bias: Correction bias to apply to expert scores.
977-
978-
Returns:
979-
topk_weights: Routing weights of shape (num_tokens, top_k).
980-
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
981-
982-
Raises:
983-
ValueError: If an unsupported scoring function is provided.
984-
"""
985-
986-
def _renormalize_topk_weights(
987-
topk_weights: torch.Tensor,
988-
renormalize: bool,
989-
):
990-
if renormalize:
991-
topk_weights = topk_weights / topk_weights.sum(dim=-1,
992-
keepdim=True)
993-
return topk_weights
994-
995-
if scoring_func == "softmax":
996-
# NOTE: vLLM use dtype=torch.float here
997-
if not use_grouped_topk and custom_routing_function is None:
998-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
999-
x=router_logits, finished=None, k=top_k)
1000-
topk_ids = topk_ids.to(torch.int32)
1001-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1002-
return topk_weights, topk_ids
1003-
1004-
topk_weights = router_logits.softmax(dim=-1)
1005-
elif scoring_func == "sigmoid":
1006-
topk_weights = router_logits.sigmoid()
1007-
else:
1008-
raise ValueError(f"Unsupported scoring function: {scoring_func}")
1009-
1010-
if use_grouped_topk:
1011-
assert topk_group is not None
1012-
assert num_expert_group is not None
1013-
1014-
if e_score_correction_bias is not None:
1015-
# Store original scores before applying correction bias. We use biased
1016-
# scores for expert selection but original scores for routing weights
1017-
original_weights = topk_weights
1018-
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
1019-
1020-
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
1021-
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
1022-
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
1023-
topk_group)
1024-
# TODO bfloat16 is not supported in torch.topk with ge graph.
1025-
if e_score_correction_bias is not None:
1026-
topk_ids = torch.topk(topk_weights.to(torch.float32),
1027-
k=top_k,
1028-
dim=-1,
1029-
sorted=False)[1]
1030-
# Use original unbiased scores for the routing weights
1031-
topk_weights = original_weights.gather(1, topk_ids)
1032-
else:
1033-
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
1034-
k=top_k,
1035-
dim=-1,
1036-
sorted=False)
1037-
topk_ids = topk_ids.to(torch.int32)
1038-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1039-
return topk_weights, topk_ids
1040-
1041-
if custom_routing_function is not None:
1042-
topk_weights, topk_ids = custom_routing_function(
1043-
hidden_states=hidden_states,
1044-
gating_output=router_logits,
1045-
topk=top_k,
1046-
renormalize=renormalize,
1047-
global_num_experts=global_num_experts)
1048-
# Required by npu_moe_init_routing
1049-
topk_ids = topk_ids.to(torch.int32)
1050-
return topk_weights, topk_ids
1051-
1052-
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
1053-
topk_weights = topk_weights.to(hidden_states.dtype)
1054-
1055-
# Required by npu_moe_init_routing
1056-
topk_ids = topk_ids.to(torch.int32)
1057-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1058-
1059-
return topk_weights, topk_ids
1060-
1061-
1062926
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1063927

1064928
def __init__(self, moe: FusedMoEConfig = None):
@@ -1113,36 +977,19 @@ def apply(
1113977
**kwargs,
1114978
) -> torch.Tensor:
1115979

1116-
is_deepseek_v3_r1 = global_num_experts == 256
1117-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1118-
if is_deepseek_v3_r1:
1119-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1120-
router_logits,
1121-
k=top_k, # topk currently is 8
1122-
bias=e_score_correction_bias,
1123-
k_group=topk_group, # fix: 4
1124-
group_count=num_expert_group, # fix 8
1125-
group_select_mode=
1126-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
1127-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1128-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
1129-
# out_flag=False, # todo new api; should the third output be output
1130-
# y2_flag=False, # old api; should the third output be output
1131-
routed_scaling_factor=1,
1132-
eps=float(1e-20))
1133-
else:
1134-
topk_weights, topk_ids = select_experts(
1135-
hidden_states=x,
1136-
router_logits=router_logits,
1137-
top_k=top_k,
1138-
use_grouped_topk=use_grouped_topk,
1139-
renormalize=renormalize,
1140-
topk_group=topk_group,
1141-
num_expert_group=num_expert_group,
1142-
custom_routing_function=custom_routing_function,
1143-
scoring_func=scoring_func,
1144-
e_score_correction_bias=e_score_correction_bias,
1145-
)
980+
topk_weights, topk_ids = select_experts(
981+
hidden_states=x,
982+
router_logits=router_logits,
983+
top_k=top_k,
984+
use_grouped_topk=use_grouped_topk,
985+
renormalize=renormalize,
986+
topk_group=topk_group,
987+
num_expert_group=num_expert_group,
988+
custom_routing_function=custom_routing_function,
989+
scoring_func=scoring_func,
990+
e_score_correction_bias=e_score_correction_bias,
991+
global_num_experts=global_num_experts,
992+
is_unquantized=True)
1146993

1147994
topk_weights = topk_weights.to(x.dtype)
1148995
# this is a naive implementation for experts load balance so as

vllm_ascend/ops/layers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)