diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 43c9517844..135a191aeb 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -18,11 +18,21 @@ from typing import Callable, Optional import torch +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod from vllm_ascend.ops.fused_moe import fused_experts, select_experts +original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ + + +def unquantized_fused_moe_init_func(self, *args, **kwargs): + original_unquantized_fused_moe_init_func(self, *args, **kwargs) + vllm_config = get_current_vllm_config() + self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager + def forward_oot( self, @@ -55,6 +65,8 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, ) + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + return fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -63,7 +75,9 @@ def forward_oot( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + max_num_tokens=max_num_tokens) +UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 37edb9767a..c5fff4bf21 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -354,6 +354,7 @@ def fused_experts_with_all2all( top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, + max_num_tokens: Optional[int] = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -372,12 +373,13 @@ def fused_experts_with_all2all( dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, - active_num=num_tokens, + active_num=active_num, )) global_expert_tokens = torch.bincount(expanded_expert_idx, @@ -413,12 +415,13 @@ def fused_experts_with_all2all( dtype=torch.int32, device=topk_weights.device).view( top_k, -1).permute(1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, - active_num=num_tokens, + active_num=active_num, )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( @@ -497,6 +500,7 @@ def fused_experts_with_all2all_buffer( global_batch_size: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, + max_num_tokens: Optional[int] = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -511,11 +515,12 @@ def fused_experts_with_all2all_buffer( row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute(1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing(hidden_states, row_idx=row_idx, expert_idx=topk_ids, - active_num=num_tokens)) + active_num=active_num)) max_row_per_ep_rank = ( (-(-global_batch_size // ep_group.world_size) * max_model_len * @@ -650,6 +655,7 @@ def fused_experts( top_k: int, expert_map: torch.Tensor = None, apply_router_weight_on_input: bool = False, + max_num_tokens: Optional[int] = None, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -743,12 +749,13 @@ def fused_experts( dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens sorted_hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, - active_num=num_tokens, + active_num=active_num, )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( @@ -944,7 +951,7 @@ def __init__(self, moe: MoEConfig = None): self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_model_len = vllm_config.model_config.max_model_len - + self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -1023,7 +1030,7 @@ def apply( # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if enable_force_load_balance: + if enable_force_load_balance and self.use_aclgraph is None: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state @@ -1043,16 +1050,17 @@ def apply( mc2_mask=mc2_mask, ) elif fused_moe_state == FusedMoEState.AllGather: - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ) + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + max_num_tokens=max_num_tokens) elif VLLM_ASCEND_MOE_ALL2ALL_BUFFER: + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None return fused_experts_with_all2all_buffer( hidden_states=x, w1=layer.w13_weight, @@ -1064,7 +1072,7 @@ def apply( global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group(), - ) + max_num_tokens=max_num_tokens) elif fused_moe_state == FusedMoEState.All2AllSeq: token_dispatcher = kwargs.get("token_dispatcher") return fused_experts_with_all2allv( @@ -1076,16 +1084,16 @@ def apply( w2=layer.w2_weight, ) else: - return fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group(), - ) + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + return fused_experts_with_all2all(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=get_ep_group(), + max_num_tokens=max_num_tokens) class AscendFusedMoE(FusedMoE):