-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add save_on_exception
option to ModelCheckpoint
#20916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…ining part of callbacks in individal test for better overview
…lidation callback
for more information, see https://pre-commit.ci
…sly defined epoch lenght
6249794
to
f0502ec
Compare
…ntefere with current checkpoint behavoir
…in ModelCheckpoint
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20916 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 268 265 -3
Lines 23442 23399 -43
=========================================
- Hits 20394 18398 -1996
- Misses 3048 5001 +1953 |
save_on_exception
option to ModelCheckpoint
…eaning of unused empty variable in function signutare
def test_model_checkpoint_on_exception_run_condition(tmp_path): | ||
"""Test that the checkpoint is saved when an exception is raised in a lightning module.""" | ||
|
||
# Don't save checkpoint if sanity check fails | ||
class TroubledModelSanityCheck(BoringModel): | ||
def on_validation_start(self) -> None: | ||
if self.trainer.sanity_checking: | ||
print("Trouble!") | ||
raise RuntimeError("Trouble!") | ||
|
||
model = TroubledModelSanityCheck() | ||
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True) | ||
trainer = Trainer( | ||
default_root_dir=tmp_path, | ||
num_sanity_val_steps=4, | ||
limit_train_batches=2, | ||
callbacks=[checkpoint_callback], | ||
max_epochs=2, | ||
logger=False, | ||
) | ||
|
||
with pytest.raises(RuntimeError, match="Trouble!"): | ||
trainer.fit(model) | ||
assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt") | ||
|
||
# Don't save checkpoint if fast dev run fails | ||
class TroubledModelFastDevRun(BoringModel): | ||
def on_train_batch_start(self, batch, batch_idx) -> None: | ||
if self.trainer.fast_dev_run and batch_idx == 1: | ||
raise RuntimeError("Trouble!") | ||
|
||
model = TroubledModelFastDevRun() | ||
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True) | ||
trainer = Trainer( | ||
default_root_dir=tmp_path, | ||
fast_dev_run=2, | ||
limit_train_batches=2, | ||
callbacks=[checkpoint_callback], | ||
max_epochs=2, | ||
logger=False, | ||
) | ||
|
||
with pytest.raises(RuntimeError, match="Trouble!"): | ||
trainer.fit(model) | ||
assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt") | ||
|
||
# Don't save checkpoint if already saved a checkpoint | ||
class TroubledModelAlreadySavedCheckpoint(BoringModel): | ||
def on_train_batch_start(self, batch, batch_idx) -> None: | ||
if self.trainer.global_step == 1: | ||
raise RuntimeError("Trouble!") | ||
|
||
model = TroubledModelAlreadySavedCheckpoint() | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1 | ||
) | ||
trainer = Trainer( | ||
default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False | ||
) | ||
|
||
with pytest.raises(RuntimeError, match="Trouble!"): | ||
trainer.fit(model) | ||
|
||
assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt") | ||
assert os.path.isfile(tmp_path / "already_saved.ckpt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_model_checkpoint_on_exception_run_condition(tmp_path): | |
"""Test that the checkpoint is saved when an exception is raised in a lightning module.""" | |
# Don't save checkpoint if sanity check fails | |
class TroubledModelSanityCheck(BoringModel): | |
def on_validation_start(self) -> None: | |
if self.trainer.sanity_checking: | |
print("Trouble!") | |
raise RuntimeError("Trouble!") | |
model = TroubledModelSanityCheck() | |
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True) | |
trainer = Trainer( | |
default_root_dir=tmp_path, | |
num_sanity_val_steps=4, | |
limit_train_batches=2, | |
callbacks=[checkpoint_callback], | |
max_epochs=2, | |
logger=False, | |
) | |
with pytest.raises(RuntimeError, match="Trouble!"): | |
trainer.fit(model) | |
assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt") | |
# Don't save checkpoint if fast dev run fails | |
class TroubledModelFastDevRun(BoringModel): | |
def on_train_batch_start(self, batch, batch_idx) -> None: | |
if self.trainer.fast_dev_run and batch_idx == 1: | |
raise RuntimeError("Trouble!") | |
model = TroubledModelFastDevRun() | |
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True) | |
trainer = Trainer( | |
default_root_dir=tmp_path, | |
fast_dev_run=2, | |
limit_train_batches=2, | |
callbacks=[checkpoint_callback], | |
max_epochs=2, | |
logger=False, | |
) | |
with pytest.raises(RuntimeError, match="Trouble!"): | |
trainer.fit(model) | |
assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt") | |
# Don't save checkpoint if already saved a checkpoint | |
class TroubledModelAlreadySavedCheckpoint(BoringModel): | |
def on_train_batch_start(self, batch, batch_idx) -> None: | |
if self.trainer.global_step == 1: | |
raise RuntimeError("Trouble!") | |
model = TroubledModelAlreadySavedCheckpoint() | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1 | |
) | |
trainer = Trainer( | |
default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False | |
) | |
with pytest.raises(RuntimeError, match="Trouble!"): | |
trainer.fit(model) | |
assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt") | |
assert os.path.isfile(tmp_path / "already_saved.ckpt") |
This seems to be the same as the parametrization bellow
What does this PR do?
This PR adds a
save_on_exception
option to the ModelCheckpoint callback. This some of this functionality is already implemented in theOnExceptionCheckpoint
checkpoint, but I believe that bundling all checkpoint options in the ModelCheckpoint is more intuitive. Additionally, this leads to the same naming conventions and directory paths used for the exception checkpoint as for all the others.When enabled, this option serves as a contingency in case of any disruption during training, allowing one to continue from the last step before the exception occurs without losing too much progress. By printing the exception type and message, this also alleviates issue #20187.
Fixes #19686
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20916.org.readthedocs.build/en/20916/