Skip to content

[MoE][Dist] Fix Qwen MoE accuracy bug in DP senario #1856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions vllm_ascend/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.utils import logger


class NPUCommunicator(DeviceCommunicatorBase):
Expand All @@ -34,6 +35,12 @@ def __init__(self,
# init device according to rank
self.device = torch.npu.current_device()

if self.use_all2all:
from vllm.distributed.device_communicators.all2all import \
NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")

def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
Expand Down Expand Up @@ -73,3 +80,16 @@ def all_to_all(self,
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states
Loading