diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 7f74c53fac..ec91e52268 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -689,33 +689,16 @@ def _dispatch_postprocess(self, global_input_tokens, if self.with_quant: assert global_input_tokens_local_experts_indices is not None, \ "global_input_tokens_local_experts_indices must be provided" - expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze( + dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute( + dynamic_scale_after_all2all.unsqueeze(-1), + global_input_tokens_local_experts_indices) + dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze( -1) - active_num = global_input_tokens_local_experts_indices.numel() - - if active_num <= 0: - reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices - return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping - - global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( - global_input_tokens, - expert_idx_2d, - scale=dynamic_scale_after_all2all, - active_num=active_num, - expert_capacity=0, - expert_num=self.num_local_experts, - expert_tokens_num_type=1, - expert_tokens_num_flag=True, - active_expert_range=[0, self.num_local_experts], - quant_mode=-1, - row_idx_type=0, - ) - return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping # Non-quantized case global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, global_input_tokens_local_experts_indices) - return global_input_tokens, None, reversed_global_input_permutation_mapping + return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor: