diff --git a/pina/callback/optimizer_callback.py b/pina/callback/optimizer_callback.py index 1b518406b..6455bf861 100644 --- a/pina/callback/optimizer_callback.py +++ b/pina/callback/optimizer_callback.py @@ -62,7 +62,10 @@ def on_train_epoch_start(self, trainer, __): # Hook the new optimizers to the model parameters for idx, optim in enumerate(self._new_optimizers): - optim.hook(trainer.solver._pina_models[idx].parameters()) + optim._register_hooks( + parameters=trainer.solver._pina_models[idx].parameters(), + solver=trainer.solver, + ) optims.append(optim) # Update the solver's optimizers diff --git a/pina/optim/__init__.py b/pina/optim/__init__.py index 8266c8ca1..562f452be 100644 --- a/pina/optim/__init__.py +++ b/pina/optim/__init__.py @@ -1,13 +1,9 @@ """Module for the Optimizers and Schedulers.""" __all__ = [ - "Optimizer", "TorchOptimizer", - "Scheduler", "TorchScheduler", ] -from .optimizer_interface import Optimizer from .torch_optimizer import TorchOptimizer -from .scheduler_interface import Scheduler -from .torch_scheduler import TorchScheduler +from .torch_scheduler import TorchScheduler \ No newline at end of file diff --git a/pina/optim/core/optim_connector_interface.py b/pina/optim/core/optim_connector_interface.py new file mode 100644 index 000000000..a6b36b3cc --- /dev/null +++ b/pina/optim/core/optim_connector_interface.py @@ -0,0 +1,98 @@ +"""Module for the PINA Optimizer and Scheduler Connectors Interface.""" + +from abc import ABCMeta, abstractmethod +from functools import wraps + + +class OptimizerConnectorInterface(metaclass=ABCMeta): + """ + Interface class for method definitions in the Optimizer classes. + """ + + @abstractmethod + def parameter_hook(self, parameters): + """ + Abstract method to define the hook logic for the optimizer. This hook + is used to initialize the optimizer instance with the given parameters. + + :param dict parameters: The parameters of the model to be optimized. + """ + + @abstractmethod + def solver_hook(self, solver): + """ + Abstract method to define the hook logic for the optimizer. This hook + is used to hook the optimizer instance with the given parameters. + + :param SolverInterface solver: The solver to hook. + """ + + +class SchedulerConnectorInterface(metaclass=ABCMeta): + """ + Abstract base class for defining a scheduler. All specific schedulers should + inherit form this class and implement the required methods. + """ + + @abstractmethod + def optimizer_hook(self): + """ + Abstract method to define the hook logic for the scheduler. This hook + is used to hook the scheduler instance with the given optimizer. + """ + + +class _HooksOptim: + """ + Mixin class to manage and track the execution of hook methods in optimizer + or scheduler classes. + + This class automatically detects methods ending with `_hook` and tracks + whether they have been executed for a given instance. Subclasses defining + `_hook` methods benefit from automatic tracking without additional + boilerplate. + """ + def __init__(self, *args, **kwargs): + """ + Initialize the hooks tracking dictionary `hooks_done` for this instance. + + Each hook method detected in the class hierarchy is added to + `hooks_done` with an initial value of False (not executed). + """ + super().__init__(*args, **kwargs) + # Initialize hooks_done per instance + self.hooks_done = {} + for cls in self.__class__.__mro__: + for attr_name, attr_value in cls.__dict__.items(): + if callable(attr_value) and attr_name.endswith("_hook"): + self.hooks_done.setdefault(attr_name, False) + + def __init_subclass__(cls, **kwargs): + """ + Hook called when a subclass of _HooksOptim is created. + + Wraps all concrete `_hook` methods defined in the subclass so that + executing the method automatically updates `hooks_done`. + """ + super().__init_subclass__(**kwargs) + # Wrap only concrete _hook methods defined in this subclass + for attr_name, attr_value in cls.__dict__.items(): + if callable(attr_value) and attr_name.endswith("_hook"): + setattr(cls, attr_name, cls.hook_wrapper(attr_name, attr_value)) + + @staticmethod + def hook_wrapper(name, func): + """ + Wrap a hook method to mark it as executed after calling it. + + :param str name: The name of the hook method. + :param callable func: The original hook method to wrap. + :return: The wrapped hook method that updates `hooks_done`. + """ + @wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + self.hooks_done[name] = True + return result + + return wrapper diff --git a/pina/optim/core/optimizer_connector.py b/pina/optim/core/optimizer_connector.py new file mode 100644 index 000000000..75e1624e7 --- /dev/null +++ b/pina/optim/core/optimizer_connector.py @@ -0,0 +1,100 @@ +"""Module for the PINA Optimizer.""" + +from .optim_connector_interface import OptimizerConnectorInterface, _HooksOptim + + +class OptimizerConnector(OptimizerConnectorInterface, _HooksOptim): + """ + Abstract base class for defining an optimizer connector. All specific + optimizers connectors should inherit form this class and implement the + required methods. + """ + + def __init__(self, optimizer_class, **optimizer_class_kwargs): + """ + Initialize connector parameters + + :param torch.optim.Optimizer optimizer_class: The torch optimizer class. + :param dict optimizer_class_kwargs: The optimizer kwargs. + """ + super().__init__() + self._optimizer_class = optimizer_class + self._optimizer_instance = None + self._optim_kwargs = optimizer_class_kwargs + self._solver = None + + def parameter_hook(self, parameters): + """ + Abstract method to define the hook logic for the optimizer. This hook + is used to initialize the optimizer instance with the given parameters. + + :param dict parameters: The parameters of the model to be optimized. + """ + self._optimizer_instance = self._optimizer_class( + parameters, **self._optim_kwargs + ) + + def solver_hook(self, solver): + """ + Method to define the hook logic for the optimizer. This hook + is used to hook the optimizer instance with the given parameters. + + :param SolverInterface solver: The solver to hook. + """ + if not self.hooks_done["parameter_hook"]: + raise RuntimeError( + "Cannot run 'solver_hook' before 'parameter_hook'. " + "Please call 'parameter_hook' first to initialize " + "the solver parameters." + ) + # hook to both instance and connector the solver + self._solver = solver + self._optimizer_instance.solver = solver + + def _register_hooks(self, **kwargs): + """ + Register the optimizers hooks. This method inspects keyword arguments + for known keys (`parameters`, `solver`, ...) and applies the + corresponding hooks. + + It allows flexible integration with + different workflows without enforcing a strict method signature. + + This method is used inside the + :class:`~pina.solver.solver.SolverInterface` class. + + :param kwargs: Expected keys may include: + - ``parameters``: Parameters to be registered for optimization. + - ``solver``: Solver instance. + """ + # parameter hook + parameters = kwargs.get("parameters", None) + if parameters is not None: + self.parameter_hook(parameters) + # solver hook + solver = kwargs.get("solver", None) + if solver is not None: + self.solver_hook(solver) + + @property + def solver(self): + """ + Get the solver hooked to the optimizer. + """ + if not self.hooks_done["solver_hook"]: + raise RuntimeError( + "Solver has not been hooked." + "Override the method solver_hook to hook the solver to " + "the optimizer." + ) + return self._solver + + @property + def instance(self): + """ + Get the optimizer instance. + + :return: The optimizer instance + :rtype: torch.optim.Optimizer + """ + return self._optimizer_instance diff --git a/pina/optim/core/scheduler_connector.py b/pina/optim/core/scheduler_connector.py new file mode 100644 index 000000000..a8a877706 --- /dev/null +++ b/pina/optim/core/scheduler_connector.py @@ -0,0 +1,75 @@ +"""Module for the PINA Scheduler.""" + +from .optim_connector_interface import SchedulerConnectorInterface, _HooksOptim +from .optimizer_connector import OptimizerConnector +from ...utils import check_consistency + + +class SchedulerConnector(SchedulerConnectorInterface, _HooksOptim): + """ + Class for defining a scheduler connector. All specific schedulers connectors + should inherit form this class and implement the required methods. + """ + + def __init__(self, scheduler_class, **scheduler_kwargs): + """ + Initialize connector parameters + + :param torch.optim.lr_scheduler.LRScheduler scheduler_class: The torch + scheduler class. + :param dict scheduler_kwargs: The scheduler kwargs. + """ + super().__init__() + self._scheduler_class = scheduler_class + self._scheduler_instance = None + self._scheduler_kwargs = scheduler_kwargs + + def optimizer_hook(self, optimizer): + """ + Abstract method to define the hook logic for the scheduler. This hook + is used to hook the scheduler instance with the given optimizer. + + :param Optimizer optimizer: The optimizer to hook. + """ + check_consistency(optimizer, OptimizerConnector) + if not optimizer.hooks_done["parameter_hook"]: + raise RuntimeError( + "Scheduler cannot be set, Optimizer not hooked " + "to model parameters. " + "Please call Optimizer.parameter_hook()." + ) + self._scheduler_instance = self._scheduler_class( + optimizer.instance, **self._scheduler_kwargs + ) + + def _register_hooks(self, **kwargs): + """ + Register the optimizers hooks. This method inspects keyword arguments + for known keys (`parameters`, `solver`, ...) and applies the + corresponding hooks. + + It allows flexible integration with + different workflows without enforcing a strict method signature. + + This method is used inside the + :class:`~pina.solver.solver.SolverInterface` class. + + :param kwargs: Expected keys may include: + - ``parameters``: Parameters to be registered for optimization. + - ``solver``: Solver instance. + """ + # optimizer hook + optimizer = kwargs.get("optimizer", None) + if optimizer is not None: + check_consistency(optimizer, OptimizerConnector) + self.optimizer_hook(optimizer) + + @property + def instance(self): + """ + Get the scheduler instance. + + :return: The scheduler instance + :rtype: torch.optim.lr_scheduler.LRScheduler + """ + return self._scheduler_instance diff --git a/pina/optim/optimizer_interface.py b/pina/optim/optimizer_interface.py deleted file mode 100644 index 5f2fbe66a..000000000 --- a/pina/optim/optimizer_interface.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Module for the PINA Optimizer.""" - -from abc import ABCMeta, abstractmethod - - -class Optimizer(metaclass=ABCMeta): - """ - Abstract base class for defining an optimizer. All specific optimizers - should inherit form this class and implement the required methods. - """ - - @property - @abstractmethod - def instance(self): - """ - Abstract property to retrieve the optimizer instance. - """ - - @abstractmethod - def hook(self): - """ - Abstract method to define the hook logic for the optimizer. - """ diff --git a/pina/optim/scheduler_interface.py b/pina/optim/scheduler_interface.py deleted file mode 100644 index 5ae5d8b99..000000000 --- a/pina/optim/scheduler_interface.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Module for the PINA Scheduler.""" - -from abc import ABCMeta, abstractmethod - - -class Scheduler(metaclass=ABCMeta): - """ - Abstract base class for defining a scheduler. All specific schedulers should - inherit form this class and implement the required methods. - """ - - @property - @abstractmethod - def instance(self): - """ - Abstract property to retrieve the scheduler instance. - """ - - @abstractmethod - def hook(self): - """ - Abstract method to define the hook logic for the scheduler. - """ diff --git a/pina/optim/torch_optimizer.py b/pina/optim/torch_optimizer.py index 7163c295e..a3fd404f3 100644 --- a/pina/optim/torch_optimizer.py +++ b/pina/optim/torch_optimizer.py @@ -3,15 +3,15 @@ import torch from ..utils import check_consistency -from .optimizer_interface import Optimizer +from .core.optimizer_connector import OptimizerConnector -class TorchOptimizer(Optimizer): +class TorchOptimizer(OptimizerConnector): """ A wrapper class for using PyTorch optimizers. """ - def __init__(self, optimizer_class, **kwargs): + def __init__(self, optimizer_class, **optimizer_class_kwargs): """ Initialization of the :class:`TorchOptimizer` class. @@ -21,28 +21,7 @@ def __init__(self, optimizer_class, **kwargs): see more `here `_. """ + # external checks check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) - - self.optimizer_class = optimizer_class - self.kwargs = kwargs - self._optimizer_instance = None - - def hook(self, parameters): - """ - Initialize the optimizer instance with the given parameters. - - :param dict parameters: The parameters of the model to be optimized. - """ - self._optimizer_instance = self.optimizer_class( - parameters, **self.kwargs - ) - - @property - def instance(self): - """ - Get the optimizer instance. - - :return: The optimizer instance. - :rtype: torch.optim.Optimizer - """ - return self._optimizer_instance + check_consistency(optimizer_class_kwargs, dict) + super().__init__(optimizer_class, **optimizer_class_kwargs) diff --git a/pina/optim/torch_scheduler.py b/pina/optim/torch_scheduler.py index ff12300a1..5d8038bc6 100644 --- a/pina/optim/torch_scheduler.py +++ b/pina/optim/torch_scheduler.py @@ -1,5 +1,7 @@ """Module for the PINA Torch Optimizer""" +import copy + try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 except ImportError: @@ -8,16 +10,15 @@ ) # torch < 2.0 from ..utils import check_consistency -from .optimizer_interface import Optimizer -from .scheduler_interface import Scheduler +from .core.scheduler_connector import SchedulerConnector -class TorchScheduler(Scheduler): +class TorchScheduler(SchedulerConnector): """ A wrapper class for using PyTorch schedulers. """ - def __init__(self, scheduler_class, **kwargs): + def __init__(self, scheduler_class, **scheduler_kwargs): """ Initialization of the :class:`TorchScheduler` class. @@ -28,28 +29,5 @@ def __init__(self, scheduler_class, **kwargs): `here _`. """ check_consistency(scheduler_class, LRScheduler, subclass=True) - - self.scheduler_class = scheduler_class - self.kwargs = kwargs - self._scheduler_instance = None - - def hook(self, optimizer): - """ - Initialize the scheduler instance with the given parameters. - - :param dict parameters: The parameters of the optimizer. - """ - check_consistency(optimizer, Optimizer) - self._scheduler_instance = self.scheduler_class( - optimizer.instance, **self.kwargs - ) - - @property - def instance(self): - """ - Get the scheduler instance. - - :return: The scheduelr instance. - :rtype: torch.optim.LRScheduler - """ - return self._scheduler_instance + check_consistency(scheduler_kwargs, dict) + super().__init__(scheduler_class, **scheduler_kwargs) diff --git a/pina/solver/ensemble_solver/ensemble_solver_interface.py b/pina/solver/ensemble_solver/ensemble_solver_interface.py index 6d874e1bf..a64e78878 100644 --- a/pina/solver/ensemble_solver/ensemble_solver_interface.py +++ b/pina/solver/ensemble_solver/ensemble_solver_interface.py @@ -120,15 +120,15 @@ def training_step(self, batch): """ # zero grad for optimizer for opt in self.optimizers: - opt.instance.zero_grad() + opt.zero_grad() # perform forward passes and aggregate losses loss = super().training_step(batch) # perform backpropagation self.manual_backward(loss) # optimize for opt, sched in zip(self.optimizers, self.schedulers): - opt.instance.step() - sched.instance.step() + opt.step() + sched.step() return loss @property diff --git a/pina/solver/garom.py b/pina/solver/garom.py index 372eeddfa..8f646f882 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -151,7 +151,7 @@ def _train_generator(self, parameters, snapshots): :return: The residual loss and the generator loss. :rtype: tuple[torch.Tensor, torch.Tensor] """ - self.optimizer_generator.instance.zero_grad() + self.optimizer_generator.zero_grad() # Generate a batch of images generated_snapshots = self.sample(parameters) @@ -166,8 +166,8 @@ def _train_generator(self, parameters, snapshots): # backward step g_loss.backward() - self.optimizer_generator.instance.step() - self.scheduler_generator.instance.step() + self.optimizer_generator.step() + self.scheduler_generator.step() return r_loss, g_loss @@ -180,7 +180,7 @@ def _train_discriminator(self, parameters, snapshots): :return: The residual loss and the generator loss. :rtype: tuple[torch.Tensor, torch.Tensor] """ - self.optimizer_discriminator.instance.zero_grad() + self.optimizer_discriminator.zero_grad() # Generate a batch of images generated_snapshots = self.sample(parameters) @@ -196,8 +196,8 @@ def _train_discriminator(self, parameters, snapshots): # backward step d_loss.backward() - self.optimizer_discriminator.instance.step() - self.scheduler_discriminator.instance.step() + self.optimizer_discriminator.step() + self.scheduler_discriminator.step() return d_loss_real, d_loss_fake, d_loss diff --git a/pina/solver/physics_informed_solver/competitive_pinn.py b/pina/solver/physics_informed_solver/competitive_pinn.py index 5375efba1..bab8fff1b 100644 --- a/pina/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/solver/physics_informed_solver/competitive_pinn.py @@ -123,18 +123,18 @@ def training_step(self, batch): :rtype: LabelTensor """ # train model - self.optimizer_model.instance.zero_grad() + self.optimizer_model.zero_grad() loss = super().training_step(batch) self.manual_backward(loss) - self.optimizer_model.instance.step() - self.scheduler_model.instance.step() + self.optimizer_model.step() + self.scheduler_model.step() # train discriminator - self.optimizer_discriminator.instance.zero_grad() + self.optimizer_discriminator.zero_grad() loss = super().training_step(batch) self.manual_backward(-loss) - self.optimizer_discriminator.instance.step() - self.scheduler_discriminator.instance.step() + self.optimizer_discriminator.step() + self.scheduler_discriminator.step() return loss @@ -184,12 +184,10 @@ def configure_optimizers(self): :return: The optimizers and the schedulers :rtype: tuple[list[Optimizer], list[Scheduler]] """ - # If the problem is an InverseProblem, add the unknown parameters - # to the parameters to be optimized - self.optimizer_model.hook(self.neural_net.parameters()) - self.optimizer_discriminator.hook(self.discriminator.parameters()) + super().configure_optimizers() + # Add unknown parameters to optimization list in case of InverseProblem if isinstance(self.problem, InverseProblem): - self.optimizer_model.instance.add_param_group( + self.optimizer_model.add_param_group( { "params": [ self._params[var] @@ -197,19 +195,7 @@ def configure_optimizers(self): ] } ) - self.scheduler_model.hook(self.optimizer_model) - self.scheduler_discriminator.hook(self.optimizer_discriminator) - return ( - [ - self.optimizer_model.instance, - self.optimizer_discriminator.instance, - ], - [ - self.scheduler_model.instance, - self.scheduler_discriminator.instance, - ], - ) - + return self.optimizers, self.schedulers @property def neural_net(self): """ diff --git a/pina/solver/physics_informed_solver/pinn.py b/pina/solver/physics_informed_solver/pinn.py index 914d01451..4cb7dc0b6 100644 --- a/pina/solver/physics_informed_solver/pinn.py +++ b/pina/solver/physics_informed_solver/pinn.py @@ -4,7 +4,6 @@ from .pinn_interface import PINNInterface from ..solver import SingleSolverInterface -from ...problem import InverseProblem class PINN(PINNInterface, SingleSolverInterface): @@ -109,25 +108,3 @@ def loss_phys(self, samples, equation): """ residuals = self.compute_residual(samples, equation) return self._loss_fn(residuals, torch.zeros_like(residuals)) - - def configure_optimizers(self): - """ - Optimizer configuration for the PINN solver. - - :return: The optimizers and the schedulers - :rtype: tuple[list[Optimizer], list[Scheduler]] - """ - # If the problem is an InverseProblem, add the unknown parameters - # to the parameters to be optimized. - self.optimizer.hook(self.model.parameters()) - if isinstance(self.problem, InverseProblem): - self.optimizer.instance.add_param_group( - { - "params": [ - self._params[var] - for var in self.problem.unknown_variables - ] - } - ) - self.scheduler.hook(self.optimizer) - return ([self.optimizer.instance], [self.scheduler.instance]) diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/solver/physics_informed_solver/self_adaptive_pinn.py index b1d2a2cb4..4e2f0ea65 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physics_informed_solver/self_adaptive_pinn.py @@ -183,22 +183,22 @@ def training_step(self, batch, batch_idx, **kwargs): :rtype: torch.Tensor """ # Weights optimization - self.optimizer_weights.instance.zero_grad() + self.optimizer_weights.zero_grad() loss = self._optimization_cycle( batch=batch, batch_idx=batch_idx, **kwargs ) self.manual_backward(-loss) - self.optimizer_weights.instance.step() - self.scheduler_weights.instance.step() + self.optimizer_weights.step() + self.scheduler_weights.step() # Model optimization - self.optimizer_model.instance.zero_grad() + self.optimizer_model.zero_grad() loss = self._optimization_cycle( batch=batch, batch_idx=batch_idx, **kwargs ) self.manual_backward(loss) - self.optimizer_model.instance.step() - self.scheduler_model.instance.step() + self.optimizer_model.step() + self.scheduler_model.step() # Log the loss self.store_log("train_loss", loss, self.get_batch_size(batch)) @@ -297,13 +297,10 @@ def configure_optimizers(self): :return: The optimizers and the schedulers :rtype: tuple[list[Optimizer], list[Scheduler]] """ - # Hook the optimizers to the models - self.optimizer_model.hook(self.model.parameters()) - self.optimizer_weights.hook(self.weights.parameters()) - + super().configure_optimizers() # Add unknown parameters to optimization list in case of InverseProblem if isinstance(self.problem, InverseProblem): - self.optimizer_model.instance.add_param_group( + self.optimizer_model.add_param_group( { "params": [ self._params[var] @@ -311,15 +308,7 @@ def configure_optimizers(self): ] } ) - - # Hook the schedulers to the optimizers - self.scheduler_model.hook(self.optimizer_model) - self.scheduler_weights.hook(self.optimizer_weights) - - return ( - [self.optimizer_model.instance, self.optimizer_weights.instance], - [self.scheduler_model.instance, self.scheduler_weights.instance], - ) + return self.optimizers, self.schedulers def _optimization_cycle(self, batch, batch_idx, **kwargs): """ diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 6948ec664..8aed69bb1 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -6,7 +6,9 @@ from torch._dynamo import OptimizedModule from ..problem import AbstractProblem, InverseProblem -from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler +from ..optim.core.optimizer_connector import OptimizerConnector as Optimizer +from ..optim.core.scheduler_connector import SchedulerConnector as Scheduler +from ..optim import TorchOptimizer, TorchScheduler from ..loss import WeightingInterface from ..loss.scalar_weighting import _NoWeighting from ..utils import check_consistency, labelize_forward @@ -59,6 +61,10 @@ def __init__(self, problem, weighting, use_lt): ) # PINA private attributes (some are overridden by derived classes) + #### self._pina_problem ---> link to AbstractProblem (or derived) + #### self._pina_models ---> link to torch.nn.Module (or derived) + #### self._pina_optimizers ---> link to OptimizerConnector (or derived) + #### self._pina_schedulers ---> link to SchedulerConnector (or derived) self._pina_problem = problem self._pina_models = None self._pina_optimizers = None @@ -107,6 +113,7 @@ def training_step(self, batch, **kwargs): :return: The loss of the training step. :rtype: torch.Tensor """ + self.current_batch = batch loss = self._optimization_cycle(batch=batch, **kwargs) self.store_log("train_loss", loss, self.get_batch_size(batch)) return loss @@ -124,6 +131,7 @@ def validation_step(self, batch, **kwargs): :return: The loss of the training step. :rtype: torch.Tensor """ + self.current_batch = batch losses = self.optimization_cycle(batch=batch, **kwargs) loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("val_loss", loss, self.get_batch_size(batch)) @@ -142,6 +150,7 @@ def test_step(self, batch, **kwargs): :return: The loss of the training step. :rtype: torch.Tensor """ + self.current_batch = batch losses = self.optimization_cycle(batch=batch, **kwargs) loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("test_loss", loss, self.get_batch_size(batch)) @@ -163,6 +172,21 @@ def store_log(self, name, value, batch_size): **self.trainer.logging_kwargs, ) + def lr_scheduler_step(self, scheduler, metric=None): + """ + The lr scheduler step overriden. This method is overridden in order to + ensure :class:`pina.optim.scheduler.Scheduler` + objects can be safely used. + + :param Scheduler scheduler: The scheduler instance. + :param str metric: The metric to track for the scheduling, + defaults to None. + """ + if metric is None: + scheduler.step() + else: + scheduler.step(metric) + def setup(self, stage): """ This method is called at the start of the train and test process to @@ -243,7 +267,7 @@ def _optimization_cycle(self, batch, **kwargs): :rtype: dict """ # compute losses - losses = self.optimization_cycle(batch) + losses = self.optimization_cycle(batch, **kwargs) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() # store log @@ -430,9 +454,18 @@ def configure_optimizers(self): :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """ - self.optimizer.hook(self.model.parameters()) + # get connector + optimizer_connector = self._pina_optimizers[0] + scheduler_connector = self._pina_schedulers[0] + # set the hooks + optimizer_connector._register_hooks( + parameters=self.model.parameters(), + solver=self, + ) + scheduler_connector._register_hooks(optimizer=optimizer_connector) + # only for inverse problems if isinstance(self.problem, InverseProblem): - self.optimizer.instance.add_param_group( + self.optimizer.add_param_group( { "params": [ self._params[var] @@ -440,8 +473,7 @@ def configure_optimizers(self): ] } ) - self.scheduler.hook(self.optimizer) - return ([self.optimizer.instance], [self.scheduler.instance]) + return ([self.optimizer], [self.scheduler]) @property def model(self): @@ -461,7 +493,7 @@ def scheduler(self): :return: The scheduler used for training. :rtype: Scheduler """ - return self._pina_schedulers[0] + return self._pina_schedulers[0].instance @property def optimizer(self): @@ -471,7 +503,7 @@ def optimizer(self): :return: The optimizer used for training. :rtype: Optimizer """ - return self._pina_optimizers[0] + return self._pina_optimizers[0].instance class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): @@ -601,15 +633,17 @@ def configure_optimizers(self): :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """ - for optimizer, scheduler, model in zip( - self.optimizers, self.schedulers, self.models + for optimizer_connector, scheduler_connector, model in zip( + self._pina_optimizers, self._pina_schedulers, self.models ): - optimizer.hook(model.parameters()) - scheduler.hook(optimizer) + optimizer_connector._register_hooks( + parameters=model.parameters(), solver=self + ) + scheduler_connector._register_hooks(optimizer=optimizer_connector) return ( - [optimizer.instance for optimizer in self.optimizers], - [scheduler.instance for scheduler in self.schedulers], + [optimizer for optimizer in self.optimizers], + [scheduler for scheduler in self.schedulers], ) @property @@ -630,7 +664,7 @@ def optimizers(self): :return: The optimizers used for training. :rtype: list[Optimizer] """ - return self._pina_optimizers + return [optimizer.instance for optimizer in self._pina_optimizers] @property def schedulers(self): @@ -640,4 +674,4 @@ def schedulers(self): :return: The schedulers used for training. :rtype: list[Scheduler] """ - return self._pina_schedulers + return [scheduler.instance for scheduler in self._pina_schedulers] diff --git a/tests/test_callback/test_optimizer_callback.py b/tests/test_callback/test_optimizer_callback.py index 3383c792c..54d9bd747 100644 --- a/tests/test_callback/test_optimizer_callback.py +++ b/tests/test_callback/test_optimizer_callback.py @@ -57,7 +57,7 @@ def test_switch_optimizer_routine(new_opt, epoch_switch): trainer.train() # Check that the trainer strategy optimizers have been updated - assert solver.optimizer.instance.__class__ == new_opt.instance.__class__ + assert solver.optimizer.__class__ == new_opt.instance.__class__ assert ( trainer.strategy.optimizers[0].__class__ == new_opt.instance.__class__ ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 037de9929..b1aec99f1 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,7 +1,12 @@ import torch import pytest +from pina.solver import SupervisedSolver +from pina.problem.zoo import SupervisedProblem from pina.optim import TorchOptimizer + +problem = SupervisedProblem(torch.randn(2, 10), torch.randn(2, 10)) +model = torch.nn.Linear(10, 10) opt_list = [ torch.optim.Adam, torch.optim.AdamW, @@ -16,6 +21,27 @@ def test_constructor(optimizer_class): @pytest.mark.parametrize("optimizer_class", opt_list) -def test_hook(optimizer_class): +def test_parameter_hook(optimizer_class): + opt = TorchOptimizer(optimizer_class, lr=1e-3) + assert opt.hooks_done["parameter_hook"] is False + opt.parameter_hook(model.parameters()) + assert opt.hooks_done["parameter_hook"] is True + + +@pytest.mark.parametrize("optimizer_class", opt_list) +def test_solver_hook(optimizer_class): + opt = TorchOptimizer(optimizer_class, lr=1e-3) + solver = SupervisedSolver(problem=problem, model=model, optimizer=opt) + assert opt.hooks_done["solver_hook"] is False + with pytest.raises(RuntimeError): + opt.solver_hook(solver) + solver.configure_optimizers() + assert opt.hooks_done["solver_hook"] is True + assert opt.solver is solver + + +@pytest.mark.parametrize("optimizer_class", opt_list) +def test_instance(optimizer_class): opt = TorchOptimizer(optimizer_class, lr=1e-3) - opt.hook(torch.nn.Linear(10, 10).parameters()) + opt.parameter_hook(model.parameters()) + assert isinstance(opt.instance, optimizer_class) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 157a818d2..47bc9afc1 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,6 +2,7 @@ import pytest from pina.optim import TorchOptimizer, TorchScheduler +model = torch.nn.Linear(10, 10) opt_list = [ torch.optim.Adam, torch.optim.AdamW, @@ -21,6 +22,18 @@ def test_constructor(scheduler_class): @pytest.mark.parametrize("scheduler_class", sch_list) def test_hook(optimizer_class, scheduler_class): opt = TorchOptimizer(optimizer_class, lr=1e-3) - opt.hook(torch.nn.Linear(10, 10).parameters()) + opt.parameter_hook(model.parameters()) sch = TorchScheduler(scheduler_class) - sch.hook(opt) + assert sch.hooks_done["optimizer_hook"] is False + sch.optimizer_hook(opt) + assert sch.hooks_done["optimizer_hook"] is True + + +@pytest.mark.parametrize("optimizer_class", opt_list) +@pytest.mark.parametrize("scheduler_class", sch_list) +def test_instance(optimizer_class, scheduler_class): + opt = TorchOptimizer(optimizer_class, lr=1e-3) + opt.parameter_hook(model.parameters()) + sch = TorchScheduler(scheduler_class) + sch.optimizer_hook(opt) + assert isinstance(sch.instance, scheduler_class)