Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/ml_flashpoint/adapter/nemo/checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,24 @@ def __init__(
self.every_n_steps = every_n_steps
self.skip_every_n_steps = skip_every_n_steps if skip_every_n_steps is not None else 0
self._enabled = enabled
self._replication_manager = None
self._validate()

@property
def replication_manager(self):
"""Returns the ReplicationManager instance if one has been set."""
return self._replication_manager

@replication_manager.setter
def replication_manager(self, manager):
"""
Sets the ReplicationManager instance.

This is typically called by the ML Flashpoint wrapper to inject the managers
because the callback is instantiated by the user prior to wrapper initialization.
"""
self._replication_manager = manager

def _validate(self):
"""Ensures this instance passes validity checks and expectations. Expected to be used by __init__.

Expand Down Expand Up @@ -151,3 +167,22 @@ def on_train_batch_end(
ckpt_options,
)
trainer.save_checkpoint(ckpt_version_container.data, storage_options={ML_FLASHPOINT_OPTS_KEY: ckpt_options})

@override
@log_execution_time(logger=_LOGGER, name="on_train_end", level=logging.INFO)
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
_LOGGER.info("Training ended. Synchronizing and finalizing checkpoints...")

# 1. Wait for async checkpoint saves to finish locally
trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)

# 2. Synchronize all ranks to ensure background writes are done everywhere before deletion
trainer.strategy.barrier("mlf_cleanup_barrier")

if self.replication_manager is not None:
_LOGGER.info("Training ended. Shutting down Replication Manager...")
self.replication_manager.shutdown()

if trainer.local_rank == 0:
_LOGGER.info("Local rank 0: Performing final checkpoint cleanup...")
trainer.strategy.checkpoint_io.remove_checkpoint(self.base_container.data)
7 changes: 5 additions & 2 deletions src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,13 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
raise ValueError(f"initial_write_buffer_size_bytes must be > 0, got {initial_write_buffer_size_bytes}.")

callbacks = trainer.callbacks
mlflashpoint_enabled = any(isinstance(cb, MLFlashpointCheckpointCallback) for cb in callbacks)
if not mlflashpoint_enabled:
mlf_callbacks = [cb for cb in callbacks if isinstance(cb, MLFlashpointCheckpointCallback)]
if not mlf_callbacks:
return

for cb in mlf_callbacks:
cb.replication_manager = replication_manager

if not isinstance(trainer.strategy, nl_strategies.MegatronStrategy):
raise ValueError(
"Only MegatronStrategy is supported for ML Flashpoint, but got "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,17 @@ def delete_container(self, container_id: CheckpointContainerId) -> None:
"""
container_id = str(container_id)
_LOGGER.info("Starting deletion process for container: '%s'", container_id)

def _onerror(func, path, exc_info):
exc = exc_info[1]
if isinstance(exc, FileNotFoundError):
return
raise exc

try:
if os.path.isdir(container_id):
# Use shutil.rmtree for recursive deletion.
shutil.rmtree(container_id)
shutil.rmtree(container_id, onerror=_onerror)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note: this delete_container function is synchronous, and doesnt use the async C++ impl for delete dir, so we should be careful of when we use each. We might want to make this call that one, and allow for blocking, for consistency

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we delete_container before it's fully finished (save + replication), so the transfer service might make changes to the dir at the same time (in the receiver, we save to a tmp file first and then rename it), so there could be file not found error.

_LOGGER.info("Successfully deleted container directory: '%s'", container_id)
else:
# This is not an error; the directory might have already been deleted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

#include "object_manager.h"

#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>

#include <filesystem>
#include <future>
#include <iostream>
Expand All @@ -27,13 +31,35 @@ namespace ml_flashpoint::checkpoint_object_manager::object_manager {
namespace fs = std::filesystem;

namespace {
// We use a fork/exec approach calling 'rm -rf' here instead of
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, is this the only way to get around that? do we know why there was a seg fault?

// std::filesystem::remove_all to address a Segmentation Fault
// observed in multi-threaded environments. This should be safer
// and avoids the crash experienced with std::filesystem operations.
//
// The actual deletion logic
void delete_directories_task(const std::vector<std::string>& directories) {
for (const std::string& dir_path : directories) {
try {
if (fs::is_directory(dir_path)) {
LOG(INFO) << "Removing directory " << dir_path << " ...";
fs::remove_all(dir_path);
LOG(INFO) << "Removing directory " << dir_path << " via fork/exec...";
pid_t pid = fork();
if (pid == 0) {
// Child process
execlp("rm", "rm", "-rf", dir_path.c_str(), (char*)NULL);
// If execlp returns, it failed
std::cerr << "Failed to exec rm -rf for " << dir_path << std::endl;
exit(1);
} else if (pid > 0) {
// Parent process
int status;
waitpid(pid, &status, 0);
if (status != 0) {
LOG(ERROR) << "rm -rf failed for " << dir_path << " with status "
<< status;
}
} else {
LOG(ERROR) << "Failed to fork for deleting " << dir_path;
}
}
} catch (const fs::filesystem_error& e) {
// It's important to handle errors inside the thread,
Expand Down
16 changes: 16 additions & 0 deletions src/ml_flashpoint/replication/transfer_service/connection_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,20 @@ void ConnectionPool::Shutdown() {
std::unique_lock<std::mutex> lock(mtx_);
stopping_ = true;
cv_.notify_all();

for (int fd : active_connections_) {
LOG(INFO) << "Force shutting down active connection " << fd;
shutdown(fd, SHUT_RDWR);
}
active_connections_.clear();

while (!available_connections_.empty()) {
int fd = available_connections_.front();
available_connections_.pop();
LOG(INFO) << "Shutting down available connection " << fd;
shutdown(fd, SHUT_RDWR);
close(fd);
}
}

int ConnectionPool::CreateConnection() {
Expand Down Expand Up @@ -174,6 +188,7 @@ std::optional<ScopedConnection> ConnectionPool::GetConnection(int timeout_ms) {
}
int fd = available_connections_.front();
available_connections_.pop();
active_connections_.insert(fd);
return ScopedConnection(fd, this);
}

Expand All @@ -187,6 +202,7 @@ void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) {
return;
}
std::unique_lock<std::mutex> lock(mtx_);
active_connections_.erase(sockfd);
if (stopping_) {
LOG(WARNING)
<< "ConnectionPool::ReleaseConnection: stopping, close connection";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <optional>
#include <queue>
#include <string>
#include <unordered_set>

namespace ml_flashpoint::replication::transfer_service {

Expand Down Expand Up @@ -116,7 +117,8 @@ class ConnectionPool {
std::string peer_host_;
int peer_port_;
size_t max_size_;
std::queue<int> available_connections_; // Guarded by mtx_.
std::queue<int> available_connections_; // Guarded by mtx_.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the rationale for using a queue here and unordered set below? specifically why FIFO order is relevant for available_connections.

also are these two collections mutually exclusive?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, they should be be mutually exclusive. For the queue usage for available_connections_, ideally no difference if we using a queue or a stack, just a way to track them and get one to use quickly.

adding active_connections_ as we want to destroy all the alive connection during shutdown

std::unordered_set<int> active_connections_; // Guarded by mtx_.
std::mutex mtx_; // Protects available_connections_ and stopping_.
std::condition_variable
cv_; // Signaled when a connection is released or the pool is stopping.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,7 @@ void TransferService::Shutdown() {
epoll_thread_.join();
}

// 6. Stop the thread pool. This will wait for all currently executing tasks
// to complete. Now that both the task queue and epoll threads are stopped,
// no new tasks can be enqueued.
if (thread_pool_) {
thread_pool_->stop();
}

// 7. Clean up connection pools.
// 6. Clean up connection pools.
{
std::unique_lock write_lock(connection_pools_mutex_);
for (auto const& [peer_addr, pool] : connection_pools_) {
Expand All @@ -229,6 +222,11 @@ void TransferService::Shutdown() {
connection_pools_.clear();
}

// 7. Stop the thread pool.
if (thread_pool_) {
thread_pool_->stop();
}

// 8. Clean up epoll fd.
if (epoll_fd_ != -1) {
if (close(epoll_fd_) == -1) {
Expand Down
145 changes: 145 additions & 0 deletions tests/adapter/nemo/test_checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import lightning.pytorch as pl
import pytest

Expand All @@ -21,6 +22,8 @@
ML_FLASHPOINT_TYPE,
MLFlashpointCheckpointCallback,
)
from ml_flashpoint.adapter.nemo.checkpoint_io import MLFlashpointCheckpointIO
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
from ml_flashpoint.core.mlf_logging import _TRAINING_STEP

Expand Down Expand Up @@ -352,3 +355,145 @@ def test_on_train_batch_end_when_enabled(mocker):
# Then
# Should save
trainer.save_checkpoint.assert_called_once()


def test_on_train_end_cleans_up_on_rank_zero(mocker, tmp_path):
# Given
trainer = mocker.MagicMock(spec=pl.Trainer)
trainer.local_rank = 0
chkpt_obj_manager = CheckpointObjectManager()

checkpoint_io = MLFlashpointCheckpointIO(
flashpoint_base_path=str(tmp_path / "ckpt_base"),
alt_checkpoint_io=mocker.MagicMock(),
chkpt_obj_manager=chkpt_obj_manager,
save_strategy=mocker.MagicMock(),
load_strategy=mocker.MagicMock(),
trainer=trainer,
)
checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock()
mocker.spy(checkpoint_io, "remove_checkpoint")
trainer.strategy.checkpoint_io = checkpoint_io

pl_module = mocker.MagicMock(spec=pl.LightningModule)

# Create a base container directory and a dummy file inside it
base_container_path = tmp_path / "ckpt_base"
base_container_path.mkdir()
dummy_file = base_container_path / "dummy.txt"
dummy_file.write_text("dummy")

base_container = CheckpointContainerId(str(base_container_path))
callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1)
callback.replication_manager = mocker.MagicMock()

# When
callback.on_train_end(trainer, pl_module)

# Then
checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True)
trainer.strategy.barrier.assert_called_once_with("mlf_cleanup_barrier")
callback.replication_manager.shutdown.assert_called_once()
checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data)

# Verify file deletion
assert not base_container_path.exists(), "Base container directory should have been deleted"


def test_on_train_end_skips_cleanup_on_non_zero_rank(mocker, tmp_path):
# Given
trainer = mocker.MagicMock(spec=pl.Trainer)
trainer.local_rank = 1
checkpoint_io = mocker.MagicMock()
trainer.strategy.checkpoint_io = checkpoint_io

pl_module = mocker.MagicMock(spec=pl.LightningModule)

# Create a base container directory and a dummy file inside it
base_container_path = tmp_path / "ckpt_base"
base_container_path.mkdir()
dummy_file = base_container_path / "dummy.txt"
dummy_file.write_text("dummy")

base_container = CheckpointContainerId(str(base_container_path))
callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1)
callback.replication_manager = mocker.MagicMock()

# When
callback.on_train_end(trainer, pl_module)

# Then
checkpoint_io.maybe_finalize_save_checkpoint.assert_called_once_with(blocking=True)
callback.replication_manager.shutdown.assert_called_once()
checkpoint_io.remove_checkpoint.assert_not_called()

# Verify file retention
assert base_container_path.exists(), "Base container directory should NOT have been deleted"
assert dummy_file.exists(), "Dummy file should NOT have been deleted"


def test_on_train_end_no_replication_manager_skips_shutdown(mocker):
# Given
trainer = mocker.MagicMock(spec=pl.Trainer)
trainer.local_rank = 0
checkpoint_io = mocker.MagicMock()
trainer.strategy.checkpoint_io = checkpoint_io

pl_module = mocker.MagicMock(spec=pl.LightningModule)

base_container = CheckpointContainerId("/test/base")
callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1)
assert callback.replication_manager is None, "replication_manager is expected to be None initially"

# When
callback.on_train_end(trainer, pl_module)

# Then
# replication_manager doesn't crash since it's None and checked.
# still cleans up on rank 0
checkpoint_io.remove_checkpoint.assert_called_once_with(base_container.data)


def test_on_train_end_is_idempotent(mocker, tmp_path):
"""Tests that calling on_train_end twice is safe."""
# Given
trainer = mocker.MagicMock(spec=pl.Trainer)
trainer.local_rank = 0
chkpt_obj_manager = CheckpointObjectManager()

checkpoint_io = MLFlashpointCheckpointIO(
flashpoint_base_path=str(tmp_path / "ckpt_base"),
alt_checkpoint_io=mocker.MagicMock(),
chkpt_obj_manager=chkpt_obj_manager,
save_strategy=mocker.MagicMock(),
load_strategy=mocker.MagicMock(),
trainer=trainer,
)
checkpoint_io.maybe_finalize_save_checkpoint = mocker.MagicMock()
mocker.spy(checkpoint_io, "remove_checkpoint")
trainer.strategy.checkpoint_io = checkpoint_io

pl_module = mocker.MagicMock(spec=pl.LightningModule)

# Create a base container directory and a dummy file inside it
base_container_path = tmp_path / "ckpt_base"
base_container_path.mkdir()
dummy_file = base_container_path / "dummy.txt"
dummy_file.write_text("dummy")

base_container = CheckpointContainerId(str(base_container_path))
callback = MLFlashpointCheckpointCallback(checkpoint_base_container=base_container, every_n_steps=1)
callback.replication_manager = mocker.MagicMock()

# When
callback.on_train_end(trainer, pl_module)
callback.on_train_end(trainer, pl_module)

# Then
assert callback.replication_manager.shutdown.call_count == 2
assert checkpoint_io.remove_checkpoint.call_count == 2
assert checkpoint_io.maybe_finalize_save_checkpoint.call_count == 2
assert trainer.strategy.barrier.call_count == 2

# Verify file deletion
assert not base_container_path.exists(), "Base container directory should have been deleted"
Loading
Loading