From fc2251b8a184b780061162d3cca6b1f9fd9bde6f Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 30 Mar 2026 19:53:29 +0000 Subject: [PATCH 1/9] feat: shutdown replication manager and delete mlf checkpoint dir on train end. --- src/ml_flashpoint/adapter/nemo/checkpoint_callback.py | 11 +++++++++++ src/ml_flashpoint/adapter/nemo/wrapper_util.py | 7 +++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py index 0e45470..6d46f6a 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py @@ -73,6 +73,7 @@ def __init__( self.every_n_steps = every_n_steps self.skip_every_n_steps = skip_every_n_steps if skip_every_n_steps is not None else 0 self._enabled = enabled + self.replication_manager = None self._validate() def _validate(self): @@ -151,3 +152,13 @@ def on_train_batch_end( ckpt_options, ) trainer.save_checkpoint(ckpt_version_container.data, storage_options={ML_FLASHPOINT_OPTS_KEY: ckpt_options}) + + @override + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.replication_manager is not None: + _LOGGER.info("Training ended. Shutting down Replication Manager...") + self.replication_manager.shutdown() + + if trainer.local_rank == 0: + _LOGGER.info("Rank 0: Performing final checkpoint cleanup...") + trainer.strategy.checkpoint_io.remove_checkpoint(self.base_container.data) diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index f2284d8..564c70a 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -201,10 +201,13 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( raise ValueError(f"initial_write_buffer_size_bytes must be > 0, got {initial_write_buffer_size_bytes}.") callbacks = trainer.callbacks - mlflashpoint_enabled = any(isinstance(cb, MLFlashpointCheckpointCallback) for cb in callbacks) - if not mlflashpoint_enabled: + mlf_callbacks = [cb for cb in callbacks if isinstance(cb, MLFlashpointCheckpointCallback)] + if not mlf_callbacks: return + for cb in mlf_callbacks: + cb.replication_manager = replication_manager + if not isinstance(trainer.strategy, nl_strategies.MegatronStrategy): raise ValueError( "Only MegatronStrategy is supported for ML Flashpoint, but got " From aaf42d0c1313bfe3f17ceec906af67da92d1694e Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 30 Mar 2026 21:09:42 +0000 Subject: [PATCH 2/9] ensure mlf checkpoint done before shutdown replication manager --- .../adapter/nemo/checkpoint_callback.py | 18 +++++- .../adapter/nemo/test_checkpoint_callback.py | 64 +++++++++++++++++++ tests/adapter/nemo/test_wrapper_util.py | 30 +++++++++ 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py index 6d46f6a..d845e1c 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py @@ -154,11 +154,23 @@ def on_train_batch_end( trainer.save_checkpoint(ckpt_version_container.data, storage_options={ML_FLASHPOINT_OPTS_KEY: ckpt_options}) @override + @log_execution_time(logger=_LOGGER, name="on_train_end", level=logging.INFO) def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + _LOGGER.info("Training ended. Synchronizing and finalizing checkpoints...") + + # 1. Wait for async checkpoint saves to finish locally + checkpoint_io = getattr(trainer.strategy, "checkpoint_io", None) + if hasattr(checkpoint_io, "maybe_finalize_save_checkpoint"): + checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) + + # 2. Synchronize all ranks to ensure background writes are done everywhere before deletion + if hasattr(trainer.strategy, "barrier"): + trainer.strategy.barrier("mlf_cleanup_barrier") + if self.replication_manager is not None: _LOGGER.info("Training ended. Shutting down Replication Manager...") self.replication_manager.shutdown() - if trainer.local_rank == 0: - _LOGGER.info("Rank 0: Performing final checkpoint cleanup...") - trainer.strategy.checkpoint_io.remove_checkpoint(self.base_container.data) + if trainer.local_rank == 0: + _LOGGER.info("Local rank 0: Performing final checkpoint cleanup...") + trainer.strategy.checkpoint_io.remove_checkpoint(self.base_container.data) diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index f564580..3b0d4f1 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -352,3 +352,67 @@ def test_on_train_batch_end_when_enabled(mocker): # Then # Should save trainer.save_checkpoint.assert_called_once() + + +def test_on_train_end_cleans_up_on_rank_zero(mocker): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + base_container = CheckpointContainerId("/test/base") + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + + # Then + callback.replication_manager.shutdown.assert_called_once() + checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) + + +def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 1 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + base_container = CheckpointContainerId("/test/base") + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + + # Then + callback.replication_manager.shutdown.assert_called_once() + checkpoint_io.remove_checkpoint.assert_not_called() + + +def test_on_train_end_no_replication_manager_skips_shutdown(mocker): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + base_container = CheckpointContainerId("/test/base") + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + # self.replication_manager is inherently None initialized in __init__ + + # When + callback.on_train_end(trainer, pl_module) + + # Then + # replication_manager doesn't crash since it's None and checked. + # still cleans up on rank 0 + checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 2ad6794..487fe0a 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -878,6 +878,36 @@ def test_mlflashpoint_enabled_with_multiple_callbacks( assert isinstance(trainer.strategy.checkpoint_io, MLFlashpointCheckpointIO) assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io + def test_replication_manager_injected_into_callbacks(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): + """Tests that the ReplicationManager is injected into all MLFlashpointCheckpointCallback instances.""" + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + mock_mlf_callback1 = mocker.MagicMock(spec=MLFlashpointCheckpointCallback) + mock_mlf_callback2 = mocker.MagicMock(spec=MLFlashpointCheckpointCallback) + trainer.callbacks = [ + mocker.MagicMock(), + mock_mlf_callback1, + mock_mlf_callback2, + ] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + trainer.strategy.checkpoint_io = original_checkpoint_io + base_container = "/test_base_container" + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + replication_manager=mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + ) + + # Then + assert mock_mlf_callback1.replication_manager is mock_replication_manager + assert mock_mlf_callback2.replication_manager is mock_replication_manager + def test_invalid_config_with_mlf_async_wrapper_and_async_save_false( self, mocker, mock_ckpt_obj_manager, mock_replication_manager ): From 1897d3df8a89d2e0b3671fcf3ebd14b62757f0e5 Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Thu, 2 Apr 2026 15:41:23 +0000 Subject: [PATCH 3/9] resolve comments --- .../adapter/nemo/checkpoint_callback.py | 23 ++++++++++++++----- .../adapter/nemo/wrapper_util.py | 2 +- .../adapter/nemo/test_checkpoint_callback.py | 4 ++-- tests/adapter/nemo/test_wrapper_util.py | 4 ++-- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py index d845e1c..c89b69b 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py @@ -73,9 +73,23 @@ def __init__( self.every_n_steps = every_n_steps self.skip_every_n_steps = skip_every_n_steps if skip_every_n_steps is not None else 0 self._enabled = enabled - self.replication_manager = None + self._replication_manager = None self._validate() + @property + def replication_manager(self): + """Returns the ReplicationManager instance if one has been set.""" + return self._replication_manager + + def set_replication_manager(self, manager): + """ + Sets the ReplicationManager instance. + + This is typically called by the ML Flashpoint wrapper to inject the managers + because the callback is instantiated by the user prior to wrapper initialization. + """ + self._replication_manager = manager + def _validate(self): """Ensures this instance passes validity checks and expectations. Expected to be used by __init__. @@ -159,13 +173,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - _LOGGER.info("Training ended. Synchronizing and finalizing checkpoints...") # 1. Wait for async checkpoint saves to finish locally - checkpoint_io = getattr(trainer.strategy, "checkpoint_io", None) - if hasattr(checkpoint_io, "maybe_finalize_save_checkpoint"): - checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) + trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) # 2. Synchronize all ranks to ensure background writes are done everywhere before deletion - if hasattr(trainer.strategy, "barrier"): - trainer.strategy.barrier("mlf_cleanup_barrier") + trainer.strategy.barrier("mlf_cleanup_barrier") if self.replication_manager is not None: _LOGGER.info("Training ended. Shutting down Replication Manager...") diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 564c70a..9cb84a4 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -206,7 +206,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( return for cb in mlf_callbacks: - cb.replication_manager = replication_manager + cb.set_replication_manager(replication_manager) if not isinstance(trainer.strategy, nl_strategies.MegatronStrategy): raise ValueError( diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index 3b0d4f1..12f2c59 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -365,7 +365,7 @@ def test_on_train_end_cleans_up_on_rank_zero(mocker): base_container = CheckpointContainerId("/test/base") callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) - callback.replication_manager = mocker.MagicMock() + callback.set_replication_manager(mocker.MagicMock()) # When callback.on_train_end(trainer, pl_module) @@ -386,7 +386,7 @@ def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): base_container = CheckpointContainerId("/test/base") callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) - callback.replication_manager = mocker.MagicMock() + callback.set_replication_manager(mocker.MagicMock()) # When callback.on_train_end(trainer, pl_module) diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 487fe0a..5b5219b 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -905,8 +905,8 @@ def test_replication_manager_injected_into_callbacks(self, mocker, mock_ckpt_obj ) # Then - assert mock_mlf_callback1.replication_manager is mock_replication_manager - assert mock_mlf_callback2.replication_manager is mock_replication_manager + mock_mlf_callback1.set_replication_manager.assert_called_once_with(mock_replication_manager) + mock_mlf_callback2.set_replication_manager.assert_called_once_with(mock_replication_manager) def test_invalid_config_with_mlf_async_wrapper_and_async_save_false( self, mocker, mock_ckpt_obj_manager, mock_replication_manager From 66786fb835ba0be34fe916d12905376c8d135e9f Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 6 Apr 2026 17:13:10 +0000 Subject: [PATCH 4/9] resolve comments --- .../adapter/nemo/checkpoint_callback.py | 3 +- .../adapter/nemo/wrapper_util.py | 2 +- .../adapter/nemo/test_checkpoint_callback.py | 31 +++++++++++++++++-- tests/adapter/nemo/test_wrapper_util.py | 4 +-- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py index c89b69b..6435c7b 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py @@ -81,7 +81,8 @@ def replication_manager(self): """Returns the ReplicationManager instance if one has been set.""" return self._replication_manager - def set_replication_manager(self, manager): + @replication_manager.setter + def replication_manager(self, manager): """ Sets the ReplicationManager instance. diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 9cb84a4..564c70a 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -206,7 +206,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( return for cb in mlf_callbacks: - cb.set_replication_manager(replication_manager) + cb.replication_manager = replication_manager if not isinstance(trainer.strategy, nl_strategies.MegatronStrategy): raise ValueError( diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index 12f2c59..724061d 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -365,12 +365,14 @@ def test_on_train_end_cleans_up_on_rank_zero(mocker): base_container = CheckpointContainerId("/test/base") callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) - callback.set_replication_manager(mocker.MagicMock()) + callback.replication_manager = mocker.MagicMock() # When callback.on_train_end(trainer, pl_module) # Then + checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True) + trainer.strategy.barrier.assert_called_once_with("mlf_cleanup_barrier") callback.replication_manager.shutdown.assert_called_once() checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) @@ -386,7 +388,7 @@ def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): base_container = CheckpointContainerId("/test/base") callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) - callback.set_replication_manager(mocker.MagicMock()) + callback.replication_manager = mocker.MagicMock() # When callback.on_train_end(trainer, pl_module) @@ -416,3 +418,28 @@ def test_on_train_end_no_replication_manager_skips_shutdown(mocker): # replication_manager doesn't crash since it's None and checked. # still cleans up on rank 0 checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) + + +def test_on_train_end_is_idempotent(mocker): + """Tests that calling on_train_end twice is safe.""" + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + base_container = CheckpointContainerId("/test/base") + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + callback.on_train_end(trainer, pl_module) + + # Then + assert callback.replication_manager.shutdown.call_count == 2 + assert checkpoint_io.remove_checkpoint.call_count == 2 + assert checkpoint_io.maybe_finalize_save_checkpoint.call_count == 2 + assert trainer.strategy.barrier.call_count == 2 diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 5b5219b..1a16b17 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -905,8 +905,8 @@ def test_replication_manager_injected_into_callbacks(self, mocker, mock_ckpt_obj ) # Then - mock_mlf_callback1.set_replication_manager.assert_called_once_with(mock_replication_manager) - mock_mlf_callback2.set_replication_manager.assert_called_once_with(mock_replication_manager) + assert mock_mlf_callback1.replication_manager == mock_replication_manager + assert mock_mlf_callback2.replication_manager == mock_replication_manager def test_invalid_config_with_mlf_async_wrapper_and_async_save_false( self, mocker, mock_ckpt_obj_manager, mock_replication_manager From 9afab16566b914bbc8b3516d26709c6d6b479d5d Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 6 Apr 2026 17:18:30 +0000 Subject: [PATCH 5/9] resolve comments --- tests/adapter/nemo/test_checkpoint_callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index 724061d..288fdf5 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -394,6 +394,7 @@ def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): callback.on_train_end(trainer, pl_module) # Then + checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True) callback.replication_manager.shutdown.assert_called_once() checkpoint_io.remove_checkpoint.assert_not_called() @@ -409,7 +410,7 @@ def test_on_train_end_no_replication_manager_skips_shutdown(mocker): base_container = CheckpointContainerId("/test/base") callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) - # self.replication_manager is inherently None initialized in __init__ + assert callback.replication_manager is None, "replication_manager is expected to be None initially" # When callback.on_train_end(trainer, pl_module) From cb755a1f5bc85248b67c427b95263317ffe7acd4 Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Mon, 6 Apr 2026 17:28:42 +0000 Subject: [PATCH 6/9] resolve comments --- .../adapter/nemo/test_checkpoint_callback.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index 288fdf5..3fb8ef2 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import shutil + import lightning.pytorch as pl import pytest @@ -354,16 +356,25 @@ def test_on_train_batch_end_when_enabled(mocker): trainer.save_checkpoint.assert_called_once() -def test_on_train_end_cleans_up_on_rank_zero(mocker): +def test_on_train_end_cleans_up_on_rank_zero(mocker, tmp_path): # Given trainer = mocker.MagicMock(spec=pl.Trainer) trainer.local_rank = 0 checkpoint_io = mocker.MagicMock() trainer.strategy.checkpoint_io = checkpoint_io + # Make remove_checkpoint actually delete the directory + checkpoint_io.remove_checkpoint.side_effect = lambda path: shutil.rmtree(path) + pl_module = mocker.MagicMock(spec=pl.LightningModule) - base_container = CheckpointContainerId("/test/base") + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) callback.replication_manager = mocker.MagicMock() @@ -376,8 +387,11 @@ def test_on_train_end_cleans_up_on_rank_zero(mocker): callback.replication_manager.shutdown.assert_called_once() checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) + # Verify file deletion + assert not base_container_path.exists(), "Base container directory should have been deleted" -def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): + +def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker, tmp_path): # Given trainer = mocker.MagicMock(spec=pl.Trainer) trainer.local_rank = 1 @@ -386,7 +400,13 @@ def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): pl_module = mocker.MagicMock(spec=pl.LightningModule) - base_container = CheckpointContainerId("/test/base") + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) callback.replication_manager = mocker.MagicMock() @@ -398,6 +418,10 @@ def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker): callback.replication_manager.shutdown.assert_called_once() checkpoint_io.remove_checkpoint.assert_not_called() + # Verify file retention + assert base_container_path.exists(), "Base container directory should NOT have been deleted" + assert dummy_file.exists(), "Dummy file should NOT have been deleted" + def test_on_train_end_no_replication_manager_skips_shutdown(mocker): # Given From 3b11591ab7bd3d61940b4e855c75141fe12887a3 Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Tue, 7 Apr 2026 13:47:41 +0000 Subject: [PATCH 7/9] Need to shutdown all connections when shutdown transfer service --- .../checkpoint_object_manager.py | 9 +++- .../object_manager/object_manager.cpp | 30 +++++++++++++- .../transfer_service/connection_pool.cpp | 16 ++++++++ .../transfer_service/connection_pool.h | 4 +- .../transfer_service/transfer_service.cpp | 14 +++---- .../test_checkpoint_object_manager.py | 35 +++++++++++++++- .../transfer_service/mlf_log_sink_test.cpp | 6 --- .../transfer_service_p2p_test.cpp | 41 +++++++++++++++++++ 8 files changed, 135 insertions(+), 20 deletions(-) diff --git a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py index 6f6008e..8dd03d8 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py +++ b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py @@ -239,10 +239,17 @@ def delete_container(self, container_id: CheckpointContainerId) -> None: """ container_id = str(container_id) _LOGGER.info("Starting deletion process for container: '%s'", container_id) + + def _onerror(func, path, exc_info): + exc = exc_info[1] + if isinstance(exc, FileNotFoundError): + return + raise exc + try: if os.path.isdir(container_id): # Use shutil.rmtree for recursive deletion. - shutil.rmtree(container_id) + shutil.rmtree(container_id, onerror=_onerror) _LOGGER.info("Successfully deleted container directory: '%s'", container_id) else: # This is not an error; the directory might have already been deleted. diff --git a/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp b/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp index 29dd3c7..f92f9a8 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp +++ b/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp @@ -14,6 +14,10 @@ #include "object_manager.h" +#include +#include +#include + #include #include #include @@ -27,13 +31,35 @@ namespace ml_flashpoint::checkpoint_object_manager::object_manager { namespace fs = std::filesystem; namespace { +// We use a fork/exec approach calling 'rm -rf' here instead of +// std::filesystem::remove_all to address a Segmentation Fault +// observed in multi-threaded environments. This should be safer +// and avoids the crash experienced with std::filesystem operations. +// // The actual deletion logic void delete_directories_task(const std::vector& directories) { for (const std::string& dir_path : directories) { try { if (fs::is_directory(dir_path)) { - LOG(INFO) << "Removing directory " << dir_path << " ..."; - fs::remove_all(dir_path); + LOG(INFO) << "Removing directory " << dir_path << " via fork/exec..."; + pid_t pid = fork(); + if (pid == 0) { + // Child process + execlp("rm", "rm", "-rf", dir_path.c_str(), (char*)NULL); + // If execlp returns, it failed + std::cerr << "Failed to exec rm -rf for " << dir_path << std::endl; + exit(1); + } else if (pid > 0) { + // Parent process + int status; + waitpid(pid, &status, 0); + if (status != 0) { + LOG(ERROR) << "rm -rf failed for " << dir_path << " with status " + << status; + } + } else { + LOG(ERROR) << "Failed to fork for deleting " << dir_path; + } } } catch (const fs::filesystem_error& e) { // It's important to handle errors inside the thread, diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp index e205d33..2eb6e1f 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp @@ -110,6 +110,20 @@ void ConnectionPool::Shutdown() { std::unique_lock lock(mtx_); stopping_ = true; cv_.notify_all(); + + for (int fd : active_connections_) { + LOG(INFO) << "Force shutting down active connection " << fd; + shutdown(fd, SHUT_RDWR); + } + active_connections_.clear(); + + while (!available_connections_.empty()) { + int fd = available_connections_.front(); + available_connections_.pop(); + LOG(INFO) << "Shutting down available connection " << fd; + shutdown(fd, SHUT_RDWR); + close(fd); + } } int ConnectionPool::CreateConnection() { @@ -174,6 +188,7 @@ std::optional ConnectionPool::GetConnection(int timeout_ms) { } int fd = available_connections_.front(); available_connections_.pop(); + active_connections_.insert(fd); return ScopedConnection(fd, this); } @@ -187,6 +202,7 @@ void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) { return; } std::unique_lock lock(mtx_); + active_connections_.erase(sockfd); if (stopping_) { LOG(WARNING) << "ConnectionPool::ReleaseConnection: stopping, close connection"; diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.h b/src/ml_flashpoint/replication/transfer_service/connection_pool.h index 229cac5..52c0f9e 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.h +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.h @@ -37,6 +37,7 @@ #include #include #include +#include namespace ml_flashpoint::replication::transfer_service { @@ -116,7 +117,8 @@ class ConnectionPool { std::string peer_host_; int peer_port_; size_t max_size_; - std::queue available_connections_; // Guarded by mtx_. + std::queue available_connections_; // Guarded by mtx_. + std::unordered_set active_connections_; // Guarded by mtx_. std::mutex mtx_; // Protects available_connections_ and stopping_. std::condition_variable cv_; // Signaled when a connection is released or the pool is stopping. diff --git a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp index ea7f735..9d0d7b6 100644 --- a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp +++ b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp @@ -211,14 +211,7 @@ void TransferService::Shutdown() { epoll_thread_.join(); } - // 6. Stop the thread pool. This will wait for all currently executing tasks - // to complete. Now that both the task queue and epoll threads are stopped, - // no new tasks can be enqueued. - if (thread_pool_) { - thread_pool_->stop(); - } - - // 7. Clean up connection pools. + // 6. Clean up connection pools. { std::unique_lock write_lock(connection_pools_mutex_); for (auto const& [peer_addr, pool] : connection_pools_) { @@ -229,6 +222,11 @@ void TransferService::Shutdown() { connection_pools_.clear(); } + // 7. Stop the thread pool. + if (thread_pool_) { + thread_pool_->stop(); + } + // 8. Clean up epoll fd. if (epoll_fd_ != -1) { if (close(epoll_fd_) == -1) { diff --git a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py index 3d8508e..2e26731 100644 --- a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py +++ b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py @@ -577,7 +577,7 @@ def test_delete_container_success(self, manager_setup, mocker): if is_mock: # In mock mode, verify that shutil.rmtree was called. - mock_rmtree.assert_called_once_with(str(container_path)) + mock_rmtree.assert_called_once_with(str(container_path), onerror=mocker.ANY) else: # In real mode, verify the directory and its content are gone. assert not container_path.exists() @@ -599,7 +599,7 @@ def test_delete_container_with_no_content(self, manager_setup, mocker): manager.delete_container(container_path) if is_mock: - mock_rmtree.assert_called_once_with(container_path) + mock_rmtree.assert_called_once_with(container_path, onerror=mocker.ANY) else: assert not os.path.exists(str(container_path)) @@ -636,6 +636,37 @@ def test_delete_container_propagates_os_error_on_rmtree_failure(self, mocker): with pytest.raises(OSError, match="Permission denied"): manager.delete_container(fake_container_path) + def test_delete_container_ignores_file_not_found_on_rmtree(self, mocker): + """ + Unit Test: Verifies that delete_container ignores FileNotFoundError + during shutil.rmtree by passing a proper onerror handler. + """ + manager = CheckpointObjectManager() + fake_container_path = CheckpointContainerId("/a/fake/path") + + mocker.patch("os.path.isdir", return_value=True) + mock_rmtree = mocker.patch("shutil.rmtree") + + manager.delete_container(fake_container_path) + + mock_rmtree.assert_called_once() + kwargs = mock_rmtree.call_args.kwargs + onerror_handler = kwargs.get("onerror") + assert onerror_handler is not None + + # Test the onerror handler + # Should not raise FileNotFoundError + try: + exc = FileNotFoundError("file missing") + onerror_handler(None, None, (type(exc), exc, None)) + except FileNotFoundError: + pytest.fail("onerror raised FileNotFoundError it should have ignored") + + # Should raise other exceptions + with pytest.raises(ValueError): + exc = ValueError("some other error") + onerror_handler(None, None, (type(exc), exc, None)) + def test_delete_container_on_file_path_does_nothing(self, real_buffer_manager): """ Tests that calling delete_container on a file path does not delete the diff --git a/tests/replication/transfer_service/mlf_log_sink_test.cpp b/tests/replication/transfer_service/mlf_log_sink_test.cpp index 0069ed3..4e2e5ac 100644 --- a/tests/replication/transfer_service/mlf_log_sink_test.cpp +++ b/tests/replication/transfer_service/mlf_log_sink_test.cpp @@ -16,11 +16,9 @@ #include -#include #include #include "absl/log/globals.h" -#include "absl/log/initialize.h" #include "absl/log/log.h" #include "absl/log/log_sink.h" #include "absl/log/log_sink_registry.h" @@ -31,10 +29,6 @@ namespace { class MLFLogSinkTest : public ::testing::Test { protected: - static void SetUpTestSuite() { - static std::once_flag flag; - std::call_once(flag, []() { absl::InitializeLog(); }); - } void SetUp() override { original_threshold_ = absl::StderrThreshold(); diff --git a/tests/replication/transfer_service/transfer_service_p2p_test.cpp b/tests/replication/transfer_service/transfer_service_p2p_test.cpp index 53ef622..33f5cd9 100644 --- a/tests/replication/transfer_service/transfer_service_p2p_test.cpp +++ b/tests/replication/transfer_service/transfer_service_p2p_test.cpp @@ -153,6 +153,47 @@ TEST(TransferServiceP2PTest, PutLargeObject) { service2.Shutdown(); } +TEST(TransferServiceP2PTest, ShutdownInterruptsTransfer) { + TransferService service1; + int port1 = service1.Initialize(); + ASSERT_GT(port1, 0); + + TransferService service2; + int port2 = service2.Initialize(); + ASSERT_GT(port2, 0); + + // Create a large string (100MB) to make sure it takes time to send + const size_t large_size = 100 * 1024 * 1024; + std::string large_data(large_size, 'A'); + + std::string obj_id = "my_interrupt_object"; + + auto put_future = + service1.AsyncPut((void*)large_data.c_str(), large_data.size(), + "127.0.0.1:" + std::to_string(port2), obj_id); + + // Wait a small amount of time to let the transfer start + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Trigger shutdown! + service1.Shutdown(); + + try { + auto put_result = put_future.get(); + EXPECT_FALSE(put_result.success); + LOG(INFO) << "Transfer failed as expected after shutdown."; + } catch (const std::runtime_error& e) { + EXPECT_THAT(e.what(), testing::HasSubstr("Service is shutting down")); + LOG(INFO) << "Transfer threw exception as expected after shutdown: " << e.what(); + } + + // Cleanup file if it was partially created + std::remove(obj_id.c_str()); + std::remove((obj_id + ".tmp").c_str()); + + service2.Shutdown(); +} + TEST(TransferServiceP2PTest, AsyncPutLargeMmapData) { TransferService service1; int port1 = service1.Initialize(); From 406937450a90b6324cfc03ed6bb98a5afa8c269a Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Tue, 7 Apr 2026 14:04:10 +0000 Subject: [PATCH 8/9] revert mlf_log_siink_test.cpp --- tests/replication/transfer_service/mlf_log_sink_test.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/replication/transfer_service/mlf_log_sink_test.cpp b/tests/replication/transfer_service/mlf_log_sink_test.cpp index 4e2e5ac..0069ed3 100644 --- a/tests/replication/transfer_service/mlf_log_sink_test.cpp +++ b/tests/replication/transfer_service/mlf_log_sink_test.cpp @@ -16,9 +16,11 @@ #include +#include #include #include "absl/log/globals.h" +#include "absl/log/initialize.h" #include "absl/log/log.h" #include "absl/log/log_sink.h" #include "absl/log/log_sink_registry.h" @@ -29,6 +31,10 @@ namespace { class MLFLogSinkTest : public ::testing::Test { protected: + static void SetUpTestSuite() { + static std::once_flag flag; + std::call_once(flag, []() { absl::InitializeLog(); }); + } void SetUp() override { original_threshold_ = absl::StderrThreshold(); From d6792bb11cce7a67c23c4dd4b22ac5d6d637d35e Mon Sep 17 00:00:00 2001 From: leahlijuan Date: Tue, 7 Apr 2026 14:11:23 +0000 Subject: [PATCH 9/9] resolve comment --- .../adapter/nemo/test_checkpoint_callback.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index 3fb8ef2..114aa41 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import shutil import lightning.pytorch as pl import pytest @@ -23,6 +22,8 @@ ML_FLASHPOINT_TYPE, MLFlashpointCheckpointCallback, ) +from ml_flashpoint.adapter.nemo.checkpoint_io import MLFlashpointCheckpointIO +from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId from ml_flashpoint.core.mlf_logging import _TRAINING_STEP @@ -360,12 +361,20 @@ def test_on_train_end_cleans_up_on_rank_zero(mocker, tmp_path): # Given trainer = mocker.MagicMock(spec=pl.Trainer) trainer.local_rank = 0 - checkpoint_io = mocker.MagicMock() + chkpt_obj_manager = CheckpointObjectManager() + + checkpoint_io = MLFlashpointCheckpointIO( + flashpoint_base_path=str(tmp_path / "ckpt_base"), + alt_checkpoint_io=mocker.MagicMock(), + chkpt_obj_manager=chkpt_obj_manager, + save_strategy=mocker.MagicMock(), + load_strategy=mocker.MagicMock(), + trainer=trainer, + ) + checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock() + mocker.spy(checkpoint_io, "remove_checkpoint") trainer.strategy.checkpoint_io = checkpoint_io - # Make remove_checkpoint actually delete the directory - checkpoint_io.remove_checkpoint.side_effect = lambda path: shutil.rmtree(path) - pl_module = mocker.MagicMock(spec=pl.LightningModule) # Create a base container directory and a dummy file inside it @@ -445,17 +454,34 @@ def test_on_train_end_no_replication_manager_skips_shutdown(mocker): checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) -def test_on_train_end_is_idempotent(mocker): +def test_on_train_end_is_idempotent(mocker, tmp_path): """Tests that calling on_train_end twice is safe.""" # Given trainer = mocker.MagicMock(spec=pl.Trainer) trainer.local_rank = 0 - checkpoint_io = mocker.MagicMock() + chkpt_obj_manager = CheckpointObjectManager() + + checkpoint_io = MLFlashpointCheckpointIO( + flashpoint_base_path=str(tmp_path / "ckpt_base"), + alt_checkpoint_io=mocker.MagicMock(), + chkpt_obj_manager=chkpt_obj_manager, + save_strategy=mocker.MagicMock(), + load_strategy=mocker.MagicMock(), + trainer=trainer, + ) + checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock() + mocker.spy(checkpoint_io, "remove_checkpoint") trainer.strategy.checkpoint_io = checkpoint_io pl_module = mocker.MagicMock(spec=pl.LightningModule) - base_container = CheckpointContainerId("/test/base") + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) callback.replication_manager = mocker.MagicMock() @@ -468,3 +494,6 @@ def test_on_train_end_is_idempotent(mocker): assert checkpoint_io.remove_checkpoint.call_count == 2 assert checkpoint_io.maybe_finalize_save_checkpoint.call_count == 2 assert trainer.strategy.barrier.call_count == 2 + + # Verify file deletion + assert not base_container_path.exists(), "Base container directory should have been deleted"