Skip to content

Commit 3ed9d4e

Browse files
littlebullGitSkafteNickiBordapre-commit-ci[bot]
authored
Fix LR not being correctly set after using LearningRateFinder callback (#21068)
* fix(tuner/lr_finder): apply LR suggestion after checkpoint restore when used as callback Previously, LearningRateFinder applied the suggested LR before restoring the checkpoint, so the optimizer LR was reverted by the restore step. This caused the callback to print “Learning rate set to …” without persisting the change. Change: - Move LR application to after checkpoint restore and update both the LM attr and active optimizer param groups so the LR persists for training. Tests: - Add unit test [test_lr_finder_callback_applies_lr_after_restore] to assert the optimizer LR matches the LR Finder suggestion after the search completes. * changelog * Apply suggestions from code review --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 119a640 commit 3ed9d4e

File tree

3 files changed

+89
-9
lines changed

3 files changed

+89
-9
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
-
2929

3030

31+
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))
32+
3133
---
3234

3335
## [2.5.3] - 2025-08-13

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,24 +276,30 @@ def _lr_find(
276276
if trainer.progress_bar_callback:
277277
trainer.progress_bar_callback.enable()
278278

279-
# Update lr attr if required
279+
# Update results across ranks
280280
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
if update_attr:
282-
lr = lr_finder.suggestion()
283-
284-
# TODO: log lr.results to self.logger
285-
if lr is not None:
286-
lightning_setattr(model, attr_name, lr)
287-
log.info(f"Learning rate set to {lr}")
288281

289-
# Restore initial state of model
282+
# Restore initial state of model (this will also restore the original optimizer state)
290283
trainer._checkpoint_connector.restore(ckpt_path)
291284
trainer.strategy.remove_checkpoint(ckpt_path)
292285
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
293286
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
294287
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
295288
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
296289
trainer.fit_loop.setup_data()
290+
291+
# Apply LR suggestion after restoring so it persists for the real training run
292+
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293+
if update_attr:
294+
lr = lr_finder.suggestion()
295+
if lr is not None:
296+
# update the attribute on the LightningModule (e.g., lr or learning_rate)
297+
lightning_setattr(model, attr_name, lr)
298+
# also update the currently active optimizer(s) so training continues with the suggested LR
299+
for opt in trainer.optimizers or []:
300+
for pg in opt.param_groups:
301+
pg["lr"] = lr
302+
log.info(f"Learning rate set to {lr}")
297303
return lr_finder
298304

299305

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,78 @@ def test_gradient_correctness():
619619
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"
620620

621621

622+
def test_lr_finder_callback_applies_lr_after_restore(tmp_path):
623+
"""LearningRateFinder used as a callback should apply its suggested LR to the optimizer used after state
624+
restoration."""
625+
626+
import torch.nn as nn
627+
import torch.nn.functional as F
628+
from torch.utils.data import DataLoader, Dataset
629+
630+
from lightning.pytorch.callbacks import LearningRateMonitor
631+
632+
class RandomDataset(Dataset):
633+
def __init__(self, n: int = 256, in_dim: int = 28 * 28):
634+
self.x = torch.randn(n, in_dim)
635+
self.y = torch.randn(n, in_dim)
636+
637+
def __len__(self) -> int:
638+
return len(self.x)
639+
640+
def __getitem__(self, idx):
641+
return self.x[idx], self.y[idx]
642+
643+
class TinyAE(BoringModel):
644+
def __init__(self, lr: float = 1e-5):
645+
super().__init__()
646+
self.save_hyperparameters()
647+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
648+
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
649+
650+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
651+
x, y = batch
652+
z = self.encoder(x)
653+
x_hat = self.decoder(z)
654+
loss = F.mse_loss(x_hat, y)
655+
return loss
656+
657+
def configure_optimizers(self):
658+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
659+
660+
seed_everything(123)
661+
662+
ds = RandomDataset(n=512)
663+
train_loader = DataLoader(ds, batch_size=64, shuffle=False)
664+
665+
model = TinyAE(lr=1e-5)
666+
667+
lr_finder_cb = LearningRateFinder() # default update_attr=True should apply suggestion
668+
lr_monitor = LearningRateMonitor(logging_interval="step")
669+
670+
trainer = Trainer(
671+
default_root_dir=tmp_path,
672+
max_epochs=2,
673+
callbacks=[lr_finder_cb, lr_monitor],
674+
enable_model_summary=False,
675+
enable_progress_bar=False,
676+
log_every_n_steps=1,
677+
)
678+
679+
trainer.fit(model, train_loader)
680+
assert model.hparams.lr is not None
681+
# Ensure LR Finder produced a suggestion for this setup; if not, the test can't assert application
682+
assert lr_finder_cb.optimal_lr is not None, "LR Finder should have computed results"
683+
suggestion = lr_finder_cb.optimal_lr.suggestion()
684+
assert suggestion is not None, "LR Finder should produce a suggestion for this setup"
685+
686+
# Verify that the optimizer used for subsequent training has the suggested LR applied
687+
assert trainer.optimizers, "Trainer should have an optimizer after fit"
688+
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
689+
assert current_lr == pytest.approx(suggestion), (
690+
f"LR Finder suggestion {suggestion} should be applied to optimizer, but got {current_lr}"
691+
)
692+
693+
622694
def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
623695
"""Test that exponential and linear modes produce different but valid suggestions.
624696

0 commit comments

Comments
 (0)