8888# crash if call to quorum fails, all replicas will crash. 
8989QUORUM_RETRIES_ENV : str  =  "TORCHFT_QUORUM_RETRIES" 
9090
91+ TORCH_FR_DUMP_TEMP_FILE_ENV : str  =  "TORCH_FR_DUMP_TEMP_FILE" 
92+ 
9193T  =  TypeVar ("T" )
9294
9395
@@ -109,6 +111,17 @@ def get_timeout(
109111    return  default_timeout_sec 
110112
111113
114+ def  extract_trailing_digits (s : str ) ->  int :
115+     """ 
116+     Extracts the trailing digits from the end of the string s. 
117+     Returns an empty string if no trailing digits are found. 
118+     """ 
119+     i  =  len (s ) -  1 
120+     while  i  >=  0  and  s [i ].isdigit ():
121+         i  -=  1 
122+     return  int (s [i  +  1  :]) if  i  <  len (s ) -  1  else  0 
123+ 
124+ 
112125class  WorldSizeMode (Enum ):
113126    """ 
114127    This controls the numerics for the job when doing allreduces across replicas 
@@ -223,6 +236,9 @@ def __init__(
223236        self ._load_state_dict_fns : Dict [str , Callable [[object ], None ]] =  {}
224237        self ._user_state_dicts : Dict [str , Callable [[], object ]] =  {}
225238
239+         self ._original_fr_dump_temp_file : Optional [str ] =  os .environ .get (
240+             TORCH_FR_DUMP_TEMP_FILE_ENV 
241+         )
226242        self ._replica_id  =  replica_id 
227243
228244        # Protects state dict 
@@ -257,7 +273,7 @@ def __init__(
257273        store_port  =  store_port  or  int (os .environ ["MASTER_PORT" ])
258274        self ._group_rank : int  =  rank  if  rank  is  not None  else  int (os .environ ["RANK" ])
259275        group_rank  =  self ._group_rank 
260-         group_world_size  =  world_size  or  int (os .environ ["WORLD_SIZE" ])
276+         self . _group_world_size :  int  =  world_size  or  int (os .environ ["WORLD_SIZE" ])
261277        self ._min_replica_size  =  min_replica_size 
262278
263279        if  checkpoint_transport  is  None :
@@ -310,7 +326,7 @@ def __init__(
310326                hostname = hostname ,
311327                bind = bind ,
312328                store_addr = f"{ store_addr } { store_port }  ,
313-                 world_size = group_world_size ,
329+                 world_size = self . _group_world_size ,
314330                heartbeat_interval = heartbeat_interval ,
315331                connect_timeout = connect_timeout ,
316332                quorum_retries = self ._quorum_retries ,
@@ -338,6 +354,8 @@ def __init__(
338354        self ._participating_replica_world_size : int  =  0 
339355        self ._is_state_dict_read_allowed  =  True 
340356
357+         self ._update_fr_path ()
358+ 
341359    def  allow_state_dict_read (self ) ->  None :
342360        if  self ._is_state_dict_read_allowed :
343361            return 
@@ -665,16 +683,20 @@ def _async_quorum(
665683            self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = }  )
666684            # We use the replica rank and world as we want all replicas in the PG. 
667685            try :
686+                 self ._quorum_id  =  quorum_id 
668687                with  torch .profiler .record_function ("torchft::manager::_pg::configure" ):
688+                     # Reset GPU state for Flight Recorder 
669689                    if  torch .accelerator .is_available ():
670690                        torch .accelerator .synchronize ()
691+                     torch ._C ._distributed_c10d ._reset_fr_recording_nccl ()
692+                     self ._update_fr_path ()
693+ 
671694                    self ._pg .configure (
672695                        store_prefixed_addr ,
673696                        self ._replica_id  if  self ._replica_id  is  not None  else  "0" ,
674697                        replica_rank ,
675698                        replica_world_size ,
676699                    )
677-                 self ._quorum_id  =  quorum_id 
678700            except  Exception  as  e :
679701                self ._logger .exception (f"got exception in pg configure: { e }  )
680702                self .report_error (e )
@@ -749,6 +771,21 @@ def _async_quorum(
749771                    else  None 
750772                )
751773
774+     def  _update_fr_path (self ) ->  None :
775+         if  self ._original_fr_dump_temp_file  is  not None :
776+             folder  =  f"{ self ._original_fr_dump_temp_file } { self ._quorum_id }  
777+             os .makedirs (folder , exist_ok = True )
778+ 
779+             filename  =  (
780+                 self ._group_rank 
781+                 if  self ._replica_id  is  None 
782+                 else  (
783+                     extract_trailing_digits (self ._replica_id ) *  self ._group_world_size 
784+                     +  self ._group_rank 
785+                 )
786+             )
787+             os .environ [TORCH_FR_DUMP_TEMP_FILE_ENV ] =  f"{ folder } { filename }  
788+ 
752789    def  _apply_pending_state_dict (self ) ->  None :
753790        assert  self ._healing , "must be in healing state" 
754791
0 commit comments