8888# crash if call to quorum fails, all replicas will crash.
8989QUORUM_RETRIES_ENV : str = "TORCHFT_QUORUM_RETRIES"
9090
91+ TORCH_FR_DUMP_TEMP_FILE_ENV : str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193T = TypeVar ("T" )
9294
9395
@@ -109,6 +111,17 @@ def get_timeout(
109111 return default_timeout_sec
110112
111113
114+ def extract_trailing_digits (s : str ) -> int :
115+ """
116+ Extracts the trailing digits from the end of the string s.
117+ Returns an empty string if no trailing digits are found.
118+ """
119+ i = len (s ) - 1
120+ while i >= 0 and s [i ].isdigit ():
121+ i -= 1
122+ return int (s [i + 1 :]) if i < len (s ) - 1 else 0
123+
124+
112125class WorldSizeMode (Enum ):
113126 """
114127 This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236 self ._load_state_dict_fns : Dict [str , Callable [[object ], None ]] = {}
224237 self ._user_state_dicts : Dict [str , Callable [[], object ]] = {}
225238
239+ self ._original_fr_dump_temp_file : Optional [str ] = os .environ .get (
240+ TORCH_FR_DUMP_TEMP_FILE_ENV
241+ )
226242 self ._replica_id = replica_id
227243
228244 # Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273 store_port = store_port or int (os .environ ["MASTER_PORT" ])
258274 self ._group_rank : int = rank if rank is not None else int (os .environ ["RANK" ])
259275 group_rank = self ._group_rank
260- group_world_size = world_size or int (os .environ ["WORLD_SIZE" ])
276+ self . _group_world_size : int = world_size or int (os .environ ["WORLD_SIZE" ])
261277 self ._min_replica_size = min_replica_size
262278
263279 if checkpoint_transport is None :
@@ -310,7 +326,7 @@ def __init__(
310326 hostname = hostname ,
311327 bind = bind ,
312328 store_addr = f"{ store_addr } :{ store_port } " ,
313- world_size = group_world_size ,
329+ world_size = self . _group_world_size ,
314330 heartbeat_interval = heartbeat_interval ,
315331 connect_timeout = connect_timeout ,
316332 quorum_retries = self ._quorum_retries ,
@@ -338,6 +354,8 @@ def __init__(
338354 self ._participating_replica_world_size : int = 0
339355 self ._is_state_dict_read_allowed = True
340356
357+ self ._update_fr_path ()
358+
341359 def allow_state_dict_read (self ) -> None :
342360 if self ._is_state_dict_read_allowed :
343361 return
@@ -665,16 +683,20 @@ def _async_quorum(
665683 self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
666684 # We use the replica rank and world as we want all replicas in the PG.
667685 try :
686+ self ._quorum_id = quorum_id
668687 with torch .profiler .record_function ("torchft::manager::_pg::configure" ):
688+ # Reset GPU state for Flight Recorder
669689 if torch .accelerator .is_available ():
670690 torch .accelerator .synchronize ()
691+ torch ._C ._distributed_c10d ._reset_fr_recording_nccl () # pyre-ignore
692+ self ._update_fr_path ()
693+
671694 self ._pg .configure (
672695 store_prefixed_addr ,
673696 self ._replica_id if self ._replica_id is not None else "0" ,
674697 replica_rank ,
675698 replica_world_size ,
676699 )
677- self ._quorum_id = quorum_id
678700 except Exception as e :
679701 self ._logger .exception (f"got exception in pg configure: { e } " )
680702 self .report_error (e )
@@ -749,6 +771,21 @@ def _async_quorum(
749771 else None
750772 )
751773
774+ def _update_fr_path (self ) -> None :
775+ if self ._original_fr_dump_temp_file is not None :
776+ folder = f"{ self ._original_fr_dump_temp_file } _quorum_{ self ._quorum_id } "
777+ os .makedirs (folder , exist_ok = True )
778+
779+ filename = (
780+ self ._group_rank
781+ if self ._replica_id is None
782+ else (
783+ extract_trailing_digits (self ._replica_id ) * self ._group_world_size
784+ + self ._group_rank
785+ )
786+ )
787+ os .environ [TORCH_FR_DUMP_TEMP_FILE_ENV ] = f"{ folder } /{ filename } "
788+
752789 def _apply_pending_state_dict (self ) -> None :
753790 assert self ._healing , "must be in healing state"
754791
0 commit comments