20
20
import torch .distributed as dist
21
21
from vllm .distributed .device_communicators .base_device_communicator import \
22
22
DeviceCommunicatorBase
23
+ from vllm .forward_context import get_forward_context
23
24
24
25
25
26
class NPUCommunicator (DeviceCommunicatorBase ):
@@ -34,6 +35,20 @@ def __init__(self,
34
35
# init device according to rank
35
36
self .device = torch .npu .current_device ()
36
37
38
+ # Adapted from vllm/distributed/device_communicators/base_device_communicator.py
39
+ if self .use_all2all :
40
+ # compute some common properties
41
+ from vllm .distributed .parallel_state import (get_dp_group ,
42
+ get_tp_group )
43
+
44
+ # all2all lives in ep group, which is merged from dp and tp group
45
+ self .dp_group = get_dp_group ()
46
+ self .tp_group = get_tp_group ()
47
+ # no self.ep_group since self.ep_group is still in construction
48
+ # when we create this object
49
+ self .dp_rank = self .dp_group .rank_in_group
50
+ self .dp_world_size = self .dp_group .world_size
51
+
37
52
def all_to_all (self ,
38
53
input_ : torch .Tensor ,
39
54
scatter_dim : int = 0 ,
@@ -73,3 +88,43 @@ def all_to_all(self,
73
88
dist .all_to_all (output_list , input_list , group = self .device_group )
74
89
output_tensor = torch .cat (output_list , dim = gather_dim ).contiguous ()
75
90
return output_tensor
91
+
92
+ def naive_multicast (self , x : torch .Tensor ,
93
+ cu_tokens_across_dp_cpu : torch .Tensor ):
94
+ assert (len (x .shape ) == 2 )
95
+ buffer = torch .empty ((cu_tokens_across_dp_cpu [- 1 ], x .size (1 )),
96
+ device = x .device ,
97
+ dtype = x .dtype )
98
+
99
+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
100
+ self .dp_rank - 1 ]
101
+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
102
+ buffer [start :end , :].copy_ (x )
103
+ for idx in range (self .dp_world_size ):
104
+ start = 0 if idx == 0 else cu_tokens_across_dp_cpu [idx - 1 ]
105
+ end = cu_tokens_across_dp_cpu [idx ]
106
+ self .dp_group .broadcast (buffer [start :end , :], idx )
107
+
108
+ return buffer
109
+
110
+ def dispatch (self , hidden_states : torch .Tensor ,
111
+ router_logits : torch .Tensor ):
112
+ cu_tokens_across_dp_cpu = get_forward_context (
113
+ ).dp_metadata .cu_tokens_across_dp_cpu
114
+
115
+ hidden_states = self .naive_multicast (hidden_states ,
116
+ cu_tokens_across_dp_cpu )
117
+ router_logits = self .naive_multicast (router_logits ,
118
+ cu_tokens_across_dp_cpu )
119
+ return hidden_states , router_logits
120
+
121
+ def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
122
+ cu_tokens_across_dp_cpu = get_forward_context (
123
+ ).dp_metadata .cu_tokens_across_dp_cpu
124
+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
125
+ self .dp_rank - 1 ]
126
+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
127
+
128
+ all_hidden_states = self .dp_group .all_reduce (hidden_states )
129
+ hidden_states = all_hidden_states [start :end , :]
130
+ return hidden_states
0 commit comments