Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1. Ensure you have sufficient space on your base container mount.
1. If you have enough memory, but are running out of buffer space during writes, you can:
1. Increase the default initial buffer size via `initial_write_buffer_size_bytes` in the `wrap` API you are using (the default is 16 GB).
1. Increase the write thread count, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_thread_count` in the `wrap` API you are using (the default is 1).
1. Increase the number of files per rank, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_files_per_rank` in the `wrap` API you are using (the default is 1).

### How can I clean up ML Flashpoint checkpoints after job completion?

Expand Down
2 changes: 1 addition & 1 deletion docs/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save=not args.sync_save,
default_auto_resume=auto_resume, # Optional
# always_save_context=False, # Optional, defaults to False
# write_thread_count=1, # Optional, defaults to 1
# write_files_per_rank=1, # Optional, defaults to 1
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
# use_cached_ckpt_structure=False, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
Expand Down
8 changes: 4 additions & 4 deletions src/ml_flashpoint/adapter/megatron/save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(
self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure

@property
def thread_count(self) -> int:
"""Returns the number of threads used by the storage writer."""
return self._storage_writer._thread_count
def files_per_rank(self) -> int:
"""Returns the number of files per rank used by the storage writer."""
return self._storage_writer._files_per_rank

@override
def can_handle_sharded_objects(self) -> bool:
Expand Down Expand Up @@ -145,7 +145,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
self._storage_writer = MemoryStorageWriter(
checkpoint_saver=self._checkpoint_saver,
mp_manager_future=self._storage_writer._main_process_torchmp_manager_future,
thread_count=self._storage_writer._thread_count,
files_per_rank=self._storage_writer._files_per_rank,
)
# 1c. Reset the StorageWriter for this checkpoint version.
self._storage_writer.reset(checkpoint_id.data)
Expand Down
20 changes: 11 additions & 9 deletions src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save: bool,
default_auto_resume: nl.AutoResume = None,
always_save_context: bool = False,
write_thread_count: int = 1,
write_files_per_rank: int = 1,
initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_cached_ckpt_structure: bool = False,
Expand All @@ -72,7 +72,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save: Whether to enable asynchronous saving for checkpoints.
default_auto_resume: The default AutoResume configuration to inherit from.
always_save_context: Whether to always save the context. Defaults to `False`.
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data.
Checkpoint data will be split roughly evenly among the files (per rank). Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly.
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
Expand All @@ -92,7 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
pool_config = BufferPoolConfig(
pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool"),
rank=trainer.global_rank,
num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT,
num_buffers=write_files_per_rank * NUM_OF_BUFFERS_PER_OBJECT,
buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
)

Expand All @@ -119,7 +120,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save=async_save,
checkpoint_loader=ckpt_loader,
always_save_context=always_save_context,
write_thread_count=write_thread_count,
write_files_per_rank=write_files_per_rank,
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
use_optimized_save=use_optimized_save,
use_cached_ckpt_structure=use_cached_ckpt_structure,
Expand All @@ -142,7 +143,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
async_save: bool,
checkpoint_loader: DefaultMLFlashpointCheckpointLoader,
always_save_context: bool = False,
write_thread_count: int = 1,
write_files_per_rank: int = 1,
initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_cached_ckpt_structure: bool = False,
Expand Down Expand Up @@ -171,7 +172,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
async_save: Whether to enable asynchronous saving.
checkpoint_loader: The checkpoint loader to use.
always_save_context: Whether to always save the context. Defaults to `False`.
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data.
Checkpoint data will be split roughly evenly among the files (per rank). Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly.
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
Expand All @@ -193,8 +195,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
raise ValueError("The 'ckpt_obj_manager' argument cannot be None.")
if replication_manager is None:
raise ValueError("The 'replication_manager' argument cannot be None.")
if write_thread_count < 1:
raise ValueError(f"write_thread_count must be >= 1, got {write_thread_count}.")
if write_files_per_rank < 1:
raise ValueError(f"write_files_per_rank must be >= 1, got {write_files_per_rank}.")
if initial_write_buffer_size_bytes is None:
initial_write_buffer_size_bytes = DEFAULT_INITIAL_BUFFER_SIZE_BYTES
if initial_write_buffer_size_bytes <= 0:
Expand Down Expand Up @@ -268,7 +270,7 @@ def start_manager():
use_optimized_save=use_optimized_save,
),
mp_manager_future=mp_manager_future,
thread_count=write_thread_count,
files_per_rank=write_files_per_rank,
),
use_cached_ckpt_structure=use_cached_ckpt_structure,
)
Expand Down
29 changes: 15 additions & 14 deletions src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,44 +88,45 @@ def __init__(
self,
checkpoint_saver: MLFlashpointCheckpointSaver,
mp_manager_future: concurrent.futures.Future,
thread_count: int = 1,
files_per_rank: int = 1,
):
"""Initializes the MemoryStorageWriter.

Args:
checkpoint_saver: An instance of `MLFlashpointCheckpointSaver` used for
handling the actual checkpoint saving logic.
mp_manager: A `torch.multiprocessing.Manager` instance for managing
shared state across processes, particularly for write results and events.
mp_manager_future: A `concurrent.futures.Future` that resolves to a
`torch.multiprocessing.Manager` instance for managing shared state
across processes, particularly for write results and events.
It is highly recommended to create this manager using a 'spawn'
multiprocessing context to avoid inheriting the parent's CUDA context,
which prevents CUDA OOM errors during failure recoveries
thread_count: Optional. The number of threads to use for writing checkpoint data.
files_per_rank: Optional. The number of files each rank writes to for checkpoint data.
Defaults to 1. If a value less than 1 is provided, it will be reset to 1,
and a warning will be logged.
"""
super().__init__()
self._current_checkpoint_id: CheckpointContainerId | None = None
self._current_save_id: str | None = None
self._checkpoint_saver: MLFlashpointCheckpointSaver = checkpoint_saver
if thread_count < 1:
_LOGGER.warning("thread_count must be >= 1, but was %d. Setting to 1.", thread_count)
thread_count = 1
self._thread_count = thread_count
# _main_process_torchmp_manager should only be used in the main process, not in the spawned processes.
# This is because mp_manager is not picklable.
if files_per_rank < 1:
_LOGGER.warning("files_per_rank must be >= 1, but was %d. Setting to 1.", files_per_rank)
files_per_rank = 1
self._files_per_rank = files_per_rank
# _main_process_torchmp_manager_future should only be used in the main process, not in the spawned processes.
# This is because the mp_manager it resolves to is not picklable.
self._main_process_torchmp_manager_future = mp_manager_future
self._write_events_per_checkpoint_id: Optional[dict[CheckpointContainerId, torch_mp.Event]] = None
self._write_results_per_checkpoint_id: Optional[dict[CheckpointContainerId, list[WriteResult]]] = None

def __getstate__(self):
"""Custom pickling to exclude unpicklable mp_manager."""
"""Custom pickling to exclude unpicklable mp_manager_future."""
state = self.__dict__.copy()
state.pop("_main_process_torchmp_manager_future", None)
return state

def __setstate__(self, state):
"""Custom unpickling to restore state and set mp_manager to None."""
"""Custom unpickling to restore state and set mp_manager_future to None."""
self.__dict__.update(state)
self._main_process_torchmp_manager_future = None

Expand Down Expand Up @@ -203,7 +204,7 @@ def prepare_write_data_buckets(
)

write_buckets = self.checkpoint_saver.prepare_write_data(
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._thread_count
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._files_per_rank
)
return write_buckets
# self._write_buckets_per_checkpoint_id[checkpoint_id] = write_buckets
Expand Down Expand Up @@ -243,7 +244,7 @@ def write_staged_data_buckets(
write_results = self._checkpoint_saver.write_data(
checkpoint_id,
write_buckets=staged_write_buckets,
thread_count=self._thread_count,
files_per_rank=self._files_per_rank,
replicate_after_write=replicate_after_write,
)
end_time = time.perf_counter()
Expand Down
22 changes: 11 additions & 11 deletions src/ml_flashpoint/core/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def write_data(
self,
checkpoint_id: CheckpointContainerId,
write_buckets: list[ObjectWriteBucket],
thread_count: int,
files_per_rank: int,
replicate_after_write: bool,
) -> list[WriteResult]:
"""Performs the core write logic for the given write items and checkpoint_id.
Expand All @@ -225,7 +225,7 @@ def write_data(
checkpoint_id: Unique hierarchical ID representing this checkpoint container.
This typically follows a directory path structure.
write_buckets: A list of `ObjectWriteBucket` objects, each containing resolved data ready for writing.
thread_count: The number of threads to use for writing data.
files_per_rank: The number of files each rank writes to.
replicate_after_write: Whether to trigger async replication of each object after it is written.

Returns:
Expand Down Expand Up @@ -371,7 +371,7 @@ def prepare_write_data(
) -> list[ObjectWriteBucket]:
bucket_count = max(bucket_count, 1)
_LOGGER.debug(
"%s prepare_write_data with prefix: '%s', thread_count: %d",
"%s prepare_write_data with prefix: '%s', files_per_rank: %d",
self.__class__.__name__,
object_name_prefix,
bucket_count,
Expand Down Expand Up @@ -403,7 +403,7 @@ def _clone_if_needed(tensor: torch.Tensor):
# NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
# only use 1 thread.

# Group items into buckets, one bucket per file, up to thread_count files
# Group items into buckets, one bucket per file, up to files_per_rank files
buckets = _split_by_size_and_type(bucket_count, write_items)
for bucket in buckets:
if not bucket:
Expand Down Expand Up @@ -437,22 +437,22 @@ def write_data(
checkpoint_id: CheckpointContainerId,
write_buckets: list[ObjectWriteBucket],
replicate_after_write: bool,
thread_count: int = 1,
files_per_rank: int = 1,
) -> list[WriteResult]:
thread_count = max(thread_count, 1)
files_per_rank = max(files_per_rank, 1)
num_cpus = os.cpu_count() or 1
num_ranks = max(get_accelerator_count(), 1)
# Use 50% of available CPU cores for PyTorch intra-op threads and evenly distribute them across ranks.
torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count)
torch_thread_count = max(1, num_cpus // 2 // num_ranks // files_per_rank)
original_num_threads = torch.get_num_threads()
# Explicitly set PyTorch intra-op threads to optimize for performance.
# This also avoids potential runtime errors in tensor.copy_() with concurrent writers
torch.set_num_threads(torch_thread_count)
_LOGGER.debug(
"%s starting multi-threaded write_data. thread_count: %d, original_num_threads: %d, "
"%s starting multi-threaded write_data. files_per_rank: %d, original_num_threads: %d, "
"num_cpus: %d, num_ranks: %d, torch_thread_count: %d",
self.__class__.__name__,
thread_count,
files_per_rank,
original_num_threads,
num_cpus,
num_ranks,
Expand All @@ -471,8 +471,8 @@ def write_data(
threads = []

# Kick off additional threads to main thread, if any.
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1)
for i in range(1, thread_count):
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", files_per_rank - 1)
for i in range(1, files_per_rank):
thread = threading.Thread(
target=self._write_to_buffer_from_queue_worker,
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),
Expand Down
19 changes: 10 additions & 9 deletions tests/adapter/megatron/test_save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def checkpoint_saver() -> MLFlashpointCheckpointSaver:
def storage_writer(mocker, checkpoint_saver) -> MemoryStorageWriter:
# Using a real MemoryStorageWriter instance instead of a mock.
# We can still spy on its methods if needed.
# The mp_manager is mocked as it's not relevant to these tests.
# The mp_manager_future is mocked as it's not relevant to these tests.
return MemoryStorageWriter(
checkpoint_saver=checkpoint_saver,
mp_manager_future=mocker.MagicMock(),
Expand Down Expand Up @@ -193,18 +193,19 @@ def test_async_save_initialization_calls_success(
mock_memory_storage_writer_cls.assert_called_once_with(
checkpoint_saver=checkpoint_saver,
mp_manager_future=storage_writer._main_process_torchmp_manager_future,
thread_count=storage_writer._thread_count,
files_per_rank=storage_writer._files_per_rank,
)
mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data)
mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with(
checkpoint_id, dummy_write_buckets, non_blocking=True
)

@pytest.mark.parametrize("expected_thread_count", [1, 2, 3, 5])
def test_async_save_reinitializes_storage_writer_with_thread_count(
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, expected_thread_count
@pytest.mark.parametrize("expected_files_per_rank", [1, 2, 3, 5])
def test_async_save_reinitializes_storage_writer_with_files_per_rank(
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets,
expected_files_per_rank,
):
"""Tests that the StorageWriter is re-initialized with the correct thread_count."""
"""Tests that the StorageWriter is re-initialized with the correct files_per_rank."""
# Given
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
(
Expand All @@ -221,8 +222,8 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
False,
)

# Set a specific thread_count on the original storage_writer
storage_writer._thread_count = expected_thread_count
# Set a specific files_per_rank on the original storage_writer
storage_writer._files_per_rank = expected_files_per_rank

mock_memory_storage_writer_cls = mocker.patch(
"ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
Expand All @@ -235,7 +236,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
mock_memory_storage_writer_cls.assert_called_once_with(
checkpoint_saver=checkpoint_saver,
mp_manager_future=storage_writer._main_process_torchmp_manager_future,
thread_count=expected_thread_count,
files_per_rank=expected_files_per_rank,
)

def test_initialize_checkpoint_failure(self, mocker, async_save_setup, checkpoint_saver):
Expand Down
Loading