Skip to content

Commit eaf4f9e

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (meta-pytorch#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Differential Revision: D84260745
1 parent 1ed8309 commit eaf4f9e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

torchft/manager.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,8 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._update_fr_path()
358+
341359
def allow_state_dict_read(self) -> None:
342360
if self._is_state_dict_read_allowed:
343361
return
@@ -665,16 +683,20 @@ def _async_quorum(
665683
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
666684
# We use the replica rank and world as we want all replicas in the PG.
667685
try:
686+
self._quorum_id = quorum_id
668687
with torch.profiler.record_function("torchft::manager::_pg::configure"):
688+
# Reset GPU state for Flight Recorder
669689
if torch.accelerator.is_available():
670690
torch.accelerator.synchronize()
691+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
692+
self._update_fr_path()
693+
671694
self._pg.configure(
672695
store_prefixed_addr,
673696
self._replica_id if self._replica_id is not None else "0",
674697
replica_rank,
675698
replica_world_size,
676699
)
677-
self._quorum_id = quorum_id
678700
except Exception as e:
679701
self._logger.exception(f"got exception in pg configure: {e}")
680702
self.report_error(e)
@@ -749,6 +771,21 @@ def _async_quorum(
749771
else None
750772
)
751773

774+
def _update_fr_path(self) -> None:
775+
if self._original_fr_dump_temp_file is not None:
776+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
777+
os.makedirs(folder, exist_ok=True)
778+
779+
filename = (
780+
self._group_rank
781+
if self._replica_id is None
782+
else (
783+
extract_trailing_digits(self._replica_id) * self._group_world_size
784+
+ self._group_rank
785+
)
786+
)
787+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{filename}"
788+
752789
def _apply_pending_state_dict(self) -> None:
753790
assert self._healing, "must be in healing state"
754791

0 commit comments

Comments
 (0)