Skip to content

Commit e9eb9d6

Browse files
devyanic11g-husamLeahlijuan
authored
feat(adapter/nemo): add fully parallel save/load wrapper option (#49)
Fixes #48 Add optional `use_fully_parallel_wrapper` flag to the NeMo wrapper utility. When enabled, save/load strategies are wrapped with `FullyParallelSaveStrategyWrapper` and `FullyParallelLoadStrategyWrapper`. Default behavior remains unchanged. - [x] Tests pass --------- Co-authored-by: g-husam <husameldawi@google.com> Co-authored-by: leahlijuan <leahlijuan@google.com>
1 parent 5d0e076 commit e9eb9d6

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
import torch
2121
import torch.distributed as dist
22+
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
23+
FullyParallelLoadStrategyWrapper,
24+
FullyParallelSaveStrategyWrapper,
25+
)
2226
from nemo import lightning as nl
2327
from nemo.lightning.io.pl import MegatronCheckpointIO
2428
from nemo.lightning.pytorch import strategies as nl_strategies
@@ -53,6 +57,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
5357
initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
5458
use_optimized_save: bool = True,
5559
use_cached_ckpt_structure: bool = False,
60+
use_fully_parallel_wrapper: bool = False,
5661
) -> MLFlashpointAutoResume:
5762
"""Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`.
5863
@@ -72,6 +77,10 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
7277
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly.
7378
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
7479
Defaults to False.
80+
use_fully_parallel_wrapper: Whether to use the fully parallel wrapper for save and load.
81+
This will evenly distribute checkpoint data across all ranks.
82+
Defaults to False.
83+
7584
Returns:
7685
An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`.
7786
"""
@@ -114,6 +123,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
114123
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
115124
use_optimized_save=use_optimized_save,
116125
use_cached_ckpt_structure=use_cached_ckpt_structure,
126+
use_fully_parallel_wrapper=use_fully_parallel_wrapper,
117127
)
118128

119129
default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {}
@@ -136,6 +146,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
136146
initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
137147
use_optimized_save: bool = True,
138148
use_cached_ckpt_structure: bool = False,
149+
use_fully_parallel_wrapper: bool = False,
139150
):
140151
"""Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities.
141152
@@ -165,6 +176,9 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
165176
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly.
166177
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
167178
Defaults to False.
179+
use_fully_parallel_wrapper: Whether to use the fully parallel wrapper for save and load.
180+
This will evenly distribute checkpoint data across all ranks.
181+
Defaults to False.
168182
169183
Returns:
170184
None. The trainer's checkpoint_io is modified in-place.
@@ -263,6 +277,10 @@ def start_manager():
263277
checkpoint_loader=checkpoint_loader,
264278
)
265279

280+
if use_fully_parallel_wrapper:
281+
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy)
282+
load_strategy = FullyParallelLoadStrategyWrapper(load_strategy)
283+
266284
ml_flashpoint_checkpoint_io = MLFlashpointCheckpointIO(
267285
flashpoint_base_path=flashpoint_base_container,
268286
alt_checkpoint_io=checkpoint_io,

tests/adapter/nemo/test_wrapper_util.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
import dataclasses
1818

1919
import pytest
20+
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
21+
FullyParallelLoadStrategyWrapper,
22+
FullyParallelSaveStrategyWrapper,
23+
)
2024
from nemo import lightning as nl
2125
from nemo.lightning.io.pl import MegatronCheckpointIO
2226
from nemo.lightning.pytorch import strategies as nl_strategies
@@ -128,6 +132,7 @@ def test_successful_wrap_and_resume_creation(self, mocker, mock_ckpt_obj_manager
128132
initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
129133
use_optimized_save=True,
130134
use_cached_ckpt_structure=False,
135+
use_fully_parallel_wrapper=False,
131136
)
132137

133138
# 3. Result is correct type and has correct attributes
@@ -343,6 +348,58 @@ def test_use_cached_ckpt_structure_default_value(self, mocker, mock_ckpt_obj_man
343348
_, kwargs = mock_wrap_trainer.call_args
344349
assert kwargs["use_cached_ckpt_structure"] is False
345350

351+
@pytest.mark.parametrize("use_fully_parallel_wrapper", [True, False])
352+
def test_use_fully_parallel_wrapper_forwarding(self, mocker, use_fully_parallel_wrapper):
353+
"""Tests that use_fully_parallel_wrapper is forwarded correctly."""
354+
# Given
355+
mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager")
356+
mock_wrap_trainer = mocker.patch(
357+
"ml_flashpoint.adapter.nemo.wrapper_util.wrap_trainer_checkpoint_io_with_mlflashpoint"
358+
)
359+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
360+
trainer.global_rank = 0
361+
flashpoint_base_container = "/tmp/test_container"
362+
default_auto_resume = nl.AutoResume()
363+
364+
# When
365+
wrap_trainer_and_auto_resume_with_mlflashpoint(
366+
trainer,
367+
flashpoint_base_container,
368+
async_save=True,
369+
default_auto_resume=default_auto_resume,
370+
use_fully_parallel_wrapper=use_fully_parallel_wrapper,
371+
)
372+
373+
# Then
374+
mock_wrap_trainer.assert_called_once()
375+
_, kwargs = mock_wrap_trainer.call_args
376+
assert kwargs["use_fully_parallel_wrapper"] is use_fully_parallel_wrapper
377+
378+
def test_use_fully_parallel_wrapper_default_value(self, mocker):
379+
"""Tests that use_fully_parallel_wrapper defaults to False."""
380+
# Given
381+
mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager")
382+
mock_wrap_trainer = mocker.patch(
383+
"ml_flashpoint.adapter.nemo.wrapper_util.wrap_trainer_checkpoint_io_with_mlflashpoint"
384+
)
385+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
386+
trainer.global_rank = 0
387+
flashpoint_base_container = "/tmp/test_container"
388+
default_auto_resume = nl.AutoResume()
389+
390+
# When
391+
wrap_trainer_and_auto_resume_with_mlflashpoint(
392+
trainer,
393+
flashpoint_base_container,
394+
async_save=True,
395+
default_auto_resume=default_auto_resume,
396+
)
397+
398+
# Then
399+
mock_wrap_trainer.assert_called_once()
400+
_, kwargs = mock_wrap_trainer.call_args
401+
assert kwargs["use_fully_parallel_wrapper"] is False
402+
346403

347404
class TestWrapTrainerCheckpointIOWithMLFlashpoint:
348405
"""Tests for the wrap_trainer_checkpoint_io_with_mlflashpoint function."""
@@ -547,6 +604,76 @@ def test_successful_wrapping_no_async_wrapper(self, mocker, mock_ckpt_obj_manage
547604
assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io
548605
assert trainer.strategy.checkpoint_io.async_save is True
549606

607+
def test_fully_parallel_wrapper_enabled(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
608+
"""Tests that FullyParallel wrappers are applied when flag=True."""
609+
610+
# Given
611+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
612+
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
613+
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
614+
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
615+
trainer.strategy.checkpoint_io = original_checkpoint_io
616+
base_container = "/test_base_container"
617+
618+
# When
619+
wrap_trainer_checkpoint_io_with_mlflashpoint(
620+
trainer,
621+
base_container,
622+
mock_ckpt_obj_manager,
623+
mock_replication_manager,
624+
async_save=True,
625+
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
626+
use_fully_parallel_wrapper=True, # 🔥 enable it
627+
)
628+
629+
# Then
630+
wrapped_io = trainer.strategy.checkpoint_io
631+
assert isinstance(wrapped_io, MLFlashpointCheckpointIO)
632+
633+
assert isinstance(
634+
wrapped_io.save_strategy,
635+
FullyParallelSaveStrategyWrapper,
636+
)
637+
assert isinstance(
638+
wrapped_io.load_strategy,
639+
FullyParallelLoadStrategyWrapper,
640+
)
641+
642+
def test_fully_parallel_wrapper_disabled_by_default(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
643+
"""Tests that FullyParallel wrappers are NOT applied when flag=False."""
644+
645+
# Given
646+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
647+
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
648+
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
649+
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
650+
trainer.strategy.checkpoint_io = original_checkpoint_io
651+
base_container = "/test_base_container"
652+
653+
# When
654+
wrap_trainer_checkpoint_io_with_mlflashpoint(
655+
trainer,
656+
base_container,
657+
mock_ckpt_obj_manager,
658+
mock_replication_manager,
659+
async_save=True,
660+
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
661+
use_fully_parallel_wrapper=False, # default behavior
662+
)
663+
664+
# Then
665+
wrapped_io = trainer.strategy.checkpoint_io
666+
assert isinstance(wrapped_io, MLFlashpointCheckpointIO)
667+
668+
assert not isinstance(
669+
wrapped_io.save_strategy,
670+
FullyParallelSaveStrategyWrapper,
671+
)
672+
assert not isinstance(
673+
wrapped_io.load_strategy,
674+
FullyParallelLoadStrategyWrapper,
675+
)
676+
550677
def test_successful_wrapping_with_async_wrapper(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
551678
"""Tests successful wrapping when an async wrapper is present."""
552679
# Given

0 commit comments

Comments
 (0)