Skip to content

Commit dc0194c

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (meta-pytorch#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Reviewed By: d4l3k Differential Revision: D84260745
1 parent 73dafea commit dc0194c

File tree

2 files changed

+121
-12
lines changed

2 files changed

+121
-12
lines changed

torchft/manager.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,17 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._global_rank: int = (
358+
self._group_rank
359+
if self._replica_id is None
360+
else (
361+
extract_trailing_digits(self._replica_id) * self._group_world_size
362+
+ self._group_rank
363+
)
364+
)
365+
366+
self._update_fr_path()
367+
341368
def allow_state_dict_read(self) -> None:
342369
if self._is_state_dict_read_allowed:
343370
return
@@ -674,16 +701,30 @@ def _async_quorum(
674701
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
675702
# We use the replica rank and world as we want all replicas in the PG.
676703
try:
704+
self._quorum_id = quorum_id
677705
with torch.profiler.record_function("torchft::manager::_pg::configure"):
706+
# Reset GPU state for Flight Recorder
678707
if torch.accelerator.is_available():
679708
torch.accelerator.synchronize()
709+
680710
self._pg.configure(
681711
store_prefixed_addr,
682712
self._replica_id if self._replica_id is not None else "0",
683713
replica_rank,
684714
replica_world_size,
715+
quorum_id,
716+
self._group_rank,
717+
self._group_world_size,
685718
)
686-
self._quorum_id = quorum_id
719+
720+
self._update_fr_path()
721+
722+
# We need to reset the trace after reconfiguring the PG because that
723+
# calls abort which may trigger a dump
724+
self._logger.info(
725+
f"resetting fr recording for quorum id {self._quorum_id}"
726+
)
727+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
687728
except Exception as e:
688729
self._logger.exception(f"got exception in pg configure: {e}")
689730
self.report_error(e)
@@ -758,6 +799,12 @@ def _async_quorum(
758799
else None
759800
)
760801

802+
def _update_fr_path(self) -> None:
803+
if self._original_fr_dump_temp_file is not None:
804+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
805+
os.makedirs(folder, exist_ok=True)
806+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
807+
761808
def _apply_pending_state_dict(self) -> None:
762809
assert self._healing, "must be in healing state"
763810

torchft/process_group.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,14 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278
raise NotImplementedError("not implemented")
279279

280280
def configure(
281-
self, store_addr: str, replica_id: str, rank: int, world_size: int
281+
self,
282+
store_addr: str,
283+
replica_id: str,
284+
rank: int,
285+
world_size: int,
286+
quorum_id: Optional[int] = None,
287+
group_rank: int = 0,
288+
group_world_size: int = 1,
282289
) -> None:
283290
"""
284291
This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -408,6 +415,9 @@ def __init__(
408415
self._timeout = timeout
409416
self._replica_id: str | None = None
410417
self._rank: int | None = None
418+
self._quorum_id: int | None = None
419+
self._group_rank: int = 0
420+
self._group_world_size: int = 1
411421

412422
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
413423

@@ -419,13 +429,31 @@ def getBackendName(self) -> str:
419429
raise NotImplementedError("not implemented")
420430

421431
def configure(
422-
self, store_addr: str, replica_id: str, rank: int, world_size: int
432+
self,
433+
store_addr: str,
434+
replica_id: str,
435+
rank: int,
436+
world_size: int,
437+
quorum_id: Optional[int] = None,
438+
group_rank: int = 0,
439+
group_world_size: int = 1,
423440
) -> None:
424441
pg = self._pg
425442
self._replica_id = replica_id
443+
self._quorum_id = quorum_id
444+
self._group_rank = group_rank
445+
self._group_world_size = group_world_size
426446
self._rank = rank
427447
if isinstance(pg, ProcessGroup):
428-
pg.configure(store_addr, replica_id, rank, world_size)
448+
pg.configure(
449+
store_addr,
450+
replica_id,
451+
rank,
452+
world_size,
453+
quorum_id,
454+
group_rank,
455+
group_world_size,
456+
)
429457
return
430458

431459
# abort if already initialized
@@ -443,6 +471,7 @@ def abort(self, errored: bool = True) -> None:
443471
"job_id": os.environ.get("JOB_ID", "unknown"),
444472
"replica_id": self._replica_id,
445473
"rank": self._rank,
474+
"quorum_id": self._quorum_id,
446475
"error": "process_group_abort",
447476
},
448477
)
@@ -615,6 +644,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615644
# pyre-fixme[16]: no attribute ProcessGroupGloo
616645
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
617646
backend_class._set_sequence_number_for_group()
647+
backend_class.options.global_ranks_in_group = list(range(world_size))
648+
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
618649
pg._register_backend(
619650
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
620651
)
@@ -813,6 +844,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813844
opts = BaseProcessGroupNCCL.Options()
814845
opts.config.blocking = False
815846
opts.global_ranks_in_group = list(range(world_size))
847+
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
816848

817849
pg = BaseProcessGroup(store, rank, world_size)
818850
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
@@ -979,7 +1011,12 @@ def __init__(self, rank: int, world: int) -> None:
9791011
self.configure_count = 0
9801012

9811013
def configure(
982-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1014+
self,
1015+
store_addr: str,
1016+
replica_id: str,
1017+
rank: int,
1018+
world_size: int,
1019+
quorum_id: Optional[int] = None,
9831020
) -> None:
9841021
self.configure_count += 1
9851022

@@ -1138,11 +1175,26 @@ def __init__(self, pg: ProcessGroup) -> None:
11381175
self._error: Optional[Exception] = None
11391176

11401177
def configure(
1141-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1178+
self,
1179+
store_addr: str,
1180+
replica_id: str,
1181+
rank: int,
1182+
world_size: int,
1183+
quorum_id: Optional[int] = None,
1184+
group_rank: int = 0,
1185+
group_world_size: int = 1,
11421186
) -> None:
11431187
self._error = None
11441188

1145-
super().configure(store_addr, replica_id, rank, world_size)
1189+
super().configure(
1190+
store_addr,
1191+
replica_id,
1192+
rank,
1193+
world_size,
1194+
quorum_id,
1195+
group_rank,
1196+
group_world_size,
1197+
)
11461198

11471199
def report_error(self, e: Exception) -> None:
11481200
"""
@@ -1194,11 +1246,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941246
self._future_error: Optional[Exception] = None
11951247

11961248
def configure(
1197-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1249+
self,
1250+
store_addr: str,
1251+
replica_id: str,
1252+
rank: int,
1253+
world_size: int,
1254+
quorum_id: Optional[int] = None,
11981255
) -> None:
11991256
self._future_error = None
12001257

1201-
super().configure(store_addr, replica_id, rank, world_size)
1258+
super().configure(store_addr, replica_id, rank, world_size, quorum_id)
12021259

12031260
def report_future_error(self, e: Exception) -> None:
12041261
"""
@@ -1412,7 +1469,12 @@ def shutdown(self) -> None:
14121469
self._p.kill()
14131470

14141471
def configure(
1415-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1472+
self,
1473+
store_addr: str,
1474+
replica_id: str,
1475+
rank: int,
1476+
world_size: int,
1477+
quorum_id: Optional[int] = None,
14161478
) -> None:
14171479
self._world_size = world_size
14181480

0 commit comments

Comments
 (0)