diff --git a/vllm_ascend/distributed/communicator.py b/vllm_ascend/distributed/communicator.py index 7c14befa80..bc68d2bf7f 100644 --- a/vllm_ascend/distributed/communicator.py +++ b/vllm_ascend/distributed/communicator.py @@ -20,6 +20,7 @@ import torch.distributed as dist from vllm.distributed.device_communicators.base_device_communicator import \ DeviceCommunicatorBase +from vllm.utils import logger class NPUCommunicator(DeviceCommunicatorBase): @@ -34,6 +35,12 @@ def __init__(self, # init device according to rank self.device = torch.npu.current_device() + if self.use_all2all: + from vllm.distributed.device_communicators.all2all import \ + NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + def all_to_all(self, input_: torch.Tensor, scatter_dim: int = 0, @@ -73,3 +80,16 @@ def all_to_all(self, dist.all_to_all(output_list, input_list, group=self.device_group) output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() return output_tensor + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states