Skip to content

Commit 2676332

Browse files
authored
Fix integration of LearningRateFinder with EarlyStopping (#21056)
* reset data fetching after learning rate finder * reinit at the same time to still support multiclass lrfind
1 parent 6f93a90 commit 2676332

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))
4444

4545

46+
- Fixed integration between `LearningRateFinder` and `EarlyStopping` ([#21056](https://github.com/Lightning-AI/pytorch-lightning/pull/21056))
47+
48+
4649
- Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055))
4750

4851

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def _lr_find(
292292
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
293293
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
294294
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
295-
295+
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
296+
trainer.fit_loop.setup_data()
296297
return lr_finder
297298

298299

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightning_utilities.test.warning import no_warning_call
2424

2525
from lightning.pytorch import Trainer, seed_everything
26+
from lightning.pytorch.callbacks import EarlyStopping
2627
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
2728
from lightning.pytorch.demos.boring_classes import BoringModel
2829
from lightning.pytorch.tuner.lr_finder import _LRFinder
@@ -540,6 +541,67 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
540541
assert math.isclose(model.lr, suggested_lr)
541542

542543

544+
def test_lr_finder_with_early_stopping(tmp_path):
545+
class ModelWithValidation(BoringModel):
546+
def __init__(self):
547+
super().__init__()
548+
self.learning_rate = 0.1
549+
550+
def validation_step(self, batch, batch_idx):
551+
output = self.step(batch)
552+
# Log validation loss that EarlyStopping will monitor
553+
self.log("val_loss", output, on_epoch=True)
554+
return output
555+
556+
def configure_optimizers(self):
557+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
558+
559+
# Add ReduceLROnPlateau scheduler that monitors val_loss (issue #20355)
560+
plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
561+
optimizer, mode="min", factor=0.5, patience=2
562+
)
563+
scheduler_config = {"scheduler": plateau_scheduler, "interval": "epoch", "monitor": "val_loss"}
564+
565+
return {"optimizer": optimizer, "lr_scheduler": scheduler_config}
566+
567+
model = ModelWithValidation()
568+
569+
# Both callbacks that previously caused issues
570+
callbacks = [
571+
LearningRateFinder(num_training_steps=100, update_attr=False),
572+
EarlyStopping(monitor="val_loss", patience=3),
573+
]
574+
575+
trainer = Trainer(
576+
default_root_dir=tmp_path,
577+
max_epochs=10,
578+
callbacks=callbacks,
579+
limit_train_batches=5,
580+
limit_val_batches=3,
581+
enable_model_summary=False,
582+
enable_progress_bar=False,
583+
)
584+
585+
trainer.fit(model)
586+
assert trainer.state.finished
587+
588+
# Verify that both callbacks were active
589+
lr_finder_callback = None
590+
early_stopping_callback = None
591+
for callback in trainer.callbacks:
592+
if isinstance(callback, LearningRateFinder):
593+
lr_finder_callback = callback
594+
elif isinstance(callback, EarlyStopping):
595+
early_stopping_callback = callback
596+
597+
assert lr_finder_callback is not None, "LearningRateFinder callback should be present"
598+
assert early_stopping_callback is not None, "EarlyStopping callback should be present"
599+
600+
# Verify learning rate finder ran and has results
601+
assert lr_finder_callback.optimal_lr is not None, "Learning rate finder should have results"
602+
assert lr_finder_callback.optimal_lr.suggestion() > 0, "Learning rate suggestion should be positive"
603+
604+
543605
def test_gradient_correctness():
544606
"""Test that torch.gradient uses correct spacing parameter."""
545607
lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)

0 commit comments

Comments
 (0)