Skip to content

Commit 895e0e6

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Reviewed By: d4l3k Differential Revision: D84260745
1 parent 73dafea commit 895e0e6

File tree

2 files changed

+107
-12
lines changed

2 files changed

+107
-12
lines changed

torchft/manager.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,17 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._global_rank: int = (
358+
self._group_rank
359+
if self._replica_id is None
360+
else (
361+
extract_trailing_digits(self._replica_id) * self._group_world_size
362+
+ self._group_rank
363+
)
364+
)
365+
366+
self._update_fr_path()
367+
341368
def allow_state_dict_read(self) -> None:
342369
if self._is_state_dict_read_allowed:
343370
return
@@ -674,16 +701,29 @@ def _async_quorum(
674701
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
675702
# We use the replica rank and world as we want all replicas in the PG.
676703
try:
704+
self._quorum_id = quorum_id
677705
with torch.profiler.record_function("torchft::manager::_pg::configure"):
706+
# Reset GPU state for Flight Recorder
678707
if torch.accelerator.is_available():
679708
torch.accelerator.synchronize()
709+
680710
self._pg.configure(
681711
store_prefixed_addr,
682712
self._replica_id if self._replica_id is not None else "0",
683713
replica_rank,
684714
replica_world_size,
715+
quorum_id,
716+
self._group_world_size,
685717
)
686-
self._quorum_id = quorum_id
718+
719+
self._update_fr_path()
720+
721+
# We need to reset the trace after reconfiguring the PG because that
722+
# calls abort which may trigger a dump
723+
self._logger.info(
724+
f"resetting fr recording for quorum id {self._quorum_id}"
725+
)
726+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
687727
except Exception as e:
688728
self._logger.exception(f"got exception in pg configure: {e}")
689729
self.report_error(e)
@@ -758,6 +798,12 @@ def _async_quorum(
758798
else None
759799
)
760800

801+
def _update_fr_path(self) -> None:
802+
if self._original_fr_dump_temp_file is not None:
803+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
804+
os.makedirs(folder, exist_ok=True)
805+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
806+
761807
def _apply_pending_state_dict(self) -> None:
762808
assert self._healing, "must be in healing state"
763809

torchft/process_group.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,13 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278
raise NotImplementedError("not implemented")
279279

280280
def configure(
281-
self, store_addr: str, replica_id: str, rank: int, world_size: int
281+
self,
282+
store_addr: str,
283+
replica_id: str,
284+
rank: int,
285+
world_size: int,
286+
quorum_id: Optional[int] = None,
287+
group_world_size: int = 1,
282288
) -> None:
283289
"""
284290
This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -408,6 +414,8 @@ def __init__(
408414
self._timeout = timeout
409415
self._replica_id: str | None = None
410416
self._rank: int | None = None
417+
self._quorum_id: int | None = None
418+
self._group_world_size: int = 1
411419

412420
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
413421

@@ -419,13 +427,23 @@ def getBackendName(self) -> str:
419427
raise NotImplementedError("not implemented")
420428

421429
def configure(
422-
self, store_addr: str, replica_id: str, rank: int, world_size: int
430+
self,
431+
store_addr: str,
432+
replica_id: str,
433+
rank: int,
434+
world_size: int,
435+
quorum_id: Optional[int] = None,
436+
group_world_size: int = 1,
423437
) -> None:
424438
pg = self._pg
425439
self._replica_id = replica_id
440+
self._quorum_id = quorum_id
441+
self._group_world_size = group_world_size
426442
self._rank = rank
427443
if isinstance(pg, ProcessGroup):
428-
pg.configure(store_addr, replica_id, rank, world_size)
444+
pg.configure(
445+
store_addr, replica_id, rank, world_size, quorum_id, group_world_size
446+
)
429447
return
430448

431449
# abort if already initialized
@@ -443,6 +461,7 @@ def abort(self, errored: bool = True) -> None:
443461
"job_id": os.environ.get("JOB_ID", "unknown"),
444462
"replica_id": self._replica_id,
445463
"rank": self._rank,
464+
"quorum_id": self._quorum_id,
446465
"error": "process_group_abort",
447466
},
448467
)
@@ -615,6 +634,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615634
# pyre-fixme[16]: no attribute ProcessGroupGloo
616635
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
617636
backend_class._set_sequence_number_for_group()
637+
backend_class.options.global_ranks_in_group = list(range(world_size))
638+
backend_class.options.group_name = (
639+
f"torchft_quorum_{self._quorum_id}_rank_{rank % self._group_world_size}"
640+
)
618641
pg._register_backend(
619642
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
620643
)
@@ -813,6 +836,9 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813836
opts = BaseProcessGroupNCCL.Options()
814837
opts.config.blocking = False
815838
opts.global_ranks_in_group = list(range(world_size))
839+
opts.group_name = (
840+
f"torchft_quorum_{self._quorum_id}_rank_{rank % self._group_world_size}"
841+
)
816842

817843
pg = BaseProcessGroup(store, rank, world_size)
818844
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
@@ -979,7 +1005,12 @@ def __init__(self, rank: int, world: int) -> None:
9791005
self.configure_count = 0
9801006

9811007
def configure(
982-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1008+
self,
1009+
store_addr: str,
1010+
replica_id: str,
1011+
rank: int,
1012+
world_size: int,
1013+
quorum_id: Optional[int] = None,
9831014
) -> None:
9841015
self.configure_count += 1
9851016

@@ -1138,11 +1169,19 @@ def __init__(self, pg: ProcessGroup) -> None:
11381169
self._error: Optional[Exception] = None
11391170

11401171
def configure(
1141-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1172+
self,
1173+
store_addr: str,
1174+
replica_id: str,
1175+
rank: int,
1176+
world_size: int,
1177+
quorum_id: Optional[int] = None,
1178+
group_world_size: int = 1,
11421179
) -> None:
11431180
self._error = None
11441181

1145-
super().configure(store_addr, replica_id, rank, world_size)
1182+
super().configure(
1183+
store_addr, replica_id, rank, world_size, quorum_id, group_world_size
1184+
)
11461185

11471186
def report_error(self, e: Exception) -> None:
11481187
"""
@@ -1194,11 +1233,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941233
self._future_error: Optional[Exception] = None
11951234

11961235
def configure(
1197-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1236+
self,
1237+
store_addr: str,
1238+
replica_id: str,
1239+
rank: int,
1240+
world_size: int,
1241+
quorum_id: Optional[int] = None,
11981242
) -> None:
11991243
self._future_error = None
12001244

1201-
super().configure(store_addr, replica_id, rank, world_size)
1245+
super().configure(store_addr, replica_id, rank, world_size, quorum_id)
12021246

12031247
def report_future_error(self, e: Exception) -> None:
12041248
"""
@@ -1412,7 +1456,12 @@ def shutdown(self) -> None:
14121456
self._p.kill()
14131457

14141458
def configure(
1415-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1459+
self,
1460+
store_addr: str,
1461+
replica_id: str,
1462+
rank: int,
1463+
world_size: int,
1464+
quorum_id: Optional[int] = None,
14161465
) -> None:
14171466
self._world_size = world_size
14181467

0 commit comments

Comments
 (0)