Skip to content

Commit 3698465

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Reviewed By: d4l3k Differential Revision: D84260745
1 parent 73dafea commit 3698465

File tree

6 files changed

+175
-13
lines changed

6 files changed

+175
-13
lines changed

proto/torchft.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ message ManagerQuorumResponse {
9696
int64 replica_world_size = 10;
9797
bool heal = 11;
9898
int64 commit_failures = 12;
99+
repeated string replica_ids = 13;
99100
}
100101

101102
message CheckpointMetadataRequest {

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ impl ManagerClient {
213213
max_replica_rank: resp.max_replica_rank,
214214
max_world_size: resp.max_world_size,
215215
heal: resp.heal,
216+
replica_ids: resp.replica_ids,
216217
})
217218
})
218219
}
@@ -293,6 +294,7 @@ struct QuorumResult {
293294
max_replica_rank: Option<i64>,
294295
max_world_size: i64,
295296
heal: bool,
297+
replica_ids: Vec<String>,
296298
}
297299

298300
#[pymethods]
@@ -311,6 +313,7 @@ impl QuorumResult {
311313
max_replica_rank: None,
312314
max_world_size: 1,
313315
heal: false,
316+
replica_ids: Vec::new(),
314317
}
315318
}
316319
}

src/manager.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ fn compute_quorum_results(
620620
.map(|p| p.commit_failures)
621621
.max()
622622
.unwrap_or(0),
623+
replica_ids: participants.iter().map(|p| p.replica_id.clone()).collect(),
623624
})
624625
}
625626

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class QuorumResult:
3636
max_world_size: int
3737
heal: bool
3838
commit_failures: int
39+
replica_ids: list[str]
3940

4041
class ManagerServer:
4142
def __init__(

torchft/manager.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,17 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._global_rank: int = (
358+
self._group_rank
359+
if self._replica_id is None
360+
else (
361+
extract_trailing_digits(self._replica_id) * self._group_world_size
362+
+ self._group_rank
363+
)
364+
)
365+
366+
self._update_fr_path()
367+
341368
def allow_state_dict_read(self) -> None:
342369
if self._is_state_dict_read_allowed:
343370
return
@@ -634,6 +661,13 @@ def _async_quorum(
634661
max_replica_rank = quorum.max_replica_rank
635662
max_replica_world_size = quorum.max_world_size
636663
heal = quorum.heal
664+
replica_ids = quorum.replica_ids
665+
666+
ranks_in_quorum = [
667+
extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
668+
+ self._group_rank
669+
for replica_id in replica_ids
670+
]
637671

638672
# When using async quorum we need to take the recovered workers.
639673
# When not using async quorum we need to take the max world size as all
@@ -674,16 +708,30 @@ def _async_quorum(
674708
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
675709
# We use the replica rank and world as we want all replicas in the PG.
676710
try:
711+
self._quorum_id = quorum_id
677712
with torch.profiler.record_function("torchft::manager::_pg::configure"):
713+
# Reset GPU state for Flight Recorder
678714
if torch.accelerator.is_available():
679715
torch.accelerator.synchronize()
716+
680717
self._pg.configure(
681718
store_prefixed_addr,
682719
self._replica_id if self._replica_id is not None else "0",
683720
replica_rank,
684721
replica_world_size,
722+
quorum_id,
723+
self._group_rank,
724+
self._group_world_size,
725+
ranks_in_quorum,
685726
)
686-
self._quorum_id = quorum_id
727+
728+
# We need to reset the trace after reconfiguring the PG because that
729+
# calls abort which may trigger a dump
730+
self._logger.info(
731+
f"resetting fr recording for quorum id {self._quorum_id}"
732+
)
733+
self._update_fr_path()
734+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
687735
except Exception as e:
688736
self._logger.exception(f"got exception in pg configure: {e}")
689737
self.report_error(e)
@@ -758,6 +806,17 @@ def _async_quorum(
758806
else None
759807
)
760808

809+
def _update_fr_path(self) -> None:
810+
"""
811+
Update the path that flight recorder will dump the traces to.
812+
The format is
813+
<TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
814+
"""
815+
if self._original_fr_dump_temp_file is not None:
816+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
817+
os.makedirs(folder, exist_ok=True)
818+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
819+
761820
def _apply_pending_state_dict(self) -> None:
762821
assert self._healing, "must be in healing state"
763822

torchft/process_group.py

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278
raise NotImplementedError("not implemented")
279279

280280
def configure(
281-
self, store_addr: str, replica_id: str, rank: int, world_size: int
281+
self,
282+
store_addr: str,
283+
replica_id: str,
284+
rank: int,
285+
world_size: int,
286+
quorum_id: Optional[int] = None,
287+
group_rank: Optional[int] = None,
288+
group_world_size: Optional[int] = None,
289+
global_ranks: Optional[list[int]] = None,
282290
) -> None:
283291
"""
284292
This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -294,6 +302,10 @@ def configure(
294302
replica_id: the replica_id for this group
295303
rank: rank of this process
296304
world_size: world size of this process group
305+
quorum_id: current quorum's identifier
306+
group_rank: local rank within the replica group
307+
group_world_size: the number of ranks within a replica
308+
global_ranks: the global ranks part of this process group
297309
"""
298310
raise NotImplementedError("not implemented")
299311

@@ -408,6 +420,10 @@ def __init__(
408420
self._timeout = timeout
409421
self._replica_id: str | None = None
410422
self._rank: int | None = None
423+
self._quorum_id: int | None = None
424+
self._group_rank: int | None = None
425+
self._group_world_size: int | None = None
426+
self._global_ranks: list[int] | None = None
411427

412428
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
413429

@@ -419,13 +435,34 @@ def getBackendName(self) -> str:
419435
raise NotImplementedError("not implemented")
420436

421437
def configure(
422-
self, store_addr: str, replica_id: str, rank: int, world_size: int
438+
self,
439+
store_addr: str,
440+
replica_id: str,
441+
rank: int,
442+
world_size: int,
443+
quorum_id: Optional[int] = None,
444+
group_rank: Optional[int] = None,
445+
group_world_size: Optional[int] = None,
446+
global_ranks: Optional[list[int]] = None,
423447
) -> None:
424448
pg = self._pg
425449
self._replica_id = replica_id
450+
self._quorum_id = quorum_id
451+
self._group_rank = group_rank
452+
self._group_world_size = group_world_size
426453
self._rank = rank
454+
self._global_ranks = global_ranks
427455
if isinstance(pg, ProcessGroup):
428-
pg.configure(store_addr, replica_id, rank, world_size)
456+
pg.configure(
457+
store_addr,
458+
replica_id,
459+
rank,
460+
world_size,
461+
quorum_id,
462+
group_rank,
463+
group_world_size,
464+
global_ranks,
465+
)
429466
return
430467

431468
# abort if already initialized
@@ -443,6 +480,7 @@ def abort(self, errored: bool = True) -> None:
443480
"job_id": os.environ.get("JOB_ID", "unknown"),
444481
"replica_id": self._replica_id,
445482
"rank": self._rank,
483+
"quorum_id": self._quorum_id,
446484
"error": "process_group_abort",
447485
},
448486
)
@@ -615,6 +653,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615653
# pyre-fixme[16]: no attribute ProcessGroupGloo
616654
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
617655
backend_class._set_sequence_number_for_group()
656+
657+
if self._global_ranks:
658+
backend_class.options.global_ranks_in_group = self._global_ranks
659+
if self._group_rank and self._group_world_size:
660+
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
661+
618662
pg._register_backend(
619663
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
620664
)
@@ -812,7 +856,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
812856
# pyre-fixme[16]: no attribute ProcessGroupNCCL
813857
opts = BaseProcessGroupNCCL.Options()
814858
opts.config.blocking = False
815-
opts.global_ranks_in_group = list(range(world_size))
859+
if self._global_ranks:
860+
opts.global_ranks_in_group = self._global_ranks
861+
if self._group_rank and self._group_world_size:
862+
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
816863

817864
pg = BaseProcessGroup(store, rank, world_size)
818865
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
@@ -979,7 +1026,15 @@ def __init__(self, rank: int, world: int) -> None:
9791026
self.configure_count = 0
9801027

9811028
def configure(
982-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1029+
self,
1030+
store_addr: str,
1031+
replica_id: str,
1032+
rank: int,
1033+
world_size: int,
1034+
quorum_id: Optional[int] = None,
1035+
group_rank: Optional[int] = None,
1036+
group_world_size: Optional[int] = None,
1037+
global_ranks: Optional[list[int]] = None,
9831038
) -> None:
9841039
self.configure_count += 1
9851040

@@ -1138,11 +1193,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11381193
self._error: Optional[Exception] = None
11391194

11401195
def configure(
1141-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1196+
self,
1197+
store_addr: str,
1198+
replica_id: str,
1199+
rank: int,
1200+
world_size: int,
1201+
quorum_id: Optional[int] = None,
1202+
group_rank: Optional[int] = None,
1203+
group_world_size: Optional[int] = None,
1204+
global_ranks: Optional[list[int]] = None,
11421205
) -> None:
11431206
self._error = None
11441207

1145-
super().configure(store_addr, replica_id, rank, world_size)
1208+
super().configure(
1209+
store_addr,
1210+
replica_id,
1211+
rank,
1212+
world_size,
1213+
quorum_id,
1214+
group_rank,
1215+
group_world_size,
1216+
global_ranks,
1217+
)
11461218

11471219
def report_error(self, e: Exception) -> None:
11481220
"""
@@ -1194,11 +1266,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11941266
self._future_error: Optional[Exception] = None
11951267

11961268
def configure(
1197-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1269+
self,
1270+
store_addr: str,
1271+
replica_id: str,
1272+
rank: int,
1273+
world_size: int,
1274+
quorum_id: Optional[int] = None,
1275+
group_rank: Optional[int] = None,
1276+
group_world_size: Optional[int] = None,
1277+
global_ranks: Optional[list[int]] = None,
11981278
) -> None:
11991279
self._future_error = None
12001280

1201-
super().configure(store_addr, replica_id, rank, world_size)
1281+
super().configure(
1282+
store_addr,
1283+
replica_id,
1284+
rank,
1285+
world_size,
1286+
quorum_id,
1287+
group_rank,
1288+
group_world_size,
1289+
global_ranks,
1290+
)
12021291

12031292
def report_future_error(self, e: Exception) -> None:
12041293
"""
@@ -1412,7 +1501,15 @@ def shutdown(self) -> None:
14121501
self._p.kill()
14131502

14141503
def configure(
1415-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1504+
self,
1505+
store_addr: str,
1506+
replica_id: str,
1507+
rank: int,
1508+
world_size: int,
1509+
quorum_id: Optional[int] = None,
1510+
group_rank: Optional[int] = None,
1511+
group_world_size: Optional[int] = None,
1512+
global_ranks: Optional[list[int]] = None,
14161513
) -> None:
14171514
self._world_size = world_size
14181515

0 commit comments

Comments
 (0)