diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 81c7bfc656885..aa1841935d35e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) +- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032)) + --- ## [2.5.2] - 2025-06-20 diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index ff092fa99d825..e6f61174fb987 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -552,12 +552,12 @@ def on_train_batch_end( # can happen when resuming from a mid-epoch restart self._initialize_train_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) - self._update_metrics(trainer, pl_module) + self._update_metrics(trainer, pl_module, batch_idx + 1) self.refresh() @override def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._update_metrics(trainer, pl_module) + self._update_metrics(trainer, pl_module, total_batches=True) @override def on_validation_batch_end( @@ -632,7 +632,21 @@ def _reset_progress_bar_ids(self) -> None: self.test_progress_bar_id = None self.predict_progress_bar_id = None - def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def _update_metrics( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + current: Optional[int] = None, + total_batches: bool = False, + ) -> None: + if not self.is_enabled or self._metric_component is None: + return + + if current is not None and not total_batches: + total = self.total_train_batches + if not self._should_update(current, total): + return + metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 430fb9842cddc..639414a797aa0 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -246,6 +246,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch with mock.patch.object( trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update ) as progress_update: + metrics_update = mock.MagicMock() + trainer.progress_bar_callback._update_metrics = metrics_update + trainer.fit(model) assert progress_update.call_count == expected_call_count @@ -260,6 +263,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch assert fit_val_bar.total == val_batches assert not fit_val_bar.visible + # one call for each train batch + one at the end of training epoch + one for validation end + assert metrics_update.call_count == train_batches + (1 if train_batches > 0 else 0) + (1 if val_batches > 0 else 0) + @RunIf(rich=True) @pytest.mark.parametrize("limit_val_batches", [1, 5])