@@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278        raise  NotImplementedError ("not implemented" )
279279
280280    def  configure (
281-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
281+         self ,
282+         store_addr : str ,
283+         replica_id : str ,
284+         rank : int ,
285+         world_size : int ,
286+         quorum_id : Optional [int ] =  None ,
282287    ) ->  None :
283288        """ 
284289        This reconfigures the ProcessGroup to use a new store, rank and world size. 
@@ -408,6 +413,7 @@ def __init__(
408413        self ._timeout  =  timeout 
409414        self ._replica_id : str  |  None  =  None 
410415        self ._rank : int  |  None  =  None 
416+         self ._quorum_id : int  |  None  =  None 
411417
412418        self .errors_logger : logging .Logger  =  logging .getLogger ("torchft_errors" )
413419
@@ -419,13 +425,19 @@ def getBackendName(self) -> str:
419425        raise  NotImplementedError ("not implemented" )
420426
421427    def  configure (
422-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
428+         self ,
429+         store_addr : str ,
430+         replica_id : str ,
431+         rank : int ,
432+         world_size : int ,
433+         quorum_id : Optional [int ] =  None ,
423434    ) ->  None :
424435        pg  =  self ._pg 
425436        self ._replica_id  =  replica_id 
437+         self ._quorum_id  =  quorum_id 
426438        self ._rank  =  rank 
427439        if  isinstance (pg , ProcessGroup ):
428-             pg .configure (store_addr , replica_id , rank , world_size )
440+             pg .configure (store_addr , replica_id , rank , world_size ,  quorum_id )
429441            return 
430442
431443        # abort if already initialized 
@@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None:
443455                    "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444456                    "replica_id" : self ._replica_id ,
445457                    "rank" : self ._rank ,
458+                     "quorum_id" : self ._quorum_id ,
446459                    "error" : "process_group_abort" ,
447460                },
448461            )
@@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615628        # pyre-fixme[16]: no attribute ProcessGroupGloo 
616629        backend_class  =  BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617630        backend_class ._set_sequence_number_for_group ()
631+         backend_class .options .global_ranks_in_group  =  list (range (world_size ))
632+         backend_class .options .group_name  =  f"torchft_quorum_{ self ._quorum_id }  
618633        pg ._register_backend (
619634            torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class 
620635        )
@@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813828        opts  =  BaseProcessGroupNCCL .Options ()
814829        opts .config .blocking  =  False 
815830        opts .global_ranks_in_group  =  list (range (world_size ))
831+         opts .group_name  =  f"torchft_quorum_{ self ._quorum_id }  
816832
817833        pg  =  BaseProcessGroup (store , rank , world_size )
818834        pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None:
979995        self .configure_count  =  0 
980996
981997    def  configure (
982-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
998+         self ,
999+         store_addr : str ,
1000+         replica_id : str ,
1001+         rank : int ,
1002+         world_size : int ,
1003+         quorum_id : Optional [int ] =  None ,
9831004    ) ->  None :
9841005        self .configure_count  +=  1 
9851006
@@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11381159        self ._error : Optional [Exception ] =  None 
11391160
11401161    def  configure (
1141-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1162+         self ,
1163+         store_addr : str ,
1164+         replica_id : str ,
1165+         rank : int ,
1166+         world_size : int ,
1167+         quorum_id : Optional [int ] =  None ,
11421168    ) ->  None :
11431169        self ._error  =  None 
11441170
1145-         super ().configure (store_addr , replica_id , rank , world_size )
1171+         super ().configure (store_addr , replica_id , rank , world_size ,  quorum_id )
11461172
11471173    def  report_error (self , e : Exception ) ->  None :
11481174        """ 
@@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941220        self ._future_error : Optional [Exception ] =  None 
11951221
11961222    def  configure (
1197-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1223+         self ,
1224+         store_addr : str ,
1225+         replica_id : str ,
1226+         rank : int ,
1227+         world_size : int ,
1228+         quorum_id : Optional [int ] =  None ,
11981229    ) ->  None :
11991230        self ._future_error  =  None 
12001231
1201-         super ().configure (store_addr , replica_id , rank , world_size )
1232+         super ().configure (store_addr , replica_id , rank , world_size ,  quorum_id )
12021233
12031234    def  report_future_error (self , e : Exception ) ->  None :
12041235        """ 
@@ -1412,7 +1443,12 @@ def shutdown(self) -> None:
14121443            self ._p .kill ()
14131444
14141445    def  configure (
1415-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1446+         self ,
1447+         store_addr : str ,
1448+         replica_id : str ,
1449+         rank : int ,
1450+         world_size : int ,
1451+         quorum_id : Optional [int ] =  None ,
14161452    ) ->  None :
14171453        self ._world_size  =  world_size 
14181454
0 commit comments