Skip to content

Commit 9397313

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (meta-pytorch#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Reviewed By: d4l3k Differential Revision: D84260745
1 parent c8000cf commit 9397313

File tree

2 files changed

+86
-12
lines changed

2 files changed

+86
-12
lines changed

torchft/manager.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,8 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._update_fr_path()
358+
341359
def allow_state_dict_read(self) -> None:
342360
if self._is_state_dict_read_allowed:
343361
return
@@ -674,16 +692,21 @@ def _async_quorum(
674692
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
675693
# We use the replica rank and world as we want all replicas in the PG.
676694
try:
695+
self._quorum_id = quorum_id
677696
with torch.profiler.record_function("torchft::manager::_pg::configure"):
697+
# Reset GPU state for Flight Recorder
678698
if torch.accelerator.is_available():
679699
torch.accelerator.synchronize()
700+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
701+
self._update_fr_path()
702+
680703
self._pg.configure(
681704
store_prefixed_addr,
682705
self._replica_id if self._replica_id is not None else "0",
683706
replica_rank,
684707
replica_world_size,
708+
quorum_id,
685709
)
686-
self._quorum_id = quorum_id
687710
except Exception as e:
688711
self._logger.exception(f"got exception in pg configure: {e}")
689712
self.report_error(e)
@@ -758,6 +781,21 @@ def _async_quorum(
758781
else None
759782
)
760783

784+
def _update_fr_path(self) -> None:
785+
if self._original_fr_dump_temp_file is not None:
786+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
787+
os.makedirs(folder, exist_ok=True)
788+
789+
filename = (
790+
self._group_rank
791+
if self._replica_id is None
792+
else (
793+
extract_trailing_digits(self._replica_id) * self._group_world_size
794+
+ self._group_rank
795+
)
796+
)
797+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{filename}"
798+
761799
def _apply_pending_state_dict(self) -> None:
762800
assert self._healing, "must be in healing state"
763801

torchft/process_group.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278
raise NotImplementedError("not implemented")
279279

280280
def configure(
281-
self, store_addr: str, replica_id: str, rank: int, world_size: int
281+
self,
282+
store_addr: str,
283+
replica_id: str,
284+
rank: int,
285+
world_size: int,
286+
quorum_id: Optional[int] = None,
282287
) -> None:
283288
"""
284289
This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -408,6 +413,7 @@ def __init__(
408413
self._timeout = timeout
409414
self._replica_id: str | None = None
410415
self._rank: int | None = None
416+
self._quorum_id: int | None = None
411417

412418
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
413419

@@ -419,13 +425,19 @@ def getBackendName(self) -> str:
419425
raise NotImplementedError("not implemented")
420426

421427
def configure(
422-
self, store_addr: str, replica_id: str, rank: int, world_size: int
428+
self,
429+
store_addr: str,
430+
replica_id: str,
431+
rank: int,
432+
world_size: int,
433+
quorum_id: Optional[int] = None,
423434
) -> None:
424435
pg = self._pg
425436
self._replica_id = replica_id
437+
self._quorum_id = quorum_id
426438
self._rank = rank
427439
if isinstance(pg, ProcessGroup):
428-
pg.configure(store_addr, replica_id, rank, world_size)
440+
pg.configure(store_addr, replica_id, rank, world_size, quorum_id)
429441
return
430442

431443
# abort if already initialized
@@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None:
443455
"job_id": os.environ.get("JOB_ID", "unknown"),
444456
"replica_id": self._replica_id,
445457
"rank": self._rank,
458+
"quorum_id": self._quorum_id,
446459
"error": "process_group_abort",
447460
},
448461
)
@@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615628
# pyre-fixme[16]: no attribute ProcessGroupGloo
616629
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
617630
backend_class._set_sequence_number_for_group()
631+
backend_class.options.global_ranks_in_group = list(range(world_size))
632+
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}"
618633
pg._register_backend(
619634
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
620635
)
@@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813828
opts = BaseProcessGroupNCCL.Options()
814829
opts.config.blocking = False
815830
opts.global_ranks_in_group = list(range(world_size))
831+
opts.group_name = f"torchft_quorum_{self._quorum_id}"
816832

817833
pg = BaseProcessGroup(store, rank, world_size)
818834
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
@@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None:
979995
self.configure_count = 0
980996

981997
def configure(
982-
self, store_addr: str, replica_id: str, rank: int, world_size: int
998+
self,
999+
store_addr: str,
1000+
replica_id: str,
1001+
rank: int,
1002+
world_size: int,
1003+
quorum_id: Optional[int] = None,
9831004
) -> None:
9841005
self.configure_count += 1
9851006

@@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11381159
self._error: Optional[Exception] = None
11391160

11401161
def configure(
1141-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1162+
self,
1163+
store_addr: str,
1164+
replica_id: str,
1165+
rank: int,
1166+
world_size: int,
1167+
quorum_id: Optional[int] = None,
11421168
) -> None:
11431169
self._error = None
11441170

1145-
super().configure(store_addr, replica_id, rank, world_size)
1171+
super().configure(store_addr, replica_id, rank, world_size, quorum_id)
11461172

11471173
def report_error(self, e: Exception) -> None:
11481174
"""
@@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941220
self._future_error: Optional[Exception] = None
11951221

11961222
def configure(
1197-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1223+
self,
1224+
store_addr: str,
1225+
replica_id: str,
1226+
rank: int,
1227+
world_size: int,
1228+
quorum_id: Optional[int] = None,
11981229
) -> None:
11991230
self._future_error = None
12001231

1201-
super().configure(store_addr, replica_id, rank, world_size)
1232+
super().configure(store_addr, replica_id, rank, world_size, quorum_id)
12021233

12031234
def report_future_error(self, e: Exception) -> None:
12041235
"""
@@ -1412,7 +1443,12 @@ def shutdown(self) -> None:
14121443
self._p.kill()
14131444

14141445
def configure(
1415-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1446+
self,
1447+
store_addr: str,
1448+
replica_id: str,
1449+
rank: int,
1450+
world_size: int,
1451+
quorum_id: Optional[int] = None,
14161452
) -> None:
14171453
self._world_size = world_size
14181454

0 commit comments

Comments
 (0)