@@ -611,6 +611,112 @@ def reduce_scatter_tensor_coalesced(
611611        )
612612
613613
614+ class  _ParallelWork (Work ):
615+     def  __init__ (self , works : List [Work ]) ->  None :
616+         super ().__init__ ()
617+         self ._works  =  works 
618+ 
619+     def  wait (self , timeout : Optional [timedelta ] =  None ) ->  bool :
620+         for  work  in  self ._works :
621+             if  timeout  is  not None :
622+                 work .wait (timeout = timeout )
623+             else :
624+                 work .wait ()
625+         return  True 
626+ 
627+     def  get_future (self ) ->  torch .futures .Future [object ]:
628+         futures  =  [work .get_future () for  work  in  self ._works ]
629+         return  torch .futures .collect_all (futures )
630+ 
631+ 
632+ class  ParallelProcessGroup (ProcessGroupWrapper ):
633+     def  __init__ (
634+         self ,
635+         base : ProcessGroupWrapper ,
636+         timeout : timedelta  =  timedelta (seconds = 60 ),
637+         count : int  =  10 ,
638+     ) ->  None :
639+         super ().__init__ (timeout = timeout )
640+ 
641+         self ._base  =  base 
642+         self ._count  =  count 
643+         self ._pgs  =  []
644+ 
645+         self ._create_pg  =  base ._create_pg 
646+ 
647+     def  configure (self , store_addr : str , rank : int , world_size : int ) ->  None :
648+         # abort if already initialized 
649+         self .abort ()
650+ 
651+         for  i  in  range (self ._count ):
652+             store  =  create_store_client (
653+                 f"{ store_addr } { i }  , timeout = self ._timeout 
654+             )
655+ 
656+             self ._pgs .append (self ._create_pg (store , rank , world_size ))
657+ 
658+         self ._pg  =  self ._pgs [0 ]
659+ 
660+     def  getBackendName (self ) ->  str :
661+         return  f"{ self ._base .getBackendName ()}  
662+ 
663+     def  _split_tensors (self , tensors : List [torch .Tensor ]) ->  List [List [torch .Tensor ]]:
664+         if  not  isinstance (tensors , (list , tuple )):
665+             tensors  =  [tensors ]
666+ 
667+         tensor_lists  =  [[] for  _  in  range (self ._count )]
668+         for  t  in  tensors :
669+             chunks  =  torch .tensor_split (t .view (- 1 ), self ._count , dim = 0 )
670+             for  i , chunk  in  enumerate (chunks ):
671+                 tensor_lists [i ].append (chunk )
672+ 
673+         return  tensor_lists 
674+ 
675+     def  allreduce (self , tensors : List [torch .Tensor ], opts : object ) ->  Work :
676+         tensor_lists  =  self ._split_tensors (tensors )
677+ 
678+         with  self ._run_context ():
679+             works  =  []
680+             for  i  in  range (self ._count ):
681+                 works .append (
682+                     self ._pgs [i ].allreduce (tensor_lists [i ], self ._opts_hook (opts ))
683+                 )
684+ 
685+             return  self ._wrap_work (_ParallelWork (works ), opts )
686+ 
687+     def  reduce (self , tensors : List [torch .Tensor ], dst : int , opts : object ) ->  Work :
688+         tensor_lists  =  self ._split_tensors (tensors )
689+ 
690+         with  self ._run_context ():
691+             works  =  []
692+             for  i  in  range (self ._count ):
693+                 works .append (
694+                     self ._pgs [i ].reduce (tensor_lists [i ], dst , self ._opts_hook (opts ))
695+                 )
696+ 
697+             return  self ._wrap_work (_ParallelWork (works ), opts )
698+ 
699+     def  send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) ->  Work :
700+         tensor_lists  =  self ._split_tensors (tensors )
701+ 
702+         with  self ._run_context ():
703+             works  =  []
704+             for  i  in  range (self ._count ):
705+                 works .append (self ._pgs [i ].send (tensor_lists [i ], dst_rank , tag ))
706+ 
707+             return  self ._wrap_work (_ParallelWork (works ), None )
708+ 
709+     def  recv (self , tensors : List [torch .Tensor ], src_rank : int , tag : int ) ->  Work :
710+         tensor_lists  =  self ._split_tensors (tensors )
711+ 
712+         with  self ._run_context ():
713+             works  =  []
714+             for  i  in  range (self ._count ):
715+                 works .append (self ._pgs [i ].recv (tensor_lists [i ], src_rank , tag ))
716+ 
717+             return  self ._wrap_work (_ParallelWork (works ), None )
718+ 
719+ 
614720class  _WorkCUDATimeout (Work ):
615721    def  __init__ (self , pg : ProcessGroup , work : Work , timeout : timedelta ) ->  None :
616722        super ().__init__ ()
0 commit comments