2020DS_COMM_REDUCE_OFF = False
2121
2222
23+ def disable_compiler_collective (func ):
24+ if required_torch_version (min_version = 2.3 ):
25+ return func
26+ return compiler .disable (func )
27+
28+
2329def build_shm_op ():
2430 builder = get_accelerator ().create_op_builder ("ShareMemCommBuilder" )
2531 if builder is None or not deepspeed .ops .__compatible_ops__ [builder .NAME ]:
@@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
114120 self .shm_comm_op .initialize (self .get_world_size (), self .get_rank ())
115121
116122 @classmethod
117- @compiler . disable
123+ @disable_compiler_collective
118124 def get_all_gather_function (self ):
119125 if hasattr (torch .distributed , "all_gather_into_tensor" ):
120126 return torch .distributed .all_gather_into_tensor
@@ -123,7 +129,7 @@ def get_all_gather_function(self):
123129 return None
124130
125131 @classmethod
126- @compiler . disable
132+ @disable_compiler_collective
127133 def get_reduce_scatter_function (self ):
128134 if hasattr (torch .distributed , "reduce_scatter_tensor" ):
129135 return torch .distributed .reduce_scatter_tensor
@@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
146152 world_size = world_size )
147153 self .using_mpi = torch .distributed .get_backend () == 'mpi'
148154
149- @compiler . disable
155+ @disable_compiler_collective
150156 def all_reduce (self , tensor , op = torch .distributed .ReduceOp .SUM , group = None , async_op = False ):
151157 op = self ._reduce_op (op )
152158 return torch .distributed .all_reduce (tensor = tensor , op = op , group = group , async_op = async_op )
@@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None):
158164 else :
159165 return torch .ops .deepspeed .inference_all_reduce_ (tensor )
160166
161- @compiler . disable
167+ @disable_compiler_collective
162168 def all_reduce_coalesced (self , tensors , op = torch .distributed .ReduceOp .SUM , group = None , async_op = False ):
163169 """ proxy func to torch.distributed.all_reduce_coalesced,
164170 which is included in PyTorch 1.13 and above
@@ -169,15 +175,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
169175 op = self ._reduce_op (op )
170176 return torch .distributed .all_reduce_coalesced (tensors = tensors , op = op , group = group , async_op = async_op )
171177
172- @compiler . disable
178+ @disable_compiler_collective
173179 def reduce (self , tensor , dst , op = ReduceOp .SUM , group = None , async_op = False ):
174180 if DS_COMM_REDUCE_OFF :
175181 if int (os .getenv ('RANK' , '0' )) == 0 :
176182 utils .logger .warning ("REDUCE is OFF" )
177183 return Noop ()
178184 return torch .distributed .reduce (tensor = tensor , dst = dst , op = self ._reduce_op (op ), group = group , async_op = async_op )
179185
180- @compiler . disable
186+ @disable_compiler_collective
181187 def reduce_scatter (self , output , input_list , op = ReduceOp .SUM , group = None , async_op = False ):
182188 if DS_COMM_REDUCE_SCATTER_OFF :
183189 if int (os .getenv ('RANK' , '0' )) == 0 :
@@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
190196 group = group ,
191197 async_op = async_op )
192198
193- @compiler . disable
199+ @disable_compiler_collective
194200 def broadcast (self , tensor , src , group = None , async_op = False ):
195201 if DS_COMM_BROADCAST_OFF :
196202 if int (os .getenv ('RANK' , '0' )) == 0 :
@@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
199205 else :
200206 return torch .distributed .broadcast (tensor = tensor , src = src , group = group , async_op = async_op )
201207
202- @compiler . disable
208+ @disable_compiler_collective
203209 def all_gather (self , tensor_list , tensor , group = None , async_op = False ):
204210 if DS_COMM_ALL_GATHER_OFF :
205211 if int (os .getenv ('RANK' , '0' )) == 0 :
@@ -208,15 +214,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
208214 else :
209215 return torch .distributed .all_gather (tensor_list = tensor_list , tensor = tensor , group = group , async_op = async_op )
210216
211- @compiler . disable
217+ @disable_compiler_collective
212218 def all_gather_into_tensor (self , output_tensor , input_tensor , group = None , async_op = False ):
213219 if self .has_all_gather_into_tensor ():
214220 return self .all_gather_function (output_tensor = output_tensor ,
215221 input_tensor = input_tensor ,
216222 group = group ,
217223 async_op = async_op )
218224
219- @compiler . disable
225+ @disable_compiler_collective
220226 def all_gather_base (self , output_tensor , input_tensor , group = None , async_op = False ):
221227 if DS_COMM_ALL_GATHER_OFF :
222228 if int (os .getenv ('RANK' , '0' )) == 0 :
@@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
234240 "please consider upgrading your pytorch installation." )
235241 pass
236242
237- @compiler . disable
243+ @disable_compiler_collective
238244 def all_gather_coalesced (self , output_tensors , input_tensors , group = None , async_op = False ):
239245 """"""
240246 assert len (output_tensors ) == len (input_tensors ), ""
@@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
258264 else :
259265 reqs [- 1 ].wait ()
260266
261- @compiler . disable
267+ @disable_compiler_collective
262268 def reduce_scatter_tensor (self , output_tensor , input_tensor , op = ReduceOp .SUM , group = None , async_op = False ):
263269 if self .has_reduce_scatter_tensor ():
264270 return self .reduce_scatter_function (output_tensor ,
@@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
272278 "please consider upgrading your pytorch installation." )
273279 pass
274280
275- @compiler . disable
281+ @disable_compiler_collective
276282 def all_to_all_single (self ,
277283 output ,
278284 input ,
@@ -287,49 +293,49 @@ def all_to_all_single(self,
287293 group = group ,
288294 async_op = async_op )
289295
290- @compiler . disable
296+ @disable_compiler_collective
291297 def all_to_all (self , output_tensor_list , input_tensor_list , group = None , async_op = False ):
292298 return torch .distributed .all_to_all (output_tensor_list , input_tensor_list , group = group , async_op = async_op )
293299
294- @compiler . disable
300+ @disable_compiler_collective
295301 def send (self , tensor , dst , group = None , tag = 0 ):
296302 return torch .distributed .send (tensor = tensor , dst = dst , group = group , tag = tag )
297303
298- @compiler . disable
304+ @disable_compiler_collective
299305 def recv (self , tensor , src = None , group = None , tag = 0 ):
300306 return torch .distributed .recv (tensor = tensor , src = src , group = group , tag = tag )
301307
302- @compiler . disable
308+ @disable_compiler_collective
303309 def isend (self , tensor , dst , group = None , tag = 0 ):
304310 return torch .distributed .isend (tensor = tensor , dst = dst , group = group , tag = tag )
305311
306- @compiler . disable
312+ @disable_compiler_collective
307313 def irecv (self , tensor , src = None , group = None , tag = 0 ):
308314 return torch .distributed .irecv (tensor = tensor , src = src , group = group , tag = tag )
309315
310- @compiler . disable
316+ @disable_compiler_collective
311317 def gather (self , tensor , gather_list = None , dst = 0 , group = None , async_op = False ):
312318 return torch .distributed .gather (tensor = tensor ,
313319 gather_list = gather_list ,
314320 dst = dst ,
315321 group = group ,
316322 async_op = async_op )
317323
318- @compiler . disable
324+ @disable_compiler_collective
319325 def scatter (self , tensor , scatter_list = None , src = 0 , group = None , async_op = False ):
320326 return torch .distributed .scatter (tensor = tensor ,
321327 scatter_list = scatter_list ,
322328 src = src ,
323329 group = group ,
324330 async_op = async_op )
325331
326- @compiler . disable
332+ @disable_compiler_collective
327333 def barrier (self , group = torch .distributed .GroupMember .WORLD , async_op = False , device_ids = None ):
328334 if group is None :
329335 group = torch .distributed .GroupMember .WORLD
330336 return torch .distributed .barrier (group = group , async_op = async_op , device_ids = device_ids )
331337
332- @compiler . disable
338+ @disable_compiler_collective
333339 def monitored_barrier (self , group = torch .distributed .GroupMember .WORLD , timeout = None , wait_all_ranks = False ):
334340 if group is None :
335341 group = torch .distributed .GroupMember .WORLD
0 commit comments