Skip to content

[0.9.1][Bugfix]Support Qwen3-MOE on aclgraph mode in no dp case #1940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@
from typing import Callable, Optional

import torch
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod

from vllm_ascend.ops.fused_moe import fused_experts, select_experts

original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__


def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager


def forward_oot(
self,
Expand Down Expand Up @@ -55,6 +65,8 @@ def forward_oot(
e_score_correction_bias=e_score_correction_bias,
)

max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -63,7 +75,9 @@ def forward_oot(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
max_num_tokens=max_num_tokens)


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.forward_oot = forward_oot
60 changes: 34 additions & 26 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def fused_experts_with_all2all(
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
max_num_tokens: Optional[int] = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
Expand All @@ -372,12 +373,13 @@ def fused_experts_with_all2all(
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens,
active_num=active_num,
))

global_expert_tokens = torch.bincount(expanded_expert_idx,
Expand Down Expand Up @@ -413,12 +415,13 @@ def fused_experts_with_all2all(
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens,
active_num=active_num,
))

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
Expand Down Expand Up @@ -497,6 +500,7 @@ def fused_experts_with_all2all_buffer(
global_batch_size: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
max_num_tokens: Optional[int] = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
Expand All @@ -511,11 +515,12 @@ def fused_experts_with_all2all_buffer(
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens))
active_num=active_num))

max_row_per_ep_rank = (
(-(-global_batch_size // ep_group.world_size) * max_model_len *
Expand Down Expand Up @@ -650,6 +655,7 @@ def fused_experts(
top_k: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
max_num_tokens: Optional[int] = None,
) -> torch.Tensor:
"""
Fused experts with top-k routing.
Expand Down Expand Up @@ -743,12 +749,13 @@ def fused_experts(
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens,
active_num=active_num,
))

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
Expand Down Expand Up @@ -944,7 +951,7 @@ def __init__(self, moe: MoEConfig = None):

self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len

self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

Expand Down Expand Up @@ -1023,7 +1030,7 @@ def apply(
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
if enable_force_load_balance and self.use_aclgraph is None:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)

fused_moe_state = get_forward_context().fused_moe_state
Expand All @@ -1043,16 +1050,17 @@ def apply(
mc2_mask=mc2_mask,
)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
)
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
max_num_tokens=max_num_tokens)
elif VLLM_ASCEND_MOE_ALL2ALL_BUFFER:
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
return fused_experts_with_all2all_buffer(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -1064,7 +1072,7 @@ def apply(
global_batch_size=self.global_batch_size,
expert_map=expert_map,
ep_group=get_ep_group(),
)
max_num_tokens=max_num_tokens)
elif fused_moe_state == FusedMoEState.All2AllSeq:
token_dispatcher = kwargs.get("token_dispatcher")
return fused_experts_with_all2allv(
Expand All @@ -1076,16 +1084,16 @@ def apply(
w2=layer.w2_weight,
)
else:
return fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group(),
)
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group(),
max_num_tokens=max_num_tokens)


class AscendFusedMoE(FusedMoE):
Expand Down