Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ message ManagerQuorumResponse {
int64 replica_world_size = 10;
bool heal = 11;
int64 commit_failures = 12;
repeated string replica_ids = 13;
}

message CheckpointMetadataRequest {
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ impl ManagerClient {
max_replica_rank: resp.max_replica_rank,
max_world_size: resp.max_world_size,
heal: resp.heal,
replica_ids: resp.replica_ids,
})
})
}
Expand Down Expand Up @@ -293,6 +294,7 @@ struct QuorumResult {
max_replica_rank: Option<i64>,
max_world_size: i64,
heal: bool,
replica_ids: Vec<String>,
}

#[pymethods]
Expand All @@ -311,6 +313,7 @@ impl QuorumResult {
max_replica_rank: None,
max_world_size: 1,
heal: false,
replica_ids: Vec::new(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ fn compute_quorum_results(
.map(|p| p.commit_failures)
.max()
.unwrap_or(0),
replica_ids: participants.iter().map(|p| p.replica_id.clone()).collect(),
})
}

Expand Down
1 change: 1 addition & 0 deletions torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class QuorumResult:
max_world_size: int
heal: bool
commit_failures: int
replica_ids: list[str]

class ManagerServer:
def __init__(
Expand Down
65 changes: 62 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
# crash if call to quorum fails, all replicas will crash.
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"

TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"

T = TypeVar("T")


Expand All @@ -109,6 +111,17 @@ def get_timeout(
return default_timeout_sec


def extract_trailing_digits(s: str) -> int:
"""
Extracts the trailing digits from the end of the string s.
Returns an empty string if no trailing digits are found.
"""
i = len(s) - 1
while i >= 0 and s[i].isdigit():
i -= 1
return int(s[i + 1 :]) if i < len(s) - 1 else 0


class WorldSizeMode(Enum):
"""
This controls the numerics for the job when doing allreduces across replicas
Expand Down Expand Up @@ -223,6 +236,9 @@ def __init__(
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
self._user_state_dicts: Dict[str, Callable[[], object]] = {}

self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
TORCH_FR_DUMP_TEMP_FILE_ENV
)
self._replica_id = replica_id

# Protects state dict
Expand Down Expand Up @@ -257,7 +273,7 @@ def __init__(
store_port = store_port or int(os.environ["MASTER_PORT"])
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
group_rank = self._group_rank
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

if checkpoint_transport is None:
Expand Down Expand Up @@ -310,7 +326,7 @@ def __init__(
hostname=hostname,
bind=bind,
store_addr=f"{store_addr}:{store_port}",
world_size=group_world_size,
world_size=self._group_world_size,
heartbeat_interval=heartbeat_interval,
connect_timeout=connect_timeout,
quorum_retries=self._quorum_retries,
Expand Down Expand Up @@ -338,6 +354,17 @@ def __init__(
self._participating_replica_world_size: int = 0
self._is_state_dict_read_allowed = True

self._global_rank: int = (
self._group_rank
if self._replica_id is None
else (
extract_trailing_digits(self._replica_id) * self._group_world_size
+ self._group_rank
)
)

self._update_fr_path()

def allow_state_dict_read(self) -> None:
if self._is_state_dict_read_allowed:
return
Expand Down Expand Up @@ -634,6 +661,13 @@ def _async_quorum(
max_replica_rank = quorum.max_replica_rank
max_replica_world_size = quorum.max_world_size
heal = quorum.heal
replica_ids = quorum.replica_ids

ranks_in_quorum = [
extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
+ self._group_rank
for replica_id in replica_ids
]

# When using async quorum we need to take the recovered workers.
# When not using async quorum we need to take the max world size as all
Expand Down Expand Up @@ -674,16 +708,30 @@ def _async_quorum(
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
# We use the replica rank and world as we want all replicas in the PG.
try:
self._quorum_id = quorum_id
with torch.profiler.record_function("torchft::manager::_pg::configure"):
# Reset GPU state for Flight Recorder
if torch.accelerator.is_available():
torch.accelerator.synchronize()

self._pg.configure(
store_prefixed_addr,
self._replica_id if self._replica_id is not None else "0",
replica_rank,
replica_world_size,
quorum_id,
self._group_rank,
self._group_world_size,
ranks_in_quorum,
)
self._quorum_id = quorum_id

# We need to reset the trace after reconfiguring the PG because that
# calls abort which may trigger a dump
self._logger.info(
f"resetting fr recording for quorum id {self._quorum_id}"
)
self._update_fr_path()
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
except Exception as e:
self._logger.exception(f"got exception in pg configure: {e}")
self.report_error(e)
Expand Down Expand Up @@ -758,6 +806,17 @@ def _async_quorum(
else None
)

def _update_fr_path(self) -> None:
"""
Update the path that flight recorder will dump the traces to.
The format is
<TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
"""
if self._original_fr_dump_temp_file is not None:
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
os.makedirs(folder, exist_ok=True)
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"

Expand Down
117 changes: 107 additions & 10 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
raise NotImplementedError("not implemented")

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
"""
This reconfigures the ProcessGroup to use a new store, rank and world size.
Expand All @@ -294,6 +302,10 @@ def configure(
replica_id: the replica_id for this group
rank: rank of this process
world_size: world size of this process group
quorum_id: current quorum's identifier
group_rank: local rank within the replica group
group_world_size: the number of ranks within a replica
global_ranks: the global ranks part of this process group
"""
raise NotImplementedError("not implemented")

Expand Down Expand Up @@ -408,6 +420,10 @@ def __init__(
self._timeout = timeout
self._replica_id: str | None = None
self._rank: int | None = None
self._quorum_id: int | None = None
self._group_rank: int | None = None
self._group_world_size: int | None = None
self._global_ranks: list[int] | None = None

self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")

Expand All @@ -419,13 +435,34 @@ def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
pg = self._pg
self._replica_id = replica_id
self._quorum_id = quorum_id
self._group_rank = group_rank
self._group_world_size = group_world_size
self._rank = rank
self._global_ranks = global_ranks
if isinstance(pg, ProcessGroup):
pg.configure(store_addr, replica_id, rank, world_size)
pg.configure(
store_addr,
replica_id,
rank,
world_size,
quorum_id,
group_rank,
group_world_size,
global_ranks,
)
return

# abort if already initialized
Expand All @@ -443,6 +480,7 @@ def abort(self, errored: bool = True) -> None:
"job_id": os.environ.get("JOB_ID", "unknown"),
"replica_id": self._replica_id,
"rank": self._rank,
"quorum_id": self._quorum_id,
"error": "process_group_abort",
},
)
Expand Down Expand Up @@ -615,6 +653,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
# pyre-fixme[16]: no attribute ProcessGroupGloo
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class._set_sequence_number_for_group()

if self._global_ranks:
backend_class.options.global_ranks_in_group = self._global_ranks
if self._group_rank and self._group_world_size:
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"

pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
Expand Down Expand Up @@ -812,7 +856,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
# pyre-fixme[16]: no attribute ProcessGroupNCCL
opts = BaseProcessGroupNCCL.Options()
opts.config.blocking = False
opts.global_ranks_in_group = list(range(world_size))
if self._global_ranks:
opts.global_ranks_in_group = self._global_ranks
if self._group_rank and self._group_world_size:
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"

pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
Expand Down Expand Up @@ -979,7 +1026,15 @@ def __init__(self, rank: int, world: int) -> None:
self.configure_count = 0

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
self.configure_count += 1

Expand Down Expand Up @@ -1138,11 +1193,28 @@ def __init__(self, pg: ProcessGroup) -> None:
self._error: Optional[Exception] = None

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
self._error = None

super().configure(store_addr, replica_id, rank, world_size)
super().configure(
store_addr,
replica_id,
rank,
world_size,
quorum_id,
group_rank,
group_world_size,
global_ranks,
)

def report_error(self, e: Exception) -> None:
"""
Expand Down Expand Up @@ -1194,11 +1266,28 @@ def __init__(self, pg: ProcessGroup) -> None:
self._future_error: Optional[Exception] = None

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
self._future_error = None

super().configure(store_addr, replica_id, rank, world_size)
super().configure(
store_addr,
replica_id,
rank,
world_size,
quorum_id,
group_rank,
group_world_size,
global_ranks,
)

def report_future_error(self, e: Exception) -> None:
"""
Expand Down Expand Up @@ -1412,7 +1501,15 @@ def shutdown(self) -> None:
self._p.kill()

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
group_rank: Optional[int] = None,
group_world_size: Optional[int] = None,
global_ranks: Optional[list[int]] = None,
) -> None:
self._world_size = world_size

Expand Down
Loading