diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index 64e2ce30..ead7420b 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -1,10 +1,12 @@ from typing import Any, Dict, Optional +from pytorch_pfn_extras._torch_version import requires from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module from pytorch_pfn_extras.training._manager_protocol import ( ExtensionsManagerProtocol, ) +from torch.optim import Optimizer from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -33,6 +35,22 @@ def _default_stepper( scheduler.step() +def check_optimizer_is_called(optimizer: Optimizer) -> bool: + if requires("2.4.0.dev"): + # https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/optim/lr_scheduler.py#L172-L191 + # TODO: Rewrite this URL when pytorch 2.4.0 is released. + if hasattr(optimizer.step, "_wrapped_by_lr_sched"): + return getattr(optimizer, "_opt_called", False) + else: + return True + else: + # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 + if hasattr(optimizer.step, "_with_counter"): + return bool(optimizer._step_count >= 1) # type: ignore[attr-defined] + else: + return True + + class LRScheduler(extension.Extension): """Trainer extension to adjust the learning rate using PyTorch's learning rate scheduler. @@ -72,8 +90,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 if ( self.wait_for_first_optimizer_step - and hasattr(self.scheduler.optimizer.step, "_with_counter") - and self.scheduler.optimizer._step_count < 1 + and not check_optimizer_is_called(self.scheduler.optimizer) ): return self.stepper(manager, self.scheduler)