@@ -278,7 +278,13 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278        raise  NotImplementedError ("not implemented" )
279279
280280    def  configure (
281-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
281+         self ,
282+         store_addr : str ,
283+         replica_id : str ,
284+         rank : int ,
285+         world_size : int ,
286+         quorum_id : Optional [int ] =  None ,
287+         group_world_size : int  =  1 ,
282288    ) ->  None :
283289        """ 
284290        This reconfigures the ProcessGroup to use a new store, rank and world size. 
@@ -408,6 +414,8 @@ def __init__(
408414        self ._timeout  =  timeout 
409415        self ._replica_id : str  |  None  =  None 
410416        self ._rank : int  |  None  =  None 
417+         self ._quorum_id : int  |  None  =  None 
418+         self ._group_world_size : int  =  1 
411419
412420        self .errors_logger : logging .Logger  =  logging .getLogger ("torchft_errors" )
413421
@@ -419,13 +427,23 @@ def getBackendName(self) -> str:
419427        raise  NotImplementedError ("not implemented" )
420428
421429    def  configure (
422-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
430+         self ,
431+         store_addr : str ,
432+         replica_id : str ,
433+         rank : int ,
434+         world_size : int ,
435+         quorum_id : Optional [int ] =  None ,
436+         group_world_size : int  =  1 ,
423437    ) ->  None :
424438        pg  =  self ._pg 
425439        self ._replica_id  =  replica_id 
440+         self ._quorum_id  =  quorum_id 
441+         self ._group_world_size  =  group_world_size 
426442        self ._rank  =  rank 
427443        if  isinstance (pg , ProcessGroup ):
428-             pg .configure (store_addr , replica_id , rank , world_size )
444+             pg .configure (
445+                 store_addr , replica_id , rank , world_size , quorum_id , group_world_size 
446+             )
429447            return 
430448
431449        # abort if already initialized 
@@ -443,6 +461,7 @@ def abort(self, errored: bool = True) -> None:
443461                    "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444462                    "replica_id" : self ._replica_id ,
445463                    "rank" : self ._rank ,
464+                     "quorum_id" : self ._quorum_id ,
446465                    "error" : "process_group_abort" ,
447466                },
448467            )
@@ -615,6 +634,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615634        # pyre-fixme[16]: no attribute ProcessGroupGloo 
616635        backend_class  =  BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617636        backend_class ._set_sequence_number_for_group ()
637+         backend_class .options .global_ranks_in_group  =  list (range (world_size ))
638+         backend_class .options .group_name  =  (
639+             f"torchft_quorum_{ self ._quorum_id } { rank  %  self ._group_world_size }  
640+         )
618641        pg ._register_backend (
619642            torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class 
620643        )
@@ -813,6 +836,9 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813836        opts  =  BaseProcessGroupNCCL .Options ()
814837        opts .config .blocking  =  False 
815838        opts .global_ranks_in_group  =  list (range (world_size ))
839+         opts .group_name  =  (
840+             f"torchft_quorum_{ self ._quorum_id } { rank  %  self ._group_world_size }  
841+         )
816842
817843        pg  =  BaseProcessGroup (store , rank , world_size )
818844        pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +1005,12 @@ def __init__(self, rank: int, world: int) -> None:
9791005        self .configure_count  =  0 
9801006
9811007    def  configure (
982-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1008+         self ,
1009+         store_addr : str ,
1010+         replica_id : str ,
1011+         rank : int ,
1012+         world_size : int ,
1013+         quorum_id : Optional [int ] =  None ,
9831014    ) ->  None :
9841015        self .configure_count  +=  1 
9851016
@@ -1138,11 +1169,19 @@ def __init__(self, pg: ProcessGroup) -> None:
11381169        self ._error : Optional [Exception ] =  None 
11391170
11401171    def  configure (
1141-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1172+         self ,
1173+         store_addr : str ,
1174+         replica_id : str ,
1175+         rank : int ,
1176+         world_size : int ,
1177+         quorum_id : Optional [int ] =  None ,
1178+         group_world_size : int  =  1 ,
11421179    ) ->  None :
11431180        self ._error  =  None 
11441181
1145-         super ().configure (store_addr , replica_id , rank , world_size )
1182+         super ().configure (
1183+             store_addr , replica_id , rank , world_size , quorum_id , group_world_size 
1184+         )
11461185
11471186    def  report_error (self , e : Exception ) ->  None :
11481187        """ 
@@ -1194,11 +1233,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941233        self ._future_error : Optional [Exception ] =  None 
11951234
11961235    def  configure (
1197-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1236+         self ,
1237+         store_addr : str ,
1238+         replica_id : str ,
1239+         rank : int ,
1240+         world_size : int ,
1241+         quorum_id : Optional [int ] =  None ,
11981242    ) ->  None :
11991243        self ._future_error  =  None 
12001244
1201-         super ().configure (store_addr , replica_id , rank , world_size )
1245+         super ().configure (store_addr , replica_id , rank , world_size ,  quorum_id )
12021246
12031247    def  report_future_error (self , e : Exception ) ->  None :
12041248        """ 
@@ -1412,7 +1456,12 @@ def shutdown(self) -> None:
14121456            self ._p .kill ()
14131457
14141458    def  configure (
1415-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1459+         self ,
1460+         store_addr : str ,
1461+         replica_id : str ,
1462+         rank : int ,
1463+         world_size : int ,
1464+         quorum_id : Optional [int ] =  None ,
14161465    ) ->  None :
14171466        self ._world_size  =  world_size 
14181467
0 commit comments