diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index db3f5334ed162..5b025f0398153 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -415,6 +415,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: ("py:func", "lightning.pytorch.callbacks.RichProgressBar.configure_columns"), ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_load_checkpoint"), ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_save_checkpoint"), + ("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_checkpoint_write_end"), ("py:class", "lightning.pytorch.callbacks.checkpoint.Checkpoint"), ("py:meth", "lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics"), ("py:class", "lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme"), diff --git a/docs/source-pytorch/extensions/callbacks.rst b/docs/source-pytorch/extensions/callbacks.rst index 7ed285591c4dc..729111a7578cb 100644 --- a/docs/source-pytorch/extensions/callbacks.rst +++ b/docs/source-pytorch/extensions/callbacks.rst @@ -344,6 +344,12 @@ on_load_checkpoint .. automethod:: lightning.pytorch.callbacks.Callback.on_load_checkpoint :noindex: +on_checkpoint_write_end +^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: lightning.pytorch.callbacks.Callback.on_checkpoint_write_end + :noindex: + on_before_backward ^^^^^^^^^^^^^^^^^^ diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b1d60ca48847e..5ea9f7502a88e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `Callback.on_checkpoint_write_end` hook that triggers after checkpoint files are fully written to disk ([#XXXXX](https://github.com/Lightning-AI/pytorch-lightning/pull/XXXXX)) + ### Changed - Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609465a83..f1d25b8db0636 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -271,6 +271,18 @@ def on_load_checkpoint( """ + def on_checkpoint_write_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", filepath: str) -> None: + r"""Called after a checkpoint file has been fully written to disk. + + Use this hook to perform any post-save actions such as logging, uploading, or cleanup. + + Args: + trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. + filepath: The path to the checkpoint file that was written. + + """ + def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: """Called before ``loss.backward()``.""" diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..15f0e925a9fdb 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -67,6 +67,7 @@ def __init__( on_exception: Optional[Callable] = None, on_save_checkpoint: Optional[Callable] = None, on_load_checkpoint: Optional[Callable] = None, + on_checkpoint_write_end: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..24e4cc17b0423 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -709,3 +709,27 @@ def on_save_checkpoint(self, checkpoint): There is no need for you to store anything about training. """ + + def on_checkpoint_write_end(self, filepath: str) -> None: + r"""Called after a checkpoint file has been fully written to disk. + + This hook is triggered after the checkpoint saving process completes, + ensuring the file exists and is readable. Unlike :meth:`on_save_checkpoint`, + which is called before the checkpoint is written, this hook guarantees + the file is available on disk. + + Args: + filepath: Path to the checkpoint file that was written. + + Example:: + + class MyModel(LightningModule): + def on_checkpoint_write_end(self, filepath): + print(f"Checkpoint saved at: {filepath}") + upload_to_s3(filepath) + + Note: + In distributed training, this hook is called on all ranks after + the barrier synchronization completes. + + """ diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f2f59e396ab23..74ed45a07f2a2 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1462,6 +1462,9 @@ def save_checkpoint( self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options) self.strategy.barrier("Trainer.save_checkpoint") + call._call_callback_hooks(self, "on_checkpoint_write_end", filepath) + call._call_lightning_module_hook(self, "on_checkpoint_write_end", filepath) + """ State properties """ diff --git a/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py b/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py new file mode 100644 index 0000000000000..18968553aa746 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_checkpoint_write_end_callback.py @@ -0,0 +1,41 @@ +import os + +import torch + +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class CheckpointWriteEndCallback(Callback): + def __init__(self): + self.called = False + self.filepath = None + self.file_existed = False + self.checkpoint_valid = False + + def on_checkpoint_write_end(self, trainer, pl_module, filepath): + """Verify that the hook triggers after checkpoint is written.""" + self.called = True + self.filepath = str(filepath) + self.file_existed = os.path.exists(filepath) + + checkpoint = torch.load(filepath, map_location="cpu") + self.checkpoint_valid = "state_dict" in checkpoint + + +def test_on_checkpoint_write_end_called(tmp_path): + """Test that on_checkpoint_write_end is called after saving a checkpoint.""" + model = BoringModel() + callback = CheckpointWriteEndCallback() + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False) + + trainer.fit(model) + + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + assert checkpoint_path.exists() + assert callback.called + assert callback.file_existed + assert callback.checkpoint_valid + assert callback.filepath == str(checkpoint_path) diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index d3d355edb003b..b6573a5799c1b 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -54,6 +54,7 @@ def test_fx_validator(): "on_sanity_check_start", "state_dict", "on_save_checkpoint", + "on_checkpoint_write_end", "on_test_batch_end", "on_test_batch_start", "on_test_end", @@ -87,6 +88,7 @@ def test_fx_validator(): "on_fit_start", "on_exception", "on_load_checkpoint", + "on_checkpoint_write_end", "load_state_dict", "on_sanity_check_end", "on_sanity_check_start",