Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pina/callback/optimizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def on_train_epoch_start(self, trainer, __):

# Hook the new optimizers to the model parameters
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver._pina_models[idx].parameters())
optim._register_hooks(
parameters=trainer.solver._pina_models[idx].parameters(),
solver=trainer.solver,
)
optims.append(optim)

# Update the solver's optimizers
Expand Down
6 changes: 1 addition & 5 deletions pina/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
"""Module for the Optimizers and Schedulers."""

__all__ = [
"Optimizer",
"TorchOptimizer",
"Scheduler",
"TorchScheduler",
]

from .optimizer_interface import Optimizer
from .torch_optimizer import TorchOptimizer
from .scheduler_interface import Scheduler
from .torch_scheduler import TorchScheduler
from .torch_scheduler import TorchScheduler
98 changes: 98 additions & 0 deletions pina/optim/core/optim_connector_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Module for the PINA Optimizer and Scheduler Connectors Interface."""

from abc import ABCMeta, abstractmethod
from functools import wraps


class OptimizerConnectorInterface(metaclass=ABCMeta):
"""
Interface class for method definitions in the Optimizer classes.
"""

@abstractmethod
def parameter_hook(self, parameters):
"""
Abstract method to define the hook logic for the optimizer. This hook
is used to initialize the optimizer instance with the given parameters.

:param dict parameters: The parameters of the model to be optimized.
"""

@abstractmethod
def solver_hook(self, solver):
"""
Abstract method to define the hook logic for the optimizer. This hook
is used to hook the optimizer instance with the given parameters.

:param SolverInterface solver: The solver to hook.
"""


class SchedulerConnectorInterface(metaclass=ABCMeta):
"""
Abstract base class for defining a scheduler. All specific schedulers should
inherit form this class and implement the required methods.
"""

@abstractmethod
def optimizer_hook(self):
"""
Abstract method to define the hook logic for the scheduler. This hook
is used to hook the scheduler instance with the given optimizer.
"""


class _HooksOptim:
"""
Mixin class to manage and track the execution of hook methods in optimizer
or scheduler classes.

This class automatically detects methods ending with `_hook` and tracks
whether they have been executed for a given instance. Subclasses defining
`_hook` methods benefit from automatic tracking without additional
boilerplate.
"""
def __init__(self, *args, **kwargs):
"""
Initialize the hooks tracking dictionary `hooks_done` for this instance.

Each hook method detected in the class hierarchy is added to
`hooks_done` with an initial value of False (not executed).
"""
super().__init__(*args, **kwargs)
# Initialize hooks_done per instance
self.hooks_done = {}
for cls in self.__class__.__mro__:
for attr_name, attr_value in cls.__dict__.items():
if callable(attr_value) and attr_name.endswith("_hook"):
self.hooks_done.setdefault(attr_name, False)

def __init_subclass__(cls, **kwargs):
"""
Hook called when a subclass of _HooksOptim is created.

Wraps all concrete `_hook` methods defined in the subclass so that
executing the method automatically updates `hooks_done`.
"""
super().__init_subclass__(**kwargs)
# Wrap only concrete _hook methods defined in this subclass
for attr_name, attr_value in cls.__dict__.items():
if callable(attr_value) and attr_name.endswith("_hook"):
setattr(cls, attr_name, cls.hook_wrapper(attr_name, attr_value))

@staticmethod
def hook_wrapper(name, func):
"""
Wrap a hook method to mark it as executed after calling it.

:param str name: The name of the hook method.
:param callable func: The original hook method to wrap.
:return: The wrapped hook method that updates `hooks_done`.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
self.hooks_done[name] = True
return result

return wrapper
100 changes: 100 additions & 0 deletions pina/optim/core/optimizer_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Module for the PINA Optimizer."""

from .optim_connector_interface import OptimizerConnectorInterface, _HooksOptim


class OptimizerConnector(OptimizerConnectorInterface, _HooksOptim):
"""
Abstract base class for defining an optimizer connector. All specific
optimizers connectors should inherit form this class and implement the
required methods.
"""

def __init__(self, optimizer_class, **optimizer_class_kwargs):
"""
Initialize connector parameters

:param torch.optim.Optimizer optimizer_class: The torch optimizer class.
:param dict optimizer_class_kwargs: The optimizer kwargs.
"""
super().__init__()
self._optimizer_class = optimizer_class
self._optimizer_instance = None
self._optim_kwargs = optimizer_class_kwargs
self._solver = None

def parameter_hook(self, parameters):
"""
Abstract method to define the hook logic for the optimizer. This hook
is used to initialize the optimizer instance with the given parameters.

:param dict parameters: The parameters of the model to be optimized.
"""
self._optimizer_instance = self._optimizer_class(
parameters, **self._optim_kwargs
)

def solver_hook(self, solver):
"""
Method to define the hook logic for the optimizer. This hook
is used to hook the optimizer instance with the given parameters.

:param SolverInterface solver: The solver to hook.
"""
if not self.hooks_done["parameter_hook"]:
raise RuntimeError(
"Cannot run 'solver_hook' before 'parameter_hook'. "
"Please call 'parameter_hook' first to initialize "
"the solver parameters."
)
# hook to both instance and connector the solver
self._solver = solver
self._optimizer_instance.solver = solver

def _register_hooks(self, **kwargs):
"""
Register the optimizers hooks. This method inspects keyword arguments
for known keys (`parameters`, `solver`, ...) and applies the
corresponding hooks.

It allows flexible integration with
different workflows without enforcing a strict method signature.

This method is used inside the
:class:`~pina.solver.solver.SolverInterface` class.

:param kwargs: Expected keys may include:
- ``parameters``: Parameters to be registered for optimization.
- ``solver``: Solver instance.
"""
# parameter hook
parameters = kwargs.get("parameters", None)
if parameters is not None:
self.parameter_hook(parameters)
# solver hook
solver = kwargs.get("solver", None)
if solver is not None:
self.solver_hook(solver)

@property
def solver(self):
"""
Get the solver hooked to the optimizer.
"""
if not self.hooks_done["solver_hook"]:
raise RuntimeError(
"Solver has not been hooked."
"Override the method solver_hook to hook the solver to "
"the optimizer."
)
return self._solver

@property
def instance(self):
"""
Get the optimizer instance.

:return: The optimizer instance
:rtype: torch.optim.Optimizer
"""
return self._optimizer_instance
75 changes: 75 additions & 0 deletions pina/optim/core/scheduler_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Module for the PINA Scheduler."""

from .optim_connector_interface import SchedulerConnectorInterface, _HooksOptim
from .optimizer_connector import OptimizerConnector
from ...utils import check_consistency


class SchedulerConnector(SchedulerConnectorInterface, _HooksOptim):
"""
Class for defining a scheduler connector. All specific schedulers connectors
should inherit form this class and implement the required methods.
"""

def __init__(self, scheduler_class, **scheduler_kwargs):
"""
Initialize connector parameters

:param torch.optim.lr_scheduler.LRScheduler scheduler_class: The torch
scheduler class.
:param dict scheduler_kwargs: The scheduler kwargs.
"""
super().__init__()
self._scheduler_class = scheduler_class
self._scheduler_instance = None
self._scheduler_kwargs = scheduler_kwargs

def optimizer_hook(self, optimizer):
"""
Abstract method to define the hook logic for the scheduler. This hook
is used to hook the scheduler instance with the given optimizer.

:param Optimizer optimizer: The optimizer to hook.
"""
check_consistency(optimizer, OptimizerConnector)
if not optimizer.hooks_done["parameter_hook"]:
raise RuntimeError(
"Scheduler cannot be set, Optimizer not hooked "
"to model parameters. "
"Please call Optimizer.parameter_hook()."
)
self._scheduler_instance = self._scheduler_class(
optimizer.instance, **self._scheduler_kwargs
)

def _register_hooks(self, **kwargs):
"""
Register the optimizers hooks. This method inspects keyword arguments
for known keys (`parameters`, `solver`, ...) and applies the
corresponding hooks.

It allows flexible integration with
different workflows without enforcing a strict method signature.

This method is used inside the
:class:`~pina.solver.solver.SolverInterface` class.

:param kwargs: Expected keys may include:
- ``parameters``: Parameters to be registered for optimization.
- ``solver``: Solver instance.
"""
# optimizer hook
optimizer = kwargs.get("optimizer", None)
if optimizer is not None:
check_consistency(optimizer, OptimizerConnector)
self.optimizer_hook(optimizer)

@property
def instance(self):
"""
Get the scheduler instance.

:return: The scheduler instance
:rtype: torch.optim.lr_scheduler.LRScheduler
"""
return self._scheduler_instance
23 changes: 0 additions & 23 deletions pina/optim/optimizer_interface.py

This file was deleted.

23 changes: 0 additions & 23 deletions pina/optim/scheduler_interface.py

This file was deleted.

Loading