Skip to content
4 changes: 2 additions & 2 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tp_group
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.platforms import current_platform

Expand Down Expand Up @@ -63,7 +63,7 @@ def set_ascend_forward_context(
):
forward_context = get_forward_context()
forward_context.with_prefill = with_prefill
ep_size = (torch.distributed.get_world_size() if
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)

fused_moe_state = get_fused_moe_state(ep_size, with_prefill)
Expand Down