Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:func", "lightning.pytorch.callbacks.RichProgressBar.configure_columns"),
("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_load_checkpoint"),
("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_save_checkpoint"),
("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_checkpoint_write_end"),
("py:class", "lightning.pytorch.callbacks.checkpoint.Checkpoint"),
("py:meth", "lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics"),
("py:class", "lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme"),
Expand Down
6 changes: 6 additions & 0 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ on_load_checkpoint
.. automethod:: lightning.pytorch.callbacks.Callback.on_load_checkpoint
:noindex:

on_checkpoint_write_end
^^^^^^^^^^^^^^^^^^^^^^^

.. automethod:: lightning.pytorch.callbacks.Callback.on_checkpoint_write_end
:noindex:

on_before_backward
^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))


- Added `Callback.on_checkpoint_write_end` hook that triggers after checkpoint files are fully written to disk ([#XXXXX](https://github.com/Lightning-AI/pytorch-lightning/pull/XXXXX))

### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
Expand Down
12 changes: 12 additions & 0 deletions src/lightning/pytorch/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,18 @@ def on_load_checkpoint(

"""

def on_checkpoint_write_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", filepath: str) -> None:
r"""Called after a checkpoint file has been fully written to disk.

Use this hook to perform any post-save actions such as logging, uploading, or cleanup.

Args:
trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance.
filepath: The path to the checkpoint file that was written.

"""

def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:
"""Called before ``loss.backward()``."""

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
on_exception: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_checkpoint_write_end: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
Expand Down
24 changes: 24 additions & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,27 @@ def on_save_checkpoint(self, checkpoint):
There is no need for you to store anything about training.

"""

def on_checkpoint_write_end(self, filepath: str) -> None:
r"""Called after a checkpoint file has been fully written to disk.

This hook is triggered after the checkpoint saving process completes,
ensuring the file exists and is readable. Unlike :meth:`on_save_checkpoint`,
which is called before the checkpoint is written, this hook guarantees
the file is available on disk.

Args:
filepath: Path to the checkpoint file that was written.

Example::

class MyModel(LightningModule):
def on_checkpoint_write_end(self, filepath):
print(f"Checkpoint saved at: {filepath}")
upload_to_s3(filepath)

Note:
In distributed training, this hook is called on all ranks after
the barrier synchronization completes.

"""
3 changes: 3 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,9 @@ def save_checkpoint(
self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
self.strategy.barrier("Trainer.save_checkpoint")

call._call_callback_hooks(self, "on_checkpoint_write_end", filepath)
call._call_lightning_module_hook(self, "on_checkpoint_write_end", filepath)

"""
State properties
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

import torch

from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel


class CheckpointWriteEndCallback(Callback):
def __init__(self):
self.called = False
self.filepath = None
self.file_existed = False
self.checkpoint_valid = False

def on_checkpoint_write_end(self, trainer, pl_module, filepath):
"""Verify that the hook triggers after checkpoint is written."""
self.called = True
self.filepath = str(filepath)
self.file_existed = os.path.exists(filepath)

checkpoint = torch.load(filepath, map_location="cpu")
self.checkpoint_valid = "state_dict" in checkpoint


def test_on_checkpoint_write_end_called(tmp_path):
"""Test that on_checkpoint_write_end is called after saving a checkpoint."""
model = BoringModel()
callback = CheckpointWriteEndCallback()
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, callbacks=[callback], logger=False)

trainer.fit(model)

checkpoint_path = tmp_path / "test_checkpoint.ckpt"
trainer.save_checkpoint(checkpoint_path)

assert checkpoint_path.exists()
assert callback.called
assert callback.file_existed
assert callback.checkpoint_valid
assert callback.filepath == str(checkpoint_path)
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_fx_validator():
"on_sanity_check_start",
"state_dict",
"on_save_checkpoint",
"on_checkpoint_write_end",
"on_test_batch_end",
"on_test_batch_start",
"on_test_end",
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_fx_validator():
"on_fit_start",
"on_exception",
"on_load_checkpoint",
"on_checkpoint_write_end",
"load_state_dict",
"on_sanity_check_end",
"on_sanity_check_start",
Expand Down