Skip to content

Commit ec2462a

Browse files
authored
feat(replication): support single-node execution by disabling replication (#92)
Currently, ML Flashpoint requires an even number of nodes due to the pairwise replication strategy. This prevents the library from running in single-node environments, which is often required in some testing scenarios. This commit adds support for single-node execution by safely bypassing replication logic while keeping the core save/load functionality intact.
1 parent ceffdab commit ec2462a

File tree

3 files changed

+136
-3
lines changed

3 files changed

+136
-3
lines changed

src/ml_flashpoint/replication/replication_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,12 @@ def __init__(self, replication_service_addresses: List[str], processes_per_node:
108108
)
109109

110110
self._num_nodes = utils.get_num_of_nodes()
111-
if self._num_nodes % 2 != 0:
112-
raise ValueError(f"The total number of nodes ({self._num_nodes}) must be even.")
111+
# Allow 1-node execution by bypassing the even-number node check.
112+
# Pairwise replication requires an even number of nodes, but we allow exactly 1 node
113+
# to run without any replication.
114+
if self._num_nodes > 1 and self._num_nodes % 2 != 0:
115+
raise ValueError(f"The total number of nodes ({self._num_nodes}) must be even for pairwise replication.")
116+
self._disable_replication = self._num_nodes == 1
113117

114118
if self._num_nodes != self._world_size // self._processes_per_node:
115119
raise ValueError(
@@ -119,6 +123,9 @@ def __init__(self, replication_service_addresses: List[str], processes_per_node:
119123

120124
@override
121125
def get_destination_addresses(self, global_rank: int) -> List[str]:
126+
if self._disable_replication:
127+
return []
128+
122129
if not 0 <= global_rank < self._world_size:
123130
raise ValueError(f"global_rank {global_rank} is out of valid range [0, {self._world_size - 1}].")
124131

@@ -441,6 +448,7 @@ def sync_bulk_retrieve(
441448
Returns:
442449
`True` if the bulk retrieval was successful, `False` otherwise.
443450
"""
451+
444452
if self._retry_config is None:
445453
_LOGGER.error("ReplicationManager is not initialized. Cannot retrieve.")
446454
return False

tests/core/test_checkpoint_loader.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,71 @@ def test_no_retrieval_needed(self, loader, mocker):
998998
assert result == ckpt_id
999999
mock_retrieve.assert_not_called()
10001000

1001+
def test_single_node_shared_storage_no_retrieval_needed(self, loader, mocker):
1002+
"""
1003+
Tests that in a 1-node setup (e.g., fallback to NFS or local shared storage),
1004+
the loader correctly computes an empty retrieval plan. It does not attempt to
1005+
fetch from non-existent peer nodes, allowing the process to execute correctly
1006+
without throwing an error.
1007+
"""
1008+
# Given: Simulate a single-node environment (1 node, 2 ranks)
1009+
self.mock_dist_get_rank.return_value = 0
1010+
self.mock_get_num_nodes.return_value = 1
1011+
self.mock_dist_get_world_size.return_value = 2
1012+
1013+
base_container = CheckpointContainerId("/tmp/checkpoints")
1014+
ckpt_id = CheckpointContainerId.create_child(base_container, "step-100_ckpt")
1015+
mocker.patch.object(loader, "get_candidate_checkpoints", return_value=[ckpt_id])
1016+
1017+
# Mock metadata indicating what needs to be loaded
1018+
mock_metadata = mocker.MagicMock()
1019+
mock_metadata.storage_data = {
1020+
0: _StorageInfo(relative_path="/src0/obj", offset=0, length=100),
1021+
1: _StorageInfo(relative_path="/src1/obj", offset=0, length=100),
1022+
}
1023+
mocker.patch.object(loader, "read_metadata", return_value=mock_metadata)
1024+
1025+
# Mock available objects: Rank 0 sees src0, Rank 1 sees src1.
1026+
# However, since they are on the same node (num_nodes=1), the underlying system
1027+
# can see all files via NFS/local disk.
1028+
mocker.patch.object(
1029+
loader,
1030+
"get_checkpoint_objects_by_rank",
1031+
return_value={
1032+
0: [
1033+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/src0/obj"),
1034+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/.metadata"),
1035+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/common.pt"),
1036+
],
1037+
1: [
1038+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/src1/obj"),
1039+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/.metadata"),
1040+
CheckpointObjectId("/tmp/checkpoints/step-100_ckpt/common.pt"),
1041+
],
1042+
},
1043+
)
1044+
1045+
mock_retrieve = mocker.patch.object(loader, "retrieve_checkpoint", return_value=True)
1046+
1047+
# When: Execute the core logic to get the latest checkpoint
1048+
result = loader.get_latest_complete_checkpoint(base_container)
1049+
1050+
# Then:
1051+
# Verify that the method executes correctly instead of crashing with an error.
1052+
assert result == ckpt_id
1053+
1054+
# Verify that the generated retrieval plan is empty (no network fetch to non-existent peers).
1055+
args, _ = self.mock_dist_broadcast_object_list.call_args
1056+
plan_container = args[0]
1057+
plan = plan_container[0]
1058+
1059+
# Assert that the network retries wval plan for all ranks is empty (perfect fallback to NFS/Local read).
1060+
assert not plan.get(0)
1061+
assert not plan.get(1)
1062+
1063+
# Verify that the subsequent flow skips retrieval entirely since it's locally available.
1064+
mock_retrieve.assert_not_called()
1065+
10011066
def test_non_rank0_success(self, loader, mocker):
10021067
"""Tests successful retrieval flow on non-Rank 0."""
10031068
# Given

tests/replication/test_replication_manager.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
import pytest
1818

19-
from ml_flashpoint.replication.replication_manager import ReplicationManager, ReplicationRetryConfig
19+
from ml_flashpoint.replication.replication_manager import (
20+
PairwiseReplicationStrategy,
21+
ReplicationManager,
22+
ReplicationRetryConfig,
23+
)
2024

2125

2226
@pytest.fixture
@@ -338,3 +342,59 @@ def test_sync_bulk_retrieve_invalid_rank(replication_manager):
338342

339343
# Then
340344
assert result is False
345+
346+
347+
def test_pairwise_strategy_single_node_initialization(mocker):
348+
"""Tests that PairwiseReplicationStrategy successfully initializes for a single node without raising an error."""
349+
# Given
350+
mocker.patch("ml_flashpoint.core.utils.get_num_of_nodes", return_value=1)
351+
# Simulate a single node with 2 processes (GPUs)
352+
addresses = ["127.0.0.1:8000", "127.0.0.1:8001"]
353+
354+
# When
355+
strategy = PairwiseReplicationStrategy(replication_service_addresses=addresses, processes_per_node=2)
356+
357+
# Then
358+
assert getattr(strategy, "_disable_replication", False) is True
359+
360+
361+
def test_pairwise_strategy_single_node_get_destination(mocker):
362+
"""Tests that get_destination_addresses returns an empty list when running on a single node."""
363+
# Given
364+
mocker.patch("ml_flashpoint.core.utils.get_num_of_nodes", return_value=1)
365+
addresses = ["127.0.0.1:8000"]
366+
strategy = PairwiseReplicationStrategy(replication_service_addresses=addresses, processes_per_node=1)
367+
368+
# When
369+
destinations = strategy.get_destination_addresses(global_rank=0)
370+
371+
# Then
372+
assert destinations == []
373+
374+
375+
def test_async_replicate_single_node_skips(replication_manager, mocker):
376+
"""Tests that async_replicate does nothing and returns empty futures in a single-node environment."""
377+
# Given
378+
mocker.patch("ml_flashpoint.core.utils.get_num_of_nodes", return_value=1)
379+
addresses = ["127.0.0.1:8000"]
380+
# Initialize the strategy with 1 node
381+
strategy = PairwiseReplicationStrategy(replication_service_addresses=addresses, processes_per_node=1)
382+
replication_manager._repl_strategy = strategy
383+
384+
mocker.patch("torch.distributed.get_rank", return_value=0)
385+
386+
buffer_object = mocker.MagicMock()
387+
buffer_object.get_id.return_value = "test_single_node_obj"
388+
buffer_io = mocker.MagicMock(buffer_obj=buffer_object)
389+
390+
# When
391+
result_futures = replication_manager.async_replicate(buffer_io)
392+
393+
# Then
394+
assert result_futures == []
395+
# Ensure transfer service is NOT called
396+
replication_manager._transfer_service.async_put.assert_not_called()
397+
# Ensure the buffer is closed properly
398+
replication_manager._checkpoint_object_manager.close_buffer.assert_called_once_with(
399+
buffer_io, skip_close_if_symlink=True
400+
)

0 commit comments

Comments
 (0)