|
23 | 23 | from lightning_utilities.test.warning import no_warning_call
|
24 | 24 |
|
25 | 25 | from lightning.pytorch import Trainer, seed_everything
|
| 26 | +from lightning.pytorch.callbacks import EarlyStopping |
26 | 27 | from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
|
27 | 28 | from lightning.pytorch.demos.boring_classes import BoringModel
|
28 | 29 | from lightning.pytorch.tuner.lr_finder import _LRFinder
|
@@ -540,6 +541,67 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
|
540 | 541 | assert math.isclose(model.lr, suggested_lr)
|
541 | 542 |
|
542 | 543 |
|
| 544 | +def test_lr_finder_with_early_stopping(tmp_path): |
| 545 | + class ModelWithValidation(BoringModel): |
| 546 | + def __init__(self): |
| 547 | + super().__init__() |
| 548 | + self.learning_rate = 0.1 |
| 549 | + |
| 550 | + def validation_step(self, batch, batch_idx): |
| 551 | + output = self.step(batch) |
| 552 | + # Log validation loss that EarlyStopping will monitor |
| 553 | + self.log("val_loss", output, on_epoch=True) |
| 554 | + return output |
| 555 | + |
| 556 | + def configure_optimizers(self): |
| 557 | + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
| 558 | + |
| 559 | + # Add ReduceLROnPlateau scheduler that monitors val_loss (issue #20355) |
| 560 | + plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| 561 | + optimizer, mode="min", factor=0.5, patience=2 |
| 562 | + ) |
| 563 | + scheduler_config = {"scheduler": plateau_scheduler, "interval": "epoch", "monitor": "val_loss"} |
| 564 | + |
| 565 | + return {"optimizer": optimizer, "lr_scheduler": scheduler_config} |
| 566 | + |
| 567 | + model = ModelWithValidation() |
| 568 | + |
| 569 | + # Both callbacks that previously caused issues |
| 570 | + callbacks = [ |
| 571 | + LearningRateFinder(num_training_steps=100, update_attr=False), |
| 572 | + EarlyStopping(monitor="val_loss", patience=3), |
| 573 | + ] |
| 574 | + |
| 575 | + trainer = Trainer( |
| 576 | + default_root_dir=tmp_path, |
| 577 | + max_epochs=10, |
| 578 | + callbacks=callbacks, |
| 579 | + limit_train_batches=5, |
| 580 | + limit_val_batches=3, |
| 581 | + enable_model_summary=False, |
| 582 | + enable_progress_bar=False, |
| 583 | + ) |
| 584 | + |
| 585 | + trainer.fit(model) |
| 586 | + assert trainer.state.finished |
| 587 | + |
| 588 | + # Verify that both callbacks were active |
| 589 | + lr_finder_callback = None |
| 590 | + early_stopping_callback = None |
| 591 | + for callback in trainer.callbacks: |
| 592 | + if isinstance(callback, LearningRateFinder): |
| 593 | + lr_finder_callback = callback |
| 594 | + elif isinstance(callback, EarlyStopping): |
| 595 | + early_stopping_callback = callback |
| 596 | + |
| 597 | + assert lr_finder_callback is not None, "LearningRateFinder callback should be present" |
| 598 | + assert early_stopping_callback is not None, "EarlyStopping callback should be present" |
| 599 | + |
| 600 | + # Verify learning rate finder ran and has results |
| 601 | + assert lr_finder_callback.optimal_lr is not None, "Learning rate finder should have results" |
| 602 | + assert lr_finder_callback.optimal_lr.suggestion() > 0, "Learning rate suggestion should be positive" |
| 603 | + |
| 604 | + |
543 | 605 | def test_gradient_correctness():
|
544 | 606 | """Test that torch.gradient uses correct spacing parameter."""
|
545 | 607 | lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)
|
|
0 commit comments