diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py index 0e45470..6435c7b 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_callback.py @@ -73,8 +73,24 @@ def __init__( self.every_n_steps = every_n_steps self.skip_every_n_steps = skip_every_n_steps if skip_every_n_steps is not None else 0 self._enabled = enabled + self._replication_manager = None self._validate() + @property + def replication_manager(self): + """Returns the ReplicationManager instance if one has been set.""" + return self._replication_manager + + @replication_manager.setter + def replication_manager(self, manager): + """ + Sets the ReplicationManager instance. + + This is typically called by the ML Flashpoint wrapper to inject the managers + because the callback is instantiated by the user prior to wrapper initialization. + """ + self._replication_manager = manager + def _validate(self): """Ensures this instance passes validity checks and expectations. Expected to be used by __init__. @@ -151,3 +167,22 @@ def on_train_batch_end( ckpt_options, ) trainer.save_checkpoint(ckpt_version_container.data, storage_options={ML_FLASHPOINT_OPTS_KEY: ckpt_options}) + + @override + @log_execution_time(logger=_LOGGER, name="on_train_end", level=logging.INFO) + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + _LOGGER.info("Training ended. Synchronizing and finalizing checkpoints...") + + # 1. Wait for async checkpoint saves to finish locally + trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) + + # 2. Synchronize all ranks to ensure background writes are done everywhere before deletion + trainer.strategy.barrier("mlf_cleanup_barrier") + + if self.replication_manager is not None: + _LOGGER.info("Training ended. Shutting down Replication Manager...") + self.replication_manager.shutdown() + + if trainer.local_rank == 0: + _LOGGER.info("Local rank 0: Performing final checkpoint cleanup...") + trainer.strategy.checkpoint_io.remove_checkpoint(self.base_container.data) diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index f2284d8..564c70a 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -201,10 +201,13 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( raise ValueError(f"initial_write_buffer_size_bytes must be > 0, got {initial_write_buffer_size_bytes}.") callbacks = trainer.callbacks - mlflashpoint_enabled = any(isinstance(cb, MLFlashpointCheckpointCallback) for cb in callbacks) - if not mlflashpoint_enabled: + mlf_callbacks = [cb for cb in callbacks if isinstance(cb, MLFlashpointCheckpointCallback)] + if not mlf_callbacks: return + for cb in mlf_callbacks: + cb.replication_manager = replication_manager + if not isinstance(trainer.strategy, nl_strategies.MegatronStrategy): raise ValueError( "Only MegatronStrategy is supported for ML Flashpoint, but got " diff --git a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py index 6f6008e..8dd03d8 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py +++ b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py @@ -239,10 +239,17 @@ def delete_container(self, container_id: CheckpointContainerId) -> None: """ container_id = str(container_id) _LOGGER.info("Starting deletion process for container: '%s'", container_id) + + def _onerror(func, path, exc_info): + exc = exc_info[1] + if isinstance(exc, FileNotFoundError): + return + raise exc + try: if os.path.isdir(container_id): # Use shutil.rmtree for recursive deletion. - shutil.rmtree(container_id) + shutil.rmtree(container_id, onerror=_onerror) _LOGGER.info("Successfully deleted container directory: '%s'", container_id) else: # This is not an error; the directory might have already been deleted. diff --git a/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp b/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp index 29dd3c7..f92f9a8 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp +++ b/src/ml_flashpoint/checkpoint_object_manager/object_manager/object_manager.cpp @@ -14,6 +14,10 @@ #include "object_manager.h" +#include +#include +#include + #include #include #include @@ -27,13 +31,35 @@ namespace ml_flashpoint::checkpoint_object_manager::object_manager { namespace fs = std::filesystem; namespace { +// We use a fork/exec approach calling 'rm -rf' here instead of +// std::filesystem::remove_all to address a Segmentation Fault +// observed in multi-threaded environments. This should be safer +// and avoids the crash experienced with std::filesystem operations. +// // The actual deletion logic void delete_directories_task(const std::vector& directories) { for (const std::string& dir_path : directories) { try { if (fs::is_directory(dir_path)) { - LOG(INFO) << "Removing directory " << dir_path << " ..."; - fs::remove_all(dir_path); + LOG(INFO) << "Removing directory " << dir_path << " via fork/exec..."; + pid_t pid = fork(); + if (pid == 0) { + // Child process + execlp("rm", "rm", "-rf", dir_path.c_str(), (char*)NULL); + // If execlp returns, it failed + std::cerr << "Failed to exec rm -rf for " << dir_path << std::endl; + exit(1); + } else if (pid > 0) { + // Parent process + int status; + waitpid(pid, &status, 0); + if (status != 0) { + LOG(ERROR) << "rm -rf failed for " << dir_path << " with status " + << status; + } + } else { + LOG(ERROR) << "Failed to fork for deleting " << dir_path; + } } } catch (const fs::filesystem_error& e) { // It's important to handle errors inside the thread, diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp index e205d33..2eb6e1f 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp @@ -110,6 +110,20 @@ void ConnectionPool::Shutdown() { std::unique_lock lock(mtx_); stopping_ = true; cv_.notify_all(); + + for (int fd : active_connections_) { + LOG(INFO) << "Force shutting down active connection " << fd; + shutdown(fd, SHUT_RDWR); + } + active_connections_.clear(); + + while (!available_connections_.empty()) { + int fd = available_connections_.front(); + available_connections_.pop(); + LOG(INFO) << "Shutting down available connection " << fd; + shutdown(fd, SHUT_RDWR); + close(fd); + } } int ConnectionPool::CreateConnection() { @@ -174,6 +188,7 @@ std::optional ConnectionPool::GetConnection(int timeout_ms) { } int fd = available_connections_.front(); available_connections_.pop(); + active_connections_.insert(fd); return ScopedConnection(fd, this); } @@ -187,6 +202,7 @@ void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) { return; } std::unique_lock lock(mtx_); + active_connections_.erase(sockfd); if (stopping_) { LOG(WARNING) << "ConnectionPool::ReleaseConnection: stopping, close connection"; diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.h b/src/ml_flashpoint/replication/transfer_service/connection_pool.h index 229cac5..52c0f9e 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.h +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.h @@ -37,6 +37,7 @@ #include #include #include +#include namespace ml_flashpoint::replication::transfer_service { @@ -116,7 +117,8 @@ class ConnectionPool { std::string peer_host_; int peer_port_; size_t max_size_; - std::queue available_connections_; // Guarded by mtx_. + std::queue available_connections_; // Guarded by mtx_. + std::unordered_set active_connections_; // Guarded by mtx_. std::mutex mtx_; // Protects available_connections_ and stopping_. std::condition_variable cv_; // Signaled when a connection is released or the pool is stopping. diff --git a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp index ea7f735..9d0d7b6 100644 --- a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp +++ b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp @@ -211,14 +211,7 @@ void TransferService::Shutdown() { epoll_thread_.join(); } - // 6. Stop the thread pool. This will wait for all currently executing tasks - // to complete. Now that both the task queue and epoll threads are stopped, - // no new tasks can be enqueued. - if (thread_pool_) { - thread_pool_->stop(); - } - - // 7. Clean up connection pools. + // 6. Clean up connection pools. { std::unique_lock write_lock(connection_pools_mutex_); for (auto const& [peer_addr, pool] : connection_pools_) { @@ -229,6 +222,11 @@ void TransferService::Shutdown() { connection_pools_.clear(); } + // 7. Stop the thread pool. + if (thread_pool_) { + thread_pool_->stop(); + } + // 8. Clean up epoll fd. if (epoll_fd_ != -1) { if (close(epoll_fd_) == -1) { diff --git a/tests/adapter/nemo/test_checkpoint_callback.py b/tests/adapter/nemo/test_checkpoint_callback.py index f564580..114aa41 100644 --- a/tests/adapter/nemo/test_checkpoint_callback.py +++ b/tests/adapter/nemo/test_checkpoint_callback.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import lightning.pytorch as pl import pytest @@ -21,6 +22,8 @@ ML_FLASHPOINT_TYPE, MLFlashpointCheckpointCallback, ) +from ml_flashpoint.adapter.nemo.checkpoint_io import MLFlashpointCheckpointIO +from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId from ml_flashpoint.core.mlf_logging import _TRAINING_STEP @@ -352,3 +355,145 @@ def test_on_train_batch_end_when_enabled(mocker): # Then # Should save trainer.save_checkpoint.assert_called_once() + + +def test_on_train_end_cleans_up_on_rank_zero(mocker, tmp_path): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + chkpt_obj_manager = CheckpointObjectManager() + + checkpoint_io = MLFlashpointCheckpointIO( + flashpoint_base_path=str(tmp_path / "ckpt_base"), + alt_checkpoint_io=mocker.MagicMock(), + chkpt_obj_manager=chkpt_obj_manager, + save_strategy=mocker.MagicMock(), + load_strategy=mocker.MagicMock(), + trainer=trainer, + ) + checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock() + mocker.spy(checkpoint_io, "remove_checkpoint") + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + + # Then + checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True) + trainer.strategy.barrier.assert_called_once_with("mlf_cleanup_barrier") + callback.replication_manager.shutdown.assert_called_once() + checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) + + # Verify file deletion + assert not base_container_path.exists(), "Base container directory should have been deleted" + + +def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker, tmp_path): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 1 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + + # Then + checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True) + callback.replication_manager.shutdown.assert_called_once() + checkpoint_io.remove_checkpoint.assert_not_called() + + # Verify file retention + assert base_container_path.exists(), "Base container directory should NOT have been deleted" + assert dummy_file.exists(), "Dummy file should NOT have been deleted" + + +def test_on_train_end_no_replication_manager_skips_shutdown(mocker): + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + checkpoint_io = mocker.MagicMock() + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + base_container = CheckpointContainerId("/test/base") + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + assert callback.replication_manager is None, "replication_manager is expected to be None initially" + + # When + callback.on_train_end(trainer, pl_module) + + # Then + # replication_manager doesn't crash since it's None and checked. + # still cleans up on rank 0 + checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data) + + +def test_on_train_end_is_idempotent(mocker, tmp_path): + """Tests that calling on_train_end twice is safe.""" + # Given + trainer = mocker.MagicMock(spec=pl.Trainer) + trainer.local_rank = 0 + chkpt_obj_manager = CheckpointObjectManager() + + checkpoint_io = MLFlashpointCheckpointIO( + flashpoint_base_path=str(tmp_path / "ckpt_base"), + alt_checkpoint_io=mocker.MagicMock(), + chkpt_obj_manager=chkpt_obj_manager, + save_strategy=mocker.MagicMock(), + load_strategy=mocker.MagicMock(), + trainer=trainer, + ) + checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock() + mocker.spy(checkpoint_io, "remove_checkpoint") + trainer.strategy.checkpoint_io = checkpoint_io + + pl_module = mocker.MagicMock(spec=pl.LightningModule) + + # Create a base container directory and a dummy file inside it + base_container_path = tmp_path / "ckpt_base" + base_container_path.mkdir() + dummy_file = base_container_path / "dummy.txt" + dummy_file.write_text("dummy") + + base_container = CheckpointContainerId(str(base_container_path)) + callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1) + callback.replication_manager = mocker.MagicMock() + + # When + callback.on_train_end(trainer, pl_module) + callback.on_train_end(trainer, pl_module) + + # Then + assert callback.replication_manager.shutdown.call_count == 2 + assert checkpoint_io.remove_checkpoint.call_count == 2 + assert checkpoint_io.maybe_finalize_save_checkpoint.call_count == 2 + assert trainer.strategy.barrier.call_count == 2 + + # Verify file deletion + assert not base_container_path.exists(), "Base container directory should have been deleted" diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 2ad6794..1a16b17 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -878,6 +878,36 @@ def test_mlflashpoint_enabled_with_multiple_callbacks( assert isinstance(trainer.strategy.checkpoint_io, MLFlashpointCheckpointIO) assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io + def test_replication_manager_injected_into_callbacks(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): + """Tests that the ReplicationManager is injected into all MLFlashpointCheckpointCallback instances.""" + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + mock_mlf_callback1 = mocker.MagicMock(spec=MLFlashpointCheckpointCallback) + mock_mlf_callback2 = mocker.MagicMock(spec=MLFlashpointCheckpointCallback) + trainer.callbacks = [ + mocker.MagicMock(), + mock_mlf_callback1, + mock_mlf_callback2, + ] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + trainer.strategy.checkpoint_io = original_checkpoint_io + base_container = "/test_base_container" + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + replication_manager=mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + ) + + # Then + assert mock_mlf_callback1.replication_manager == mock_replication_manager + assert mock_mlf_callback2.replication_manager == mock_replication_manager + def test_invalid_config_with_mlf_async_wrapper_and_async_save_false( self, mocker, mock_ckpt_obj_manager, mock_replication_manager ): diff --git a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py index 3d8508e..2e26731 100644 --- a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py +++ b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py @@ -577,7 +577,7 @@ def test_delete_container_success(self, manager_setup, mocker): if is_mock: # In mock mode, verify that shutil.rmtree was called. - mock_rmtree.assert_called_once_with(str(container_path)) + mock_rmtree.assert_called_once_with(str(container_path), onerror=mocker.ANY) else: # In real mode, verify the directory and its content are gone. assert not container_path.exists() @@ -599,7 +599,7 @@ def test_delete_container_with_no_content(self, manager_setup, mocker): manager.delete_container(container_path) if is_mock: - mock_rmtree.assert_called_once_with(container_path) + mock_rmtree.assert_called_once_with(container_path, onerror=mocker.ANY) else: assert not os.path.exists(str(container_path)) @@ -636,6 +636,37 @@ def test_delete_container_propagates_os_error_on_rmtree_failure(self, mocker): with pytest.raises(OSError, match="Permission denied"): manager.delete_container(fake_container_path) + def test_delete_container_ignores_file_not_found_on_rmtree(self, mocker): + """ + Unit Test: Verifies that delete_container ignores FileNotFoundError + during shutil.rmtree by passing a proper onerror handler. + """ + manager = CheckpointObjectManager() + fake_container_path = CheckpointContainerId("/a/fake/path") + + mocker.patch("os.path.isdir", return_value=True) + mock_rmtree = mocker.patch("shutil.rmtree") + + manager.delete_container(fake_container_path) + + mock_rmtree.assert_called_once() + kwargs = mock_rmtree.call_args.kwargs + onerror_handler = kwargs.get("onerror") + assert onerror_handler is not None + + # Test the onerror handler + # Should not raise FileNotFoundError + try: + exc = FileNotFoundError("file missing") + onerror_handler(None, None, (type(exc), exc, None)) + except FileNotFoundError: + pytest.fail("onerror raised FileNotFoundError it should have ignored") + + # Should raise other exceptions + with pytest.raises(ValueError): + exc = ValueError("some other error") + onerror_handler(None, None, (type(exc), exc, None)) + def test_delete_container_on_file_path_does_nothing(self, real_buffer_manager): """ Tests that calling delete_container on a file path does not delete the diff --git a/tests/replication/transfer_service/transfer_service_p2p_test.cpp b/tests/replication/transfer_service/transfer_service_p2p_test.cpp index 53ef622..33f5cd9 100644 --- a/tests/replication/transfer_service/transfer_service_p2p_test.cpp +++ b/tests/replication/transfer_service/transfer_service_p2p_test.cpp @@ -153,6 +153,47 @@ TEST(TransferServiceP2PTest, PutLargeObject) { service2.Shutdown(); } +TEST(TransferServiceP2PTest, ShutdownInterruptsTransfer) { + TransferService service1; + int port1 = service1.Initialize(); + ASSERT_GT(port1, 0); + + TransferService service2; + int port2 = service2.Initialize(); + ASSERT_GT(port2, 0); + + // Create a large string (100MB) to make sure it takes time to send + const size_t large_size = 100 * 1024 * 1024; + std::string large_data(large_size, 'A'); + + std::string obj_id = "my_interrupt_object"; + + auto put_future = + service1.AsyncPut((void*)large_data.c_str(), large_data.size(), + "127.0.0.1:" + std::to_string(port2), obj_id); + + // Wait a small amount of time to let the transfer start + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Trigger shutdown! + service1.Shutdown(); + + try { + auto put_result = put_future.get(); + EXPECT_FALSE(put_result.success); + LOG(INFO) << "Transfer failed as expected after shutdown."; + } catch (const std::runtime_error& e) { + EXPECT_THAT(e.what(), testing::HasSubstr("Service is shutting down")); + LOG(INFO) << "Transfer threw exception as expected after shutdown: " << e.what(); + } + + // Cleanup file if it was partially created + std::remove(obj_id.c_str()); + std::remove((obj_id + ".tmp").c_str()); + + service2.Shutdown(); +} + TEST(TransferServiceP2PTest, AsyncPutLargeMmapData) { TransferService service1; int port1 = service1.Initialize();