Skip to content

Commit e4626e4

Browse files
committed
[MoE][Dist] Fix Qwen MoE accuracy bug in DP senario
Co-authored-by: Yan Zhang <[email protected]> Signed-off-by: Yan Zhang <[email protected]> Signed-off-by: MengqingCao <[email protected]>
1 parent 19e37cd commit e4626e4

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

vllm_ascend/distributed/communicator.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
23+
from vllm.forward_context import get_forward_context
2324

2425

2526
class NPUCommunicator(DeviceCommunicatorBase):
@@ -34,6 +35,20 @@ def __init__(self,
3435
# init device according to rank
3536
self.device = torch.npu.current_device()
3637

38+
# Adapted from vllm/distributed/device_communicators/base_device_communicator.py
39+
if self.use_all2all:
40+
# compute some common properties
41+
from vllm.distributed.parallel_state import (get_dp_group,
42+
get_tp_group)
43+
44+
# all2all lives in ep group, which is merged from dp and tp group
45+
self.dp_group = get_dp_group()
46+
self.tp_group = get_tp_group()
47+
# no self.ep_group since self.ep_group is still in construction
48+
# when we create this object
49+
self.dp_rank = self.dp_group.rank_in_group
50+
self.dp_world_size = self.dp_group.world_size
51+
3752
def all_to_all(self,
3853
input_: torch.Tensor,
3954
scatter_dim: int = 0,
@@ -73,3 +88,43 @@ def all_to_all(self,
7388
dist.all_to_all(output_list, input_list, group=self.device_group)
7489
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
7590
return output_tensor
91+
92+
def naive_multicast(self, x: torch.Tensor,
93+
cu_tokens_across_dp_cpu: torch.Tensor):
94+
assert (len(x.shape) == 2)
95+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
96+
device=x.device,
97+
dtype=x.dtype)
98+
99+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
100+
self.dp_rank - 1]
101+
end = cu_tokens_across_dp_cpu[self.dp_rank]
102+
buffer[start:end, :].copy_(x)
103+
for idx in range(self.dp_world_size):
104+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
105+
end = cu_tokens_across_dp_cpu[idx]
106+
self.dp_group.broadcast(buffer[start:end, :], idx)
107+
108+
return buffer
109+
110+
def dispatch(self, hidden_states: torch.Tensor,
111+
router_logits: torch.Tensor):
112+
cu_tokens_across_dp_cpu = get_forward_context(
113+
).dp_metadata.cu_tokens_across_dp_cpu
114+
115+
hidden_states = self.naive_multicast(hidden_states,
116+
cu_tokens_across_dp_cpu)
117+
router_logits = self.naive_multicast(router_logits,
118+
cu_tokens_across_dp_cpu)
119+
return hidden_states, router_logits
120+
121+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
122+
cu_tokens_across_dp_cpu = get_forward_context(
123+
).dp_metadata.cu_tokens_across_dp_cpu
124+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
125+
self.dp_rank - 1]
126+
end = cu_tokens_across_dp_cpu[self.dp_rank]
127+
128+
all_hidden_states = self.dp_group.all_reduce(hidden_states)
129+
hidden_states = all_hidden_states[start:end, :]
130+
return hidden_states

0 commit comments

Comments
 (0)