|
17 | 17 | import dataclasses |
18 | 18 |
|
19 | 19 | import pytest |
| 20 | +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( |
| 21 | + FullyParallelLoadStrategyWrapper, |
| 22 | + FullyParallelSaveStrategyWrapper, |
| 23 | +) |
20 | 24 | from nemo import lightning as nl |
21 | 25 | from nemo.lightning.io.pl import MegatronCheckpointIO |
22 | 26 | from nemo.lightning.pytorch import strategies as nl_strategies |
@@ -128,6 +132,7 @@ def test_successful_wrap_and_resume_creation(self, mocker, mock_ckpt_obj_manager |
128 | 132 | initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES, |
129 | 133 | use_optimized_save=True, |
130 | 134 | use_cached_ckpt_structure=False, |
| 135 | + use_fully_parallel_wrapper=False, |
131 | 136 | ) |
132 | 137 |
|
133 | 138 | # 3. Result is correct type and has correct attributes |
@@ -343,6 +348,58 @@ def test_use_cached_ckpt_structure_default_value(self, mocker, mock_ckpt_obj_man |
343 | 348 | _, kwargs = mock_wrap_trainer.call_args |
344 | 349 | assert kwargs["use_cached_ckpt_structure"] is False |
345 | 350 |
|
| 351 | + @pytest.mark.parametrize("use_fully_parallel_wrapper", [True, False]) |
| 352 | + def test_use_fully_parallel_wrapper_forwarding(self, mocker, use_fully_parallel_wrapper): |
| 353 | + """Tests that use_fully_parallel_wrapper is forwarded correctly.""" |
| 354 | + # Given |
| 355 | + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager") |
| 356 | + mock_wrap_trainer = mocker.patch( |
| 357 | + "ml_flashpoint.adapter.nemo.wrapper_util.wrap_trainer_checkpoint_io_with_mlflashpoint" |
| 358 | + ) |
| 359 | + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) |
| 360 | + trainer.global_rank = 0 |
| 361 | + flashpoint_base_container = "/tmp/test_container" |
| 362 | + default_auto_resume = nl.AutoResume() |
| 363 | + |
| 364 | + # When |
| 365 | + wrap_trainer_and_auto_resume_with_mlflashpoint( |
| 366 | + trainer, |
| 367 | + flashpoint_base_container, |
| 368 | + async_save=True, |
| 369 | + default_auto_resume=default_auto_resume, |
| 370 | + use_fully_parallel_wrapper=use_fully_parallel_wrapper, |
| 371 | + ) |
| 372 | + |
| 373 | + # Then |
| 374 | + mock_wrap_trainer.assert_called_once() |
| 375 | + _, kwargs = mock_wrap_trainer.call_args |
| 376 | + assert kwargs["use_fully_parallel_wrapper"] is use_fully_parallel_wrapper |
| 377 | + |
| 378 | + def test_use_fully_parallel_wrapper_default_value(self, mocker): |
| 379 | + """Tests that use_fully_parallel_wrapper defaults to False.""" |
| 380 | + # Given |
| 381 | + mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager") |
| 382 | + mock_wrap_trainer = mocker.patch( |
| 383 | + "ml_flashpoint.adapter.nemo.wrapper_util.wrap_trainer_checkpoint_io_with_mlflashpoint" |
| 384 | + ) |
| 385 | + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) |
| 386 | + trainer.global_rank = 0 |
| 387 | + flashpoint_base_container = "/tmp/test_container" |
| 388 | + default_auto_resume = nl.AutoResume() |
| 389 | + |
| 390 | + # When |
| 391 | + wrap_trainer_and_auto_resume_with_mlflashpoint( |
| 392 | + trainer, |
| 393 | + flashpoint_base_container, |
| 394 | + async_save=True, |
| 395 | + default_auto_resume=default_auto_resume, |
| 396 | + ) |
| 397 | + |
| 398 | + # Then |
| 399 | + mock_wrap_trainer.assert_called_once() |
| 400 | + _, kwargs = mock_wrap_trainer.call_args |
| 401 | + assert kwargs["use_fully_parallel_wrapper"] is False |
| 402 | + |
346 | 403 |
|
347 | 404 | class TestWrapTrainerCheckpointIOWithMLFlashpoint: |
348 | 405 | """Tests for the wrap_trainer_checkpoint_io_with_mlflashpoint function.""" |
@@ -547,6 +604,76 @@ def test_successful_wrapping_no_async_wrapper(self, mocker, mock_ckpt_obj_manage |
547 | 604 | assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io |
548 | 605 | assert trainer.strategy.checkpoint_io.async_save is True |
549 | 606 |
|
| 607 | + def test_fully_parallel_wrapper_enabled(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): |
| 608 | + """Tests that FullyParallel wrappers are applied when flag=True.""" |
| 609 | + |
| 610 | + # Given |
| 611 | + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) |
| 612 | + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] |
| 613 | + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) |
| 614 | + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) |
| 615 | + trainer.strategy.checkpoint_io = original_checkpoint_io |
| 616 | + base_container = "/test_base_container" |
| 617 | + |
| 618 | + # When |
| 619 | + wrap_trainer_checkpoint_io_with_mlflashpoint( |
| 620 | + trainer, |
| 621 | + base_container, |
| 622 | + mock_ckpt_obj_manager, |
| 623 | + mock_replication_manager, |
| 624 | + async_save=True, |
| 625 | + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), |
| 626 | + use_fully_parallel_wrapper=True, # 🔥 enable it |
| 627 | + ) |
| 628 | + |
| 629 | + # Then |
| 630 | + wrapped_io = trainer.strategy.checkpoint_io |
| 631 | + assert isinstance(wrapped_io, MLFlashpointCheckpointIO) |
| 632 | + |
| 633 | + assert isinstance( |
| 634 | + wrapped_io.save_strategy, |
| 635 | + FullyParallelSaveStrategyWrapper, |
| 636 | + ) |
| 637 | + assert isinstance( |
| 638 | + wrapped_io.load_strategy, |
| 639 | + FullyParallelLoadStrategyWrapper, |
| 640 | + ) |
| 641 | + |
| 642 | + def test_fully_parallel_wrapper_disabled_by_default(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): |
| 643 | + """Tests that FullyParallel wrappers are NOT applied when flag=False.""" |
| 644 | + |
| 645 | + # Given |
| 646 | + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) |
| 647 | + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] |
| 648 | + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) |
| 649 | + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) |
| 650 | + trainer.strategy.checkpoint_io = original_checkpoint_io |
| 651 | + base_container = "/test_base_container" |
| 652 | + |
| 653 | + # When |
| 654 | + wrap_trainer_checkpoint_io_with_mlflashpoint( |
| 655 | + trainer, |
| 656 | + base_container, |
| 657 | + mock_ckpt_obj_manager, |
| 658 | + mock_replication_manager, |
| 659 | + async_save=True, |
| 660 | + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), |
| 661 | + use_fully_parallel_wrapper=False, # default behavior |
| 662 | + ) |
| 663 | + |
| 664 | + # Then |
| 665 | + wrapped_io = trainer.strategy.checkpoint_io |
| 666 | + assert isinstance(wrapped_io, MLFlashpointCheckpointIO) |
| 667 | + |
| 668 | + assert not isinstance( |
| 669 | + wrapped_io.save_strategy, |
| 670 | + FullyParallelSaveStrategyWrapper, |
| 671 | + ) |
| 672 | + assert not isinstance( |
| 673 | + wrapped_io.load_strategy, |
| 674 | + FullyParallelLoadStrategyWrapper, |
| 675 | + ) |
| 676 | + |
550 | 677 | def test_successful_wrapping_with_async_wrapper(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): |
551 | 678 | """Tests successful wrapping when an async wrapper is present.""" |
552 | 679 | # Given |
|
0 commit comments