From 369846528cae995d518c0a07631c2f544eb7a69d Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 24 Oct 2025 11:44:16 -0700 Subject: [PATCH] reset flight recorder trace (#283) Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Reviewed By: d4l3k Differential Revision: D84260745 --- proto/torchft.proto | 1 + src/lib.rs | 3 + src/manager.rs | 1 + torchft/_torchft.pyi | 1 + torchft/manager.py | 65 +++++++++++++++++++++- torchft/process_group.py | 117 +++++++++++++++++++++++++++++++++++---- 6 files changed, 175 insertions(+), 13 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 194ea44c..12fce286 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -96,6 +96,7 @@ message ManagerQuorumResponse { int64 replica_world_size = 10; bool heal = 11; int64 commit_failures = 12; + repeated string replica_ids = 13; } message CheckpointMetadataRequest { diff --git a/src/lib.rs b/src/lib.rs index 4b4da042..7291c09f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -213,6 +213,7 @@ impl ManagerClient { max_replica_rank: resp.max_replica_rank, max_world_size: resp.max_world_size, heal: resp.heal, + replica_ids: resp.replica_ids, }) }) } @@ -293,6 +294,7 @@ struct QuorumResult { max_replica_rank: Option, max_world_size: i64, heal: bool, + replica_ids: Vec, } #[pymethods] @@ -311,6 +313,7 @@ impl QuorumResult { max_replica_rank: None, max_world_size: 1, heal: false, + replica_ids: Vec::new(), } } } diff --git a/src/manager.rs b/src/manager.rs index d901caa6..816e06ab 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -620,6 +620,7 @@ fn compute_quorum_results( .map(|p| p.commit_failures) .max() .unwrap_or(0), + replica_ids: participants.iter().map(|p| p.replica_id.clone()).collect(), }) } diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index ff175bf0..95dbfd5b 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -36,6 +36,7 @@ class QuorumResult: max_world_size: int heal: bool commit_failures: int + replica_ids: list[str] class ManagerServer: def __init__( diff --git a/torchft/manager.py b/torchft/manager.py index f3aaff75..5a930a2b 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -88,6 +88,8 @@ # crash if call to quorum fails, all replicas will crash. QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES" +TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE" + T = TypeVar("T") @@ -109,6 +111,17 @@ def get_timeout( return default_timeout_sec +def extract_trailing_digits(s: str) -> int: + """ + Extracts the trailing digits from the end of the string s. + Returns an empty string if no trailing digits are found. + """ + i = len(s) - 1 + while i >= 0 and s[i].isdigit(): + i -= 1 + return int(s[i + 1 :]) if i < len(s) - 1 else 0 + + class WorldSizeMode(Enum): """ This controls the numerics for the job when doing allreduces across replicas @@ -223,6 +236,9 @@ def __init__( self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} self._user_state_dicts: Dict[str, Callable[[], object]] = {} + self._original_fr_dump_temp_file: Optional[str] = os.environ.get( + TORCH_FR_DUMP_TEMP_FILE_ENV + ) self._replica_id = replica_id # Protects state dict @@ -257,7 +273,7 @@ def __init__( store_port = store_port or int(os.environ["MASTER_PORT"]) self._group_rank: int = rank if rank is not None else int(os.environ["RANK"]) group_rank = self._group_rank - group_world_size = world_size or int(os.environ["WORLD_SIZE"]) + self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size if checkpoint_transport is None: @@ -310,7 +326,7 @@ def __init__( hostname=hostname, bind=bind, store_addr=f"{store_addr}:{store_port}", - world_size=group_world_size, + world_size=self._group_world_size, heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, quorum_retries=self._quorum_retries, @@ -338,6 +354,17 @@ def __init__( self._participating_replica_world_size: int = 0 self._is_state_dict_read_allowed = True + self._global_rank: int = ( + self._group_rank + if self._replica_id is None + else ( + extract_trailing_digits(self._replica_id) * self._group_world_size + + self._group_rank + ) + ) + + self._update_fr_path() + def allow_state_dict_read(self) -> None: if self._is_state_dict_read_allowed: return @@ -634,6 +661,13 @@ def _async_quorum( max_replica_rank = quorum.max_replica_rank max_replica_world_size = quorum.max_world_size heal = quorum.heal + replica_ids = quorum.replica_ids + + ranks_in_quorum = [ + extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size + + self._group_rank + for replica_id in replica_ids + ] # When using async quorum we need to take the recovered workers. # When not using async quorum we need to take the max world size as all @@ -674,16 +708,30 @@ def _async_quorum( self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. try: + self._quorum_id = quorum_id with torch.profiler.record_function("torchft::manager::_pg::configure"): + # Reset GPU state for Flight Recorder if torch.accelerator.is_available(): torch.accelerator.synchronize() + self._pg.configure( store_prefixed_addr, self._replica_id if self._replica_id is not None else "0", replica_rank, replica_world_size, + quorum_id, + self._group_rank, + self._group_world_size, + ranks_in_quorum, ) - self._quorum_id = quorum_id + + # We need to reset the trace after reconfiguring the PG because that + # calls abort which may trigger a dump + self._logger.info( + f"resetting fr recording for quorum id {self._quorum_id}" + ) + self._update_fr_path() + torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore except Exception as e: self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) @@ -758,6 +806,17 @@ def _async_quorum( else None ) + def _update_fr_path(self) -> None: + """ + Update the path that flight recorder will dump the traces to. + The format is + _quorum_/ + """ + if self._original_fr_dump_temp_file is not None: + folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}" + os.makedirs(folder, exist_ok=True) + os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}" + def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" diff --git a/torchft/process_group.py b/torchft/process_group.py index c462928e..770d441c 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -278,7 +278,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: """ This reconfigures the ProcessGroup to use a new store, rank and world size. @@ -294,6 +302,10 @@ def configure( replica_id: the replica_id for this group rank: rank of this process world_size: world size of this process group + quorum_id: current quorum's identifier + group_rank: local rank within the replica group + group_world_size: the number of ranks within a replica + global_ranks: the global ranks part of this process group """ raise NotImplementedError("not implemented") @@ -408,6 +420,10 @@ def __init__( self._timeout = timeout self._replica_id: str | None = None self._rank: int | None = None + self._quorum_id: int | None = None + self._group_rank: int | None = None + self._group_world_size: int | None = None + self._global_ranks: list[int] | None = None self.errors_logger: logging.Logger = logging.getLogger("torchft_errors") @@ -419,13 +435,34 @@ def getBackendName(self) -> str: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: pg = self._pg self._replica_id = replica_id + self._quorum_id = quorum_id + self._group_rank = group_rank + self._group_world_size = group_world_size self._rank = rank + self._global_ranks = global_ranks if isinstance(pg, ProcessGroup): - pg.configure(store_addr, replica_id, rank, world_size) + pg.configure( + store_addr, + replica_id, + rank, + world_size, + quorum_id, + group_rank, + group_world_size, + global_ranks, + ) return # abort if already initialized @@ -443,6 +480,7 @@ def abort(self, errored: bool = True) -> None: "job_id": os.environ.get("JOB_ID", "unknown"), "replica_id": self._replica_id, "rank": self._rank, + "quorum_id": self._quorum_id, "error": "process_group_abort", }, ) @@ -615,6 +653,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro # pyre-fixme[16]: no attribute ProcessGroupGloo backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout) backend_class._set_sequence_number_for_group() + + if self._global_ranks: + backend_class.options.global_ranks_in_group = self._global_ranks + if self._group_rank and self._group_world_size: + backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}" + pg._register_backend( torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class ) @@ -812,7 +856,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro # pyre-fixme[16]: no attribute ProcessGroupNCCL opts = BaseProcessGroupNCCL.Options() opts.config.blocking = False - opts.global_ranks_in_group = list(range(world_size)) + if self._global_ranks: + opts.global_ranks_in_group = self._global_ranks + if self._group_rank and self._group_world_size: + opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}" pg = BaseProcessGroup(store, rank, world_size) pg._set_default_backend(ProcessGroup.BackendType.NCCL) @@ -979,7 +1026,15 @@ def __init__(self, rank: int, world: int) -> None: self.configure_count = 0 def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: self.configure_count += 1 @@ -1138,11 +1193,28 @@ def __init__(self, pg: ProcessGroup) -> None: self._error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: self._error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure( + store_addr, + replica_id, + rank, + world_size, + quorum_id, + group_rank, + group_world_size, + global_ranks, + ) def report_error(self, e: Exception) -> None: """ @@ -1194,11 +1266,28 @@ def __init__(self, pg: ProcessGroup) -> None: self._future_error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: self._future_error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure( + store_addr, + replica_id, + rank, + world_size, + quorum_id, + group_rank, + group_world_size, + global_ranks, + ) def report_future_error(self, e: Exception) -> None: """ @@ -1412,7 +1501,15 @@ def shutdown(self) -> None: self._p.kill() def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[list[int]] = None, ) -> None: self._world_size = world_size