@@ -278,7 +278,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278        raise  NotImplementedError ("not implemented" )
279279
280280    def  configure (
281-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
281+         self ,
282+         store_addr : str ,
283+         replica_id : str ,
284+         rank : int ,
285+         world_size : int ,
286+         quorum_id : Optional [int ] =  None ,
287+         group_rank : Optional [int ] =  None ,
288+         group_world_size : Optional [int ] =  None ,
289+         global_ranks : Optional [list [int ]] =  None ,
282290    ) ->  None :
283291        """ 
284292        This reconfigures the ProcessGroup to use a new store, rank and world size. 
@@ -294,6 +302,10 @@ def configure(
294302            replica_id: the replica_id for this group 
295303            rank: rank of this process 
296304            world_size: world size of this process group 
305+             quorum_id: current quorum's identifier 
306+             group_rank: local rank within the replica group 
307+             group_world_size: the number of ranks within a replica 
308+             global_ranks: the global ranks part of this process group 
297309        """ 
298310        raise  NotImplementedError ("not implemented" )
299311
@@ -408,6 +420,10 @@ def __init__(
408420        self ._timeout  =  timeout 
409421        self ._replica_id : str  |  None  =  None 
410422        self ._rank : int  |  None  =  None 
423+         self ._quorum_id : int  |  None  =  None 
424+         self ._group_rank : int  |  None  =  None 
425+         self ._group_world_size : int  |  None  =  None 
426+         self ._global_ranks : list [int ] |  None  =  None 
411427
412428        self .errors_logger : logging .Logger  =  logging .getLogger ("torchft_errors" )
413429
@@ -419,13 +435,34 @@ def getBackendName(self) -> str:
419435        raise  NotImplementedError ("not implemented" )
420436
421437    def  configure (
422-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
438+         self ,
439+         store_addr : str ,
440+         replica_id : str ,
441+         rank : int ,
442+         world_size : int ,
443+         quorum_id : Optional [int ] =  None ,
444+         group_rank : Optional [int ] =  None ,
445+         group_world_size : Optional [int ] =  None ,
446+         global_ranks : Optional [list [int ]] =  None ,
423447    ) ->  None :
424448        pg  =  self ._pg 
425449        self ._replica_id  =  replica_id 
450+         self ._quorum_id  =  quorum_id 
451+         self ._group_rank  =  group_rank 
452+         self ._group_world_size  =  group_world_size 
426453        self ._rank  =  rank 
454+         self ._global_ranks  =  global_ranks 
427455        if  isinstance (pg , ProcessGroup ):
428-             pg .configure (store_addr , replica_id , rank , world_size )
456+             pg .configure (
457+                 store_addr ,
458+                 replica_id ,
459+                 rank ,
460+                 world_size ,
461+                 quorum_id ,
462+                 group_rank ,
463+                 group_world_size ,
464+                 global_ranks ,
465+             )
429466            return 
430467
431468        # abort if already initialized 
@@ -443,6 +480,7 @@ def abort(self, errored: bool = True) -> None:
443480                    "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444481                    "replica_id" : self ._replica_id ,
445482                    "rank" : self ._rank ,
483+                     "quorum_id" : self ._quorum_id ,
446484                    "error" : "process_group_abort" ,
447485                },
448486            )
@@ -615,6 +653,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615653        # pyre-fixme[16]: no attribute ProcessGroupGloo 
616654        backend_class  =  BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617655        backend_class ._set_sequence_number_for_group ()
656+ 
657+         if  self ._global_ranks :
658+             backend_class .options .global_ranks_in_group  =  self ._global_ranks 
659+         if  self ._group_rank  and  self ._group_world_size :
660+             backend_class .options .group_name  =  f"torchft_quorum_{ self ._quorum_id } { self ._group_rank  %  self ._group_world_size }  
661+ 
618662        pg ._register_backend (
619663            torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class 
620664        )
@@ -812,7 +856,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
812856        # pyre-fixme[16]: no attribute ProcessGroupNCCL 
813857        opts  =  BaseProcessGroupNCCL .Options ()
814858        opts .config .blocking  =  False 
815-         opts .global_ranks_in_group  =  list (range (world_size ))
859+         if  self ._global_ranks :
860+             opts .global_ranks_in_group  =  self ._global_ranks 
861+         if  self ._group_rank  and  self ._group_world_size :
862+             opts .group_name  =  f"torchft_quorum_{ self ._quorum_id } { self ._group_rank  %  self ._group_world_size }  
816863
817864        pg  =  BaseProcessGroup (store , rank , world_size )
818865        pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +1026,15 @@ def __init__(self, rank: int, world: int) -> None:
9791026        self .configure_count  =  0 
9801027
9811028    def  configure (
982-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1029+         self ,
1030+         store_addr : str ,
1031+         replica_id : str ,
1032+         rank : int ,
1033+         world_size : int ,
1034+         quorum_id : Optional [int ] =  None ,
1035+         group_rank : Optional [int ] =  None ,
1036+         group_world_size : Optional [int ] =  None ,
1037+         global_ranks : Optional [list [int ]] =  None ,
9831038    ) ->  None :
9841039        self .configure_count  +=  1 
9851040
@@ -1138,11 +1193,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11381193        self ._error : Optional [Exception ] =  None 
11391194
11401195    def  configure (
1141-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1196+         self ,
1197+         store_addr : str ,
1198+         replica_id : str ,
1199+         rank : int ,
1200+         world_size : int ,
1201+         quorum_id : Optional [int ] =  None ,
1202+         group_rank : Optional [int ] =  None ,
1203+         group_world_size : Optional [int ] =  None ,
1204+         global_ranks : Optional [list [int ]] =  None ,
11421205    ) ->  None :
11431206        self ._error  =  None 
11441207
1145-         super ().configure (store_addr , replica_id , rank , world_size )
1208+         super ().configure (
1209+             store_addr ,
1210+             replica_id ,
1211+             rank ,
1212+             world_size ,
1213+             quorum_id ,
1214+             group_rank ,
1215+             group_world_size ,
1216+             global_ranks ,
1217+         )
11461218
11471219    def  report_error (self , e : Exception ) ->  None :
11481220        """ 
@@ -1194,11 +1266,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11941266        self ._future_error : Optional [Exception ] =  None 
11951267
11961268    def  configure (
1197-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1269+         self ,
1270+         store_addr : str ,
1271+         replica_id : str ,
1272+         rank : int ,
1273+         world_size : int ,
1274+         quorum_id : Optional [int ] =  None ,
1275+         group_rank : Optional [int ] =  None ,
1276+         group_world_size : Optional [int ] =  None ,
1277+         global_ranks : Optional [list [int ]] =  None ,
11981278    ) ->  None :
11991279        self ._future_error  =  None 
12001280
1201-         super ().configure (store_addr , replica_id , rank , world_size )
1281+         super ().configure (
1282+             store_addr ,
1283+             replica_id ,
1284+             rank ,
1285+             world_size ,
1286+             quorum_id ,
1287+             group_rank ,
1288+             group_world_size ,
1289+             global_ranks ,
1290+         )
12021291
12031292    def  report_future_error (self , e : Exception ) ->  None :
12041293        """ 
@@ -1412,7 +1501,15 @@ def shutdown(self) -> None:
14121501            self ._p .kill ()
14131502
14141503    def  configure (
1415-         self , store_addr : str , replica_id : str , rank : int , world_size : int 
1504+         self ,
1505+         store_addr : str ,
1506+         replica_id : str ,
1507+         rank : int ,
1508+         world_size : int ,
1509+         quorum_id : Optional [int ] =  None ,
1510+         group_rank : Optional [int ] =  None ,
1511+         group_world_size : Optional [int ] =  None ,
1512+         global_ranks : Optional [list [int ]] =  None ,
14161513    ) ->  None :
14171514        self ._world_size  =  world_size 
14181515
0 commit comments