Skip to content

Commit 627f20c

Browse files
authored
[BugFix]Fix group list type of mc2. (#3864)
### What this PR does / why we need it? Fix the precision issue caused by the inconsistency between the group list type used by mc2 and that of eplb. - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: offline0806 <[email protected]>
1 parent 655a229 commit 627f20c

File tree

3 files changed

+44
-40
lines changed

3 files changed

+44
-40
lines changed

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ def __init__(self, *args, **kwargs):
266266
self.expert_map != -1) if self.expert_map is not None else
267267
self.global_num_experts)
268268
if self.dynamic_eplb:
269-
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
269+
self.moe_load = torch.zeros(local_num_experts,
270+
dtype=torch.int64).npu()
270271

271272
self.moe_config.num_experts = self.global_num_experts
272273
self.moe_config.num_local_experts = self.local_num_experts
@@ -362,10 +363,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
362363

363364
if isinstance(final_hidden_states, tuple):
364365
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
365-
366-
if self.dynamic_eplb:
367-
self.moe_load += expert_tokens if group_list_type else \
368-
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
366+
if self.dynamic_eplb:
367+
self.moe_load += expert_tokens if group_list_type == 1 else \
368+
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
369369

370370
final_hidden_states = forward_context.moe_comm_method.finalize(
371371
hidden_states=final_hidden_states,

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def fused_experts(
130130
dynamic_scale_for_share=dynamic_scale_for_share,
131131
mc2_mask=mc2_mask,
132132
apply_router_weight_on_input=apply_router_weight_on_input,
133-
with_quant=use_int8_w8a8 or use_int4_w4a8)
133+
with_quant=use_int8_w8a8 or use_int4_w4a8,
134+
dynamic_eplb=dynamic_eplb)
134135

135136
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
136137
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def token_dispatch(self,
6969
dynamic_scale_for_share: Optional[Any] = None,
7070
mc2_mask: Optional[torch.Tensor] = None,
7171
apply_router_weight_on_input: bool = False,
72-
with_quant: bool = False):
72+
with_quant: bool = False,
73+
dynamic_eplb: bool = False):
7374
raise NotImplementedError("Dispatch function not implemented.")
7475

7576
@abstractmethod
@@ -156,21 +157,20 @@ def get_dispatch_mc2_kwargs(
156157
kwargs_mc2.update(stage1_kwargs)
157158
return kwargs_mc2
158159

159-
def token_dispatch(
160-
self,
161-
hidden_states: torch.Tensor,
162-
topk_weights: torch.Tensor,
163-
topk_ids: torch.Tensor,
164-
expert_map: Optional[torch.Tensor] = None,
165-
log2phy: Optional[torch.Tensor] = None,
166-
global_redundant_expert_num: int = 0,
167-
shared_experts: Optional[Any] = None,
168-
quantized_x_for_share: Optional[Any] = None,
169-
dynamic_scale_for_share: Optional[Any] = None,
170-
mc2_mask: Optional[torch.Tensor] = None,
171-
apply_router_weight_on_input: bool = False,
172-
with_quant: bool = False,
173-
):
160+
def token_dispatch(self,
161+
hidden_states: torch.Tensor,
162+
topk_weights: torch.Tensor,
163+
topk_ids: torch.Tensor,
164+
expert_map: Optional[torch.Tensor] = None,
165+
log2phy: Optional[torch.Tensor] = None,
166+
global_redundant_expert_num: int = 0,
167+
shared_experts: Optional[Any] = None,
168+
quantized_x_for_share: Optional[Any] = None,
169+
dynamic_scale_for_share: Optional[Any] = None,
170+
mc2_mask: Optional[torch.Tensor] = None,
171+
apply_router_weight_on_input: bool = False,
172+
with_quant: bool = False,
173+
dynamic_eplb: bool = False):
174174
self.with_quant = with_quant
175175

176176
# Apply log2phy if needed
@@ -221,8 +221,10 @@ def token_dispatch(
221221
"expand_scales": expand_scales
222222
}
223223

224+
group_list_type = 1 if dynamic_eplb else 0
225+
224226
return {
225-
"group_list_type": 0,
227+
"group_list_type": group_list_type,
226228
"hidden_states": expand_x,
227229
"group_list": expert_token_nums,
228230
"dynamic_scale": dynamic_scale,
@@ -336,7 +338,8 @@ def token_dispatch(self,
336338
dynamic_scale_for_share: Optional[Any] = None,
337339
mc2_mask: Optional[torch.Tensor] = None,
338340
apply_router_weight_on_input: bool = False,
339-
with_quant: bool = False):
341+
with_quant: bool = False,
342+
dynamic_eplb: bool = False):
340343
self.with_quant = with_quant
341344
self.original_shape = hidden_states.shape
342345

@@ -426,7 +429,8 @@ def token_dispatch(self,
426429
dynamic_scale_for_share: Optional[Any] = None,
427430
mc2_mask: Optional[torch.Tensor] = None,
428431
apply_router_weight_on_input: bool = False,
429-
with_quant: bool = False):
432+
with_quant: bool = False,
433+
dynamic_eplb: bool = False):
430434
self.bsz, _ = hidden_states.shape
431435
flatten_topk_ids = topk_ids.view(-1)
432436
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -501,21 +505,20 @@ def __init__(self, **kwargs):
501505
self.local_expert_indices[i + 1] -
502506
1), "local_expert_indices must be continuous"
503507

504-
def token_dispatch(
505-
self,
506-
hidden_states: torch.Tensor,
507-
topk_weights: torch.Tensor,
508-
topk_ids: torch.Tensor,
509-
expert_map: Optional[torch.Tensor] = None,
510-
log2phy: Optional[torch.Tensor] = None,
511-
global_redundant_expert_num: int = 0,
512-
shared_experts: Optional[Any] = None,
513-
quantized_x_for_share: Optional[Any] = None,
514-
dynamic_scale_for_share: Optional[Any] = None,
515-
mc2_mask: Optional[torch.Tensor] = None,
516-
apply_router_weight_on_input: bool = False,
517-
with_quant: bool = False,
518-
):
508+
def token_dispatch(self,
509+
hidden_states: torch.Tensor,
510+
topk_weights: torch.Tensor,
511+
topk_ids: torch.Tensor,
512+
expert_map: Optional[torch.Tensor] = None,
513+
log2phy: Optional[torch.Tensor] = None,
514+
global_redundant_expert_num: int = 0,
515+
shared_experts: Optional[Any] = None,
516+
quantized_x_for_share: Optional[Any] = None,
517+
dynamic_scale_for_share: Optional[Any] = None,
518+
mc2_mask: Optional[torch.Tensor] = None,
519+
apply_router_weight_on_input: bool = False,
520+
with_quant: bool = False,
521+
dynamic_eplb: bool = False):
519522
self.with_quant = with_quant
520523
self.hidden_shape = hidden_states.shape
521524

0 commit comments

Comments
 (0)