@@ -278,7 +278,14 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278        raise  NotImplementedError ("not implemented" )
279279
280280    def  configure (
281-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
281+         self ,
282+         store_addr : str ,
283+         replica_id : str ,
284+         rank : int ,
285+         world_size : int ,
286+         quorum_id : Optional [int ] =  None ,
287+         group_rank : int  =  0 ,
288+         group_world_size : int  =  1 ,
282289    ) ->  None :
283290        """ 
284291        This reconfigures the ProcessGroup to use a new store, rank and world size. 
@@ -408,6 +415,9 @@ def __init__(
408415        self ._timeout  =  timeout 
409416        self ._replica_id : str  |  None  =  None 
410417        self ._rank : int  |  None  =  None 
418+         self ._quorum_id : int  |  None  =  None 
419+         self ._group_rank : int  =  0 
420+         self ._group_world_size : int  =  1 
411421
412422        self .errors_logger : logging .Logger  =  logging .getLogger ("torchft_errors" )
413423
@@ -419,13 +429,31 @@ def getBackendName(self) -> str:
419429        raise  NotImplementedError ("not implemented" )
420430
421431    def  configure (
422-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
432+         self ,
433+         store_addr : str ,
434+         replica_id : str ,
435+         rank : int ,
436+         world_size : int ,
437+         quorum_id : Optional [int ] =  None ,
438+         group_rank : int  =  0 ,
439+         group_world_size : int  =  1 ,
423440    ) ->  None :
424441        pg  =  self ._pg 
425442        self ._replica_id  =  replica_id 
443+         self ._quorum_id  =  quorum_id 
444+         self ._group_rank  =  group_rank 
445+         self ._group_world_size  =  group_world_size 
426446        self ._rank  =  rank 
427447        if  isinstance (pg , ProcessGroup ):
428-             pg .configure (store_addr , replica_id , rank , world_size )
448+             pg .configure (
449+                 store_addr ,
450+                 replica_id ,
451+                 rank ,
452+                 world_size ,
453+                 quorum_id ,
454+                 group_rank ,
455+                 group_world_size ,
456+             )
429457            return 
430458
431459        # abort if already initialized 
@@ -443,6 +471,7 @@ def abort(self, errored: bool = True) -> None:
443471                    "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444472                    "replica_id" : self ._replica_id ,
445473                    "rank" : self ._rank ,
474+                     "quorum_id" : self ._quorum_id ,
446475                    "error" : "process_group_abort" ,
447476                },
448477            )
@@ -615,6 +644,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615644        # pyre-fixme[16]: no attribute ProcessGroupGloo 
616645        backend_class  =  BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617646        backend_class ._set_sequence_number_for_group ()
647+         backend_class .options .global_ranks_in_group  =  list (range (world_size ))
648+         backend_class .options .group_name  =  f"torchft_quorum_{ self ._quorum_id } { self ._group_rank  %  self ._group_world_size }  
618649        pg ._register_backend (
619650            torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class 
620651        )
@@ -813,6 +844,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813844        opts  =  BaseProcessGroupNCCL .Options ()
814845        opts .config .blocking  =  False 
815846        opts .global_ranks_in_group  =  list (range (world_size ))
847+         opts .group_name  =  f"torchft_quorum_{ self ._quorum_id } { self ._group_rank  %  self ._group_world_size }  
816848
817849        pg  =  BaseProcessGroup (store , rank , world_size )
818850        pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +1011,12 @@ def __init__(self, rank: int, world: int) -> None:
9791011        self .configure_count  =  0 
9801012
9811013    def  configure (
982-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1014+         self ,
1015+         store_addr : str ,
1016+         replica_id : str ,
1017+         rank : int ,
1018+         world_size : int ,
1019+         quorum_id : Optional [int ] =  None ,
9831020    ) ->  None :
9841021        self .configure_count  +=  1 
9851022
@@ -1138,11 +1175,26 @@ def __init__(self, pg: ProcessGroup) -> None:
11381175        self ._error : Optional [Exception ] =  None 
11391176
11401177    def  configure (
1141-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1178+         self ,
1179+         store_addr : str ,
1180+         replica_id : str ,
1181+         rank : int ,
1182+         world_size : int ,
1183+         quorum_id : Optional [int ] =  None ,
1184+         group_rank : int  =  0 ,
1185+         group_world_size : int  =  1 ,
11421186    ) ->  None :
11431187        self ._error  =  None 
11441188
1145-         super ().configure (store_addr , replica_id , rank , world_size )
1189+         super ().configure (
1190+             store_addr ,
1191+             replica_id ,
1192+             rank ,
1193+             world_size ,
1194+             quorum_id ,
1195+             group_rank ,
1196+             group_world_size ,
1197+         )
11461198
11471199    def  report_error (self , e : Exception ) ->  None :
11481200        """ 
@@ -1194,11 +1246,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941246        self ._future_error : Optional [Exception ] =  None 
11951247
11961248    def  configure (
1197-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1249+         self ,
1250+         store_addr : str ,
1251+         replica_id : str ,
1252+         rank : int ,
1253+         world_size : int ,
1254+         quorum_id : Optional [int ] =  None ,
11981255    ) ->  None :
11991256        self ._future_error  =  None 
12001257
1201-         super ().configure (store_addr , replica_id , rank , world_size )
1258+         super ().configure (store_addr , replica_id , rank , world_size ,  quorum_id )
12021259
12031260    def  report_future_error (self , e : Exception ) ->  None :
12041261        """ 
@@ -1412,7 +1469,12 @@ def shutdown(self) -> None:
14121469            self ._p .kill ()
14131470
14141471    def  configure (
1415-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1472+         self ,
1473+         store_addr : str ,
1474+         replica_id : str ,
1475+         rank : int ,
1476+         world_size : int ,
1477+         quorum_id : Optional [int ] =  None ,
14161478    ) ->  None :
14171479        self ._world_size  =  world_size 
14181480
0 commit comments