Skip to content

Commit bce6888

Browse files
committed
Merge branch 'main' of https://github.com/google/ml-flashpoint into rename-thread-count-to-files-per-rank
2 parents 353a891 + ae95fb6 commit bce6888

14 files changed

+590
-305
lines changed

docs/user-guide.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
9292
# always_save_context=False, # Optional, defaults to False
9393
# write_files_per_rank=1, # Optional, defaults to 1
9494
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
95+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
96+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
9597
)
9698
```
9799

@@ -126,6 +128,7 @@ from ml_flashpoint.adapter.megatron.save_strategies import (
126128
)
127129

128130
# Loading
131+
import torch.distributed as dist
129132
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
130133
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
131134
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
@@ -148,6 +151,7 @@ memory_storage_writer = MemoryStorageWriter(...)
148151
# Use it to instantiate the Save Strategy
149152
megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
150153
storage_writer=memory_storage_writer,
154+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
151155
)
152156
```
153157

@@ -167,7 +171,7 @@ async_request = save_local_aware_megatron_checkpoint(
167171

168172
!!! note
169173

170-
Make sure to specify the checkpoint ID/path when saving based on the current step using:
174+
Make sure to specify the checkpoint ID/path when saving based on the current step using:
171175
`CheckpointContainerId.create_child(base_container, CheckpointContainerId.format_version_container(current_step))`
172176
where `base_container` is the base path CheckpointContainerId used for all checkpoints for the current job, e.g. `"/tmp/mlf-checkpoints/job123"`.
173177

@@ -188,6 +192,11 @@ replication_manager.initialize(checkpoint_object_manager)
188192
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
189193
checkpoint_object_manager=checkpoint_object_manager,
190194
replication_manager=replication_manager,
195+
global_rank_getter=dist.get_rank,
196+
local_rank_getter=torch.distributed.get_node_local_rank,
197+
broadcast_object_list_func=dist.broadcast_object_list,
198+
all_gather_object_func=dist.all_gather_object,
199+
world_size_getter=dist.get_world_size,
191200
)
192201

193202
# Instantiate the Load Strategy with the dependencies
@@ -229,11 +238,12 @@ Code: See the [`ml_flashpoint.adapter.pytorch`](https://github.com/google/ml-fla
229238
To use directly with PyTorch DCP, use the provided `StorageWriter` and `StorageReader` implementations.
230239
You can use whatever `Planner` implementations work for your use case, or resort to the defaults.
231240

232-
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
241+
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
233242

234243
#### Imports
235244
```python
236245
import torch
246+
import torch.distributed as dist
237247
from torch import multiprocessing as torch_mp
238248
import torch.distributed.checkpoint as dcp
239249

@@ -262,6 +272,7 @@ memory_storage_writer = MemoryStorageWriter(
262272
ckpt_obj_manager=checkpoint_object_manager,
263273
replication_manager=replication_manager,
264274
# initial_buffer_size_bytes=initial_write_buffer_size_bytes, # Optional - increase for larger checkpoint sizes per rank
275+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
265276
),
266277
mp_manager=torch_mp.Manager(),
267278
)
@@ -270,6 +281,11 @@ memory_storage_writer = MemoryStorageWriter(
270281
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
271282
checkpoint_object_manager=checkpoint_object_manager,
272283
replication_manager=replication_manager,
284+
global_rank_getter=dist.get_rank,
285+
local_rank_getter=torch.distributed.get_node_local_rank,
286+
broadcast_object_list_func=dist.broadcast_object_list,
287+
all_gather_object_func=dist.all_gather_object,
288+
world_size_getter=dist.get_world_size,
273289
)
274290
memory_storage_reader = MemoryStorageReader(
275291
path=checkpoint_dir,

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
139139
# 1b. Re-initialize the StorageWriter to use a new instance per save to avoid hangs from shared state.
140140
self._storage_writer = MemoryStorageWriter(
141141
checkpoint_saver=self._checkpoint_saver,
142-
mp_manager=self._storage_writer._main_process_torchmp_manager,
142+
mp_manager_future=self._storage_writer._main_process_torchmp_manager_future,
143143
files_per_rank=self._storage_writer._files_per_rank,
144144
)
145145
# 1c. Reset the StorageWriter for this checkpoint version.

src/ml_flashpoint/adapter/nemo/nemo_checkpoint_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
from pathlib import Path
17-
from typing import List, Set
17+
from typing import Callable, List, Set
1818

1919
from typing_extensions import override
2020

@@ -33,6 +33,12 @@ def __init__(
3333
self,
3434
checkpoint_object_manager: CheckpointObjectManager,
3535
replication_manager: ReplicationManager,
36+
*,
37+
global_rank_getter: Callable[[], int],
38+
local_rank_getter: Callable[[], int],
39+
broadcast_object_list_func: Callable[..., None],
40+
all_gather_object_func: Callable[..., None],
41+
world_size_getter: Callable[[], int],
3642
recover_context: bool = False,
3743
):
3844
"""Initializes the NeMoMLFlashpointCheckpointLoader.
@@ -42,9 +48,24 @@ def __init__(
4248
reading data.
4349
replication_manager: The replication manager to use for retrieving
4450
missing checkpoint objects from peer nodes.
51+
global_rank_getter: A callable that returns the global rank.
52+
local_rank_getter: A callable that returns the node-local rank.
53+
broadcast_object_list_func: A callable with the same signature as
54+
``torch.distributed.broadcast_object_list``.
55+
all_gather_object_func: A callable with the same signature as
56+
``torch.distributed.all_gather_object``.
57+
world_size_getter: A callable that returns the world size.
4558
recover_context: Whether to recover the context directory if missing.
4659
"""
47-
super().__init__(checkpoint_object_manager, replication_manager)
60+
super().__init__(
61+
checkpoint_object_manager,
62+
replication_manager,
63+
global_rank_getter=global_rank_getter,
64+
local_rank_getter=local_rank_getter,
65+
broadcast_object_list_func=broadcast_object_list_func,
66+
all_gather_object_func=all_gather_object_func,
67+
world_size_getter=world_size_getter,
68+
)
4869
self._recover_context = recover_context
4970

5071
@override

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import concurrent.futures
16+
import threading
1517
from typing import Union
1618

1719
import torch
20+
import torch.distributed as dist
1821
from nemo import lightning as nl
1922
from nemo.lightning.io.pl import MegatronCheckpointIO
2023
from nemo.lightning.pytorch import strategies as nl_strategies
@@ -79,6 +82,11 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
7982
ckpt_loader = NeMoMLFlashpointCheckpointLoader(
8083
checkpoint_object_manager=ckpt_obj_manager,
8184
replication_manager=replication_manager,
85+
global_rank_getter=dist.get_rank,
86+
local_rank_getter=dist.get_node_local_rank,
87+
broadcast_object_list_func=dist.broadcast_object_list,
88+
all_gather_object_func=dist.all_gather_object,
89+
world_size_getter=dist.get_world_size,
8290
recover_context=always_save_context,
8391
)
8492

@@ -212,6 +220,14 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
212220
# (OOM) errors upon restart. 'spawn' launches a clean interpreter without
213221
# the inherited CUDA state, allowing the GPU memory to be freed instantly.
214222
ctx = torch_mp.get_context("spawn")
223+
mp_manager_future = concurrent.futures.Future()
224+
225+
def start_manager():
226+
mp_manager_future.set_result(ctx.Manager())
227+
228+
thread = threading.Thread(target=start_manager, daemon=True)
229+
thread.start()
230+
215231
save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
216232
storage_writer=MemoryStorageWriter(
217233
checkpoint_saver=DefaultMLFlashpointCheckpointSaver(
@@ -223,7 +239,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
223239
initial_buffer_size_bytes=initial_write_buffer_size_bytes,
224240
use_optimized_save=use_optimized_save,
225241
),
226-
mp_manager=ctx.Manager(),
242+
mp_manager_future=mp_manager_future,
227243
files_per_rank=write_files_per_rank,
228244
),
229245
use_cached_ckpt_structure=use_cached_ckpt_structure,

src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Optional, Union
2121

2222
import torch
23-
from torch import multiprocessing as torch_mp
23+
import torch.multiprocessing as torch_mp
2424
from torch.distributed.checkpoint import Metadata, SavePlan, SavePlanner, StorageWriter, staging
2525
from torch.distributed.checkpoint.filesystem import _StorageInfo
2626
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, MetadataIndex, StorageMeta
@@ -87,16 +87,17 @@ class MemoryStorageWriter(StorageWriter, staging.BlockingAsyncStager):
8787
def __init__(
8888
self,
8989
checkpoint_saver: MLFlashpointCheckpointSaver,
90-
mp_manager: torch_mp.Manager,
90+
mp_manager_future: concurrent.futures.Future,
9191
files_per_rank: int = 1,
9292
):
9393
"""Initializes the MemoryStorageWriter.
9494
9595
Args:
9696
checkpoint_saver: An instance of `MLFlashpointCheckpointSaver` used for
9797
handling the actual checkpoint saving logic.
98-
mp_manager: A `torch.multiprocessing.Manager` instance for managing
99-
shared state across processes, particularly for write results and events.
98+
mp_manager_future: A `concurrent.futures.Future` that resolves to a
99+
`torch.multiprocessing.Manager` instance for managing shared state
100+
across processes, particularly for write results and events.
100101
It is highly recommended to create this manager using a 'spawn'
101102
multiprocessing context to avoid inheriting the parent's CUDA context,
102103
which prevents CUDA OOM errors during failure recoveries
@@ -112,23 +113,22 @@ def __init__(
112113
_LOGGER.warning("files_per_rank must be >= 1, but was %d. Setting to 1.", files_per_rank)
113114
files_per_rank = 1
114115
self._files_per_rank = files_per_rank
115-
# _main_process_torchmp_manager should only be used in the main process, not in the spawned processes.
116-
# This is because mp_manager is not picklable.
117-
self._main_process_torchmp_manager = mp_manager
118-
self._write_events_per_checkpoint_id: dict[CheckpointContainerId, torch_mp.Event] = mp_manager.dict()
119-
self._write_results_per_checkpoint_id: dict[CheckpointContainerId, list[WriteResult]] = mp_manager.dict()
116+
# _main_process_torchmp_manager_future should only be used in the main process, not in the spawned processes.
117+
# This is because the mp_manager it resolves to is not picklable.
118+
self._main_process_torchmp_manager_future = mp_manager_future
119+
self._write_events_per_checkpoint_id: Optional[dict[CheckpointContainerId, torch_mp.Event]] = None
120+
self._write_results_per_checkpoint_id: Optional[dict[CheckpointContainerId, list[WriteResult]]] = None
120121

121122
def __getstate__(self):
122-
"""Custom pickling to exclude unpicklable mp_manager."""
123+
"""Custom pickling to exclude unpicklable mp_manager_future."""
123124
state = self.__dict__.copy()
124-
if "_main_process_torchmp_manager" in state:
125-
del state["_main_process_torchmp_manager"]
125+
state.pop("_main_process_torchmp_manager_future", None)
126126
return state
127127

128128
def __setstate__(self, state):
129-
"""Custom unpickling to restore state and set mp_manager to None."""
129+
"""Custom unpickling to restore state and set mp_manager_future to None."""
130130
self.__dict__.update(state)
131-
self._main_process_torchmp_manager = None
131+
self._main_process_torchmp_manager_future = None
132132

133133
def _check_checkpoint_id(self) -> None:
134134
if self._current_checkpoint_id is None:
@@ -154,6 +154,11 @@ def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
154154
# Mimicking existing StorageWriter impls (e.g. `_FileSystemWriter`) by using a random ID as the save ID.
155155
self._current_save_id = generate_hfid("memwritersave")
156156

157+
if self._write_events_per_checkpoint_id is None and self._main_process_torchmp_manager_future is not None:
158+
mp_manager = self._main_process_torchmp_manager_future.result()
159+
self._write_events_per_checkpoint_id = mp_manager.dict()
160+
self._write_results_per_checkpoint_id = mp_manager.dict()
161+
157162
def storage_meta(self) -> Optional[StorageMeta]:
158163
self._check_checkpoint_id()
159164
return StorageMeta(checkpoint_id=self._current_checkpoint_id.data, save_id=self._current_save_id)
@@ -194,7 +199,9 @@ def prepare_write_data_buckets(
194199
) -> list[ObjectWriteBucket]:
195200
# Create a new, unset Event for this specific checkpoint save
196201
if checkpoint_id not in self._write_events_per_checkpoint_id:
197-
self._write_events_per_checkpoint_id[checkpoint_id] = self._main_process_torchmp_manager.Event()
202+
self._write_events_per_checkpoint_id[checkpoint_id] = (
203+
self._main_process_torchmp_manager_future.result().Event()
204+
)
198205

199206
write_buckets = self.checkpoint_saver.prepare_write_data(
200207
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._files_per_rank

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
import struct
2323
from collections import defaultdict
2424
from pathlib import Path
25-
from typing import IO, List, Optional, Set, Tuple, TypeVar, cast
25+
from typing import IO, Callable, List, Optional, Set, Tuple, TypeVar, cast
2626

2727
import torch
28-
import torch.distributed as dist
2928
from torch.distributed._shard._utils import narrow_tensor_by_index
3029
from torch.distributed.checkpoint import Metadata
3130
from torch.distributed.checkpoint.filesystem import _StorageInfo
@@ -128,6 +127,12 @@ def __init__(
128127
self,
129128
checkpoint_object_manager: CheckpointObjectManager,
130129
replication_manager: ReplicationManager,
130+
*,
131+
global_rank_getter: Callable[[], int],
132+
local_rank_getter: Callable[[], int],
133+
broadcast_object_list_func: Callable[..., None],
134+
all_gather_object_func: Callable[..., None],
135+
world_size_getter: Callable[[], int],
131136
):
132137
"""Initializes the DefaultMLFlashpointCheckpointLoader.
133138
@@ -136,9 +141,21 @@ def __init__(
136141
reading data.
137142
replication_manager: The replication manager to use for retrieving
138143
missing checkpoint objects from peer nodes.
144+
global_rank_getter: A callable that returns the global rank.
145+
local_rank_getter: A callable that returns the node-local rank.
146+
broadcast_object_list_func: A callable with the same signature as
147+
``torch.distributed.broadcast_object_list``.
148+
all_gather_object_func: A callable with the same signature as
149+
``torch.distributed.all_gather_object``.
150+
world_size_getter: A callable that returns the world size.
139151
"""
140152
self._checkpoint_object_manager = checkpoint_object_manager
141153
self._replication_manager = replication_manager
154+
self._global_rank_getter = global_rank_getter
155+
self._local_rank_getter = local_rank_getter
156+
self._broadcast_object_list_func = broadcast_object_list_func
157+
self._all_gather_object_func = all_gather_object_func
158+
self._world_size_getter = world_size_getter
142159
# Cache for available objects: CheckpointContainerId -> dict[object_path, list[rank]]
143160
self._available_objects_cache: dict[CheckpointContainerId, dict[str, List[int]]] = {}
144161

@@ -337,8 +354,7 @@ def get_latest_complete_checkpoint(
337354
else continue to the next candidate checkpoint
338355
- return the checkpoint container id of the latest complete checkpoint
339356
"""
340-
# TODO: use global_rank_getter and local_rank_getter.
341-
rank = dist.get_rank()
357+
rank = self._global_rank_getter()
342358
_LOGGER.debug(
343359
"Rank %s: Getting latest complete checkpoint for '%s'",
344360
rank,
@@ -382,7 +398,7 @@ def get_latest_complete_checkpoint(
382398
retrieval_plan = self._compute_retrieval_plan(checkpoint, available_objects_by_rank)
383399
# Broadcast the retrieval plan to all ranks.
384400
plan_container = [retrieval_plan]
385-
dist.broadcast_object_list(plan_container, src=planner_rank)
401+
self._broadcast_object_list_func(plan_container, src=planner_rank)
386402
retrieval_plan = plan_container[0]
387403

388404
if retrieval_plan is None:
@@ -451,7 +467,7 @@ def _compute_retrieval_plan(
451467

452468
objects_needed_by_local_rank_0.update(self._get_extra_needed_objects(checkpoint, available_objects_by_rank))
453469

454-
world_size = dist.get_world_size()
470+
world_size = self._world_size_getter()
455471
num_nodes = get_num_of_nodes()
456472
ranks_per_node = world_size // num_nodes
457473

@@ -507,8 +523,8 @@ def get_candidate_checkpoints(
507523

508524
# Scan locally only on the first rank of each node
509525
base_path = Path(checkpoint_base_container.data)
510-
rank = dist.get_rank()
511-
local_rank = dist.get_node_local_rank()
526+
rank = self._global_rank_getter()
527+
local_rank = self._local_rank_getter()
512528

513529
local_candidate_ckpt_ids = []
514530

@@ -532,8 +548,8 @@ def get_candidate_checkpoints(
532548
else:
533549
_LOGGER.debug("Rank %s: Base path '%s' is not a directory or does not exist.", rank, base_path)
534550

535-
all_checkpoint_container_path_lists = [None for _ in range(dist.get_world_size())]
536-
dist.all_gather_object(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
551+
all_checkpoint_container_path_lists = [None for _ in range(self._world_size_getter())]
552+
self._all_gather_object_func(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
537553
_LOGGER.debug(
538554
"Rank %s: Gathered checkpoint container paths from all ranks: '%s'",
539555
rank,
@@ -589,8 +605,8 @@ def get_checkpoint_objects_by_rank(
589605

590606
local_objects.extend(self._get_extra_local_objects(container_path))
591607

592-
all_objects_by_rank_paths = [None for _ in range(dist.get_world_size())]
593-
dist.all_gather_object(all_objects_by_rank_paths, local_objects)
608+
all_objects_by_rank_paths = [None for _ in range(self._world_size_getter())]
609+
self._all_gather_object_func(all_objects_by_rank_paths, local_objects)
594610

595611
result = {}
596612
object_locations = defaultdict(list)
@@ -620,7 +636,7 @@ def retrieve_checkpoint(
620636
If empty for this rank, no retrieval is needed.
621637
"""
622638

623-
rank = dist.get_rank()
639+
rank = self._global_rank_getter()
624640
all_success = True
625641

626642
# Only proceed with retrieval if we have items to retrieve
@@ -656,8 +672,8 @@ def retrieve_checkpoint(
656672

657673
# Gather success status from all ranks
658674
_LOGGER.debug("Gathering success status from all ranks")
659-
all_success_list = [None for _ in range(dist.get_world_size())]
660-
dist.all_gather_object(all_success_list, all_success)
675+
all_success_list = [None for _ in range(self._world_size_getter())]
676+
self._all_gather_object_func(all_success_list, all_success)
661677
_LOGGER.debug("All success list: '%s'", all_success_list)
662678
return all(all_success_list)
663679

0 commit comments

Comments
 (0)