@@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278 raise NotImplementedError ("not implemented" )
279279
280280 def configure (
281- self , store_addr : str , replica_id : str , rank : int , world_size : int
281+ self ,
282+ store_addr : str ,
283+ replica_id : str ,
284+ rank : int ,
285+ world_size : int ,
286+ quorum_id : Optional [int ] = None ,
282287 ) -> None :
283288 """
284289 This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -408,6 +413,7 @@ def __init__(
408413 self ._timeout = timeout
409414 self ._replica_id : str | None = None
410415 self ._rank : int | None = None
416+ self ._quorum_id : int | None = None
411417
412418 self .errors_logger : logging .Logger = logging .getLogger ("torchft_errors" )
413419
@@ -419,13 +425,19 @@ def getBackendName(self) -> str:
419425 raise NotImplementedError ("not implemented" )
420426
421427 def configure (
422- self , store_addr : str , replica_id : str , rank : int , world_size : int
428+ self ,
429+ store_addr : str ,
430+ replica_id : str ,
431+ rank : int ,
432+ world_size : int ,
433+ quorum_id : Optional [int ] = None ,
423434 ) -> None :
424435 pg = self ._pg
425436 self ._replica_id = replica_id
437+ self ._quorum_id = quorum_id
426438 self ._rank = rank
427439 if isinstance (pg , ProcessGroup ):
428- pg .configure (store_addr , replica_id , rank , world_size )
440+ pg .configure (store_addr , replica_id , rank , world_size , quorum_id )
429441 return
430442
431443 # abort if already initialized
@@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None:
443455 "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444456 "replica_id" : self ._replica_id ,
445457 "rank" : self ._rank ,
458+ "quorum_id" : self ._quorum_id ,
446459 "error" : "process_group_abort" ,
447460 },
448461 )
@@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615628 # pyre-fixme[16]: no attribute ProcessGroupGloo
616629 backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617630 backend_class ._set_sequence_number_for_group ()
631+ backend_class .options .global_ranks_in_group = list (range (world_size ))
632+ backend_class .options .group_name = f"torchft_quorum_{ self ._quorum_id } "
618633 pg ._register_backend (
619634 torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
620635 )
@@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813828 opts = BaseProcessGroupNCCL .Options ()
814829 opts .config .blocking = False
815830 opts .global_ranks_in_group = list (range (world_size ))
831+ opts .group_name = f"torchft_quorum_{ self ._quorum_id } "
816832
817833 pg = BaseProcessGroup (store , rank , world_size )
818834 pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None:
979995 self .configure_count = 0
980996
981997 def configure (
982- self , store_addr : str , replica_id : str , rank : int , world_size : int
998+ self ,
999+ store_addr : str ,
1000+ replica_id : str ,
1001+ rank : int ,
1002+ world_size : int ,
1003+ quorum_id : Optional [int ] = None ,
9831004 ) -> None :
9841005 self .configure_count += 1
9851006
@@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11381159 self ._error : Optional [Exception ] = None
11391160
11401161 def configure (
1141- self , store_addr : str , replica_id : str , rank : int , world_size : int
1162+ self ,
1163+ store_addr : str ,
1164+ replica_id : str ,
1165+ rank : int ,
1166+ world_size : int ,
1167+ quorum_id : Optional [int ] = None ,
11421168 ) -> None :
11431169 self ._error = None
11441170
1145- super ().configure (store_addr , replica_id , rank , world_size )
1171+ super ().configure (store_addr , replica_id , rank , world_size , quorum_id )
11461172
11471173 def report_error (self , e : Exception ) -> None :
11481174 """
@@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941220 self ._future_error : Optional [Exception ] = None
11951221
11961222 def configure (
1197- self , store_addr : str , replica_id : str , rank : int , world_size : int
1223+ self ,
1224+ store_addr : str ,
1225+ replica_id : str ,
1226+ rank : int ,
1227+ world_size : int ,
1228+ quorum_id : Optional [int ] = None ,
11981229 ) -> None :
11991230 self ._future_error = None
12001231
1201- super ().configure (store_addr , replica_id , rank , world_size )
1232+ super ().configure (store_addr , replica_id , rank , world_size , quorum_id )
12021233
12031234 def report_future_error (self , e : Exception ) -> None :
12041235 """
@@ -1412,7 +1443,12 @@ def shutdown(self) -> None:
14121443 self ._p .kill ()
14131444
14141445 def configure (
1415- self , store_addr : str , replica_id : str , rank : int , world_size : int
1446+ self ,
1447+ store_addr : str ,
1448+ replica_id : str ,
1449+ rank : int ,
1450+ world_size : int ,
1451+ quorum_id : Optional [int ] = None ,
14161452 ) -> None :
14171453 self ._world_size = world_size
14181454
0 commit comments