Skip to content

Commit 92aede6

Browse files
committed
optim refactoring
* adding connectors for optimizers/schedulers * simplify configure_optimizers logic
1 parent 6d7ce0e commit 92aede6

18 files changed

+413
-205
lines changed

pina/callback/optimizer_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def on_train_epoch_start(self, trainer, __):
6262

6363
# Hook the new optimizers to the model parameters
6464
for idx, optim in enumerate(self._new_optimizers):
65-
optim.hook(trainer.solver._pina_models[idx].parameters())
65+
optim._register_hooks(
66+
parameters=trainer.solver._pina_models[idx].parameters(),
67+
solver=trainer.solver,
68+
)
6669
optims.append(optim)
6770

6871
# Update the solver's optimizers

pina/optim/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
"""Module for the Optimizers and Schedulers."""
22

33
__all__ = [
4-
"Optimizer",
54
"TorchOptimizer",
6-
"Scheduler",
75
"TorchScheduler",
86
]
97

10-
from .optimizer_interface import Optimizer
118
from .torch_optimizer import TorchOptimizer
12-
from .scheduler_interface import Scheduler
13-
from .torch_scheduler import TorchScheduler
9+
from .torch_scheduler import TorchScheduler
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Module for the PINA Optimizer and Scheduler Connectors Interface."""
2+
3+
from abc import ABCMeta, abstractmethod
4+
from functools import wraps
5+
6+
7+
class OptimizerConnectorInterface(metaclass=ABCMeta):
8+
"""
9+
Interface class for method definitions in the Optimizer classes.
10+
"""
11+
12+
@abstractmethod
13+
def parameter_hook(self, parameters):
14+
"""
15+
Abstract method to define the hook logic for the optimizer. This hook
16+
is used to initialize the optimizer instance with the given parameters.
17+
18+
:param dict parameters: The parameters of the model to be optimized.
19+
"""
20+
21+
@abstractmethod
22+
def solver_hook(self, solver):
23+
"""
24+
Abstract method to define the hook logic for the optimizer. This hook
25+
is used to hook the optimizer instance with the given parameters.
26+
27+
:param SolverInterface solver: The solver to hook.
28+
"""
29+
30+
31+
class SchedulerConnectorInterface(metaclass=ABCMeta):
32+
"""
33+
Abstract base class for defining a scheduler. All specific schedulers should
34+
inherit form this class and implement the required methods.
35+
"""
36+
37+
@abstractmethod
38+
def optimizer_hook(self):
39+
"""
40+
Abstract method to define the hook logic for the scheduler. This hook
41+
is used to hook the scheduler instance with the given optimizer.
42+
"""
43+
44+
45+
class _HooksOptim:
46+
"""
47+
Mixin class to manage and track the execution of hook methods in optimizer
48+
or scheduler classes.
49+
50+
This class automatically detects methods ending with `_hook` and tracks
51+
whether they have been executed for a given instance. Subclasses defining
52+
`_hook` methods benefit from automatic tracking without additional
53+
boilerplate.
54+
"""
55+
def __init__(self, *args, **kwargs):
56+
"""
57+
Initialize the hooks tracking dictionary `hooks_done` for this instance.
58+
59+
Each hook method detected in the class hierarchy is added to
60+
`hooks_done` with an initial value of False (not executed).
61+
"""
62+
super().__init__(*args, **kwargs)
63+
# Initialize hooks_done per instance
64+
self.hooks_done = {}
65+
for cls in self.__class__.__mro__:
66+
for attr_name, attr_value in cls.__dict__.items():
67+
if callable(attr_value) and attr_name.endswith("_hook"):
68+
self.hooks_done.setdefault(attr_name, False)
69+
70+
def __init_subclass__(cls, **kwargs):
71+
"""
72+
Hook called when a subclass of _HooksOptim is created.
73+
74+
Wraps all concrete `_hook` methods defined in the subclass so that
75+
executing the method automatically updates `hooks_done`.
76+
"""
77+
super().__init_subclass__(**kwargs)
78+
# Wrap only concrete _hook methods defined in this subclass
79+
for attr_name, attr_value in cls.__dict__.items():
80+
if callable(attr_value) and attr_name.endswith("_hook"):
81+
setattr(cls, attr_name, cls.hook_wrapper(attr_name, attr_value))
82+
83+
@staticmethod
84+
def hook_wrapper(name, func):
85+
"""
86+
Wrap a hook method to mark it as executed after calling it.
87+
88+
:param str name: The name of the hook method.
89+
:param callable func: The original hook method to wrap.
90+
:return: The wrapped hook method that updates `hooks_done`.
91+
"""
92+
@wraps(func)
93+
def wrapper(self, *args, **kwargs):
94+
result = func(self, *args, **kwargs)
95+
self.hooks_done[name] = True
96+
return result
97+
98+
return wrapper
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Module for the PINA Optimizer."""
2+
3+
from .optim_connector_interface import OptimizerConnectorInterface, _HooksOptim
4+
5+
6+
class OptimizerConnector(OptimizerConnectorInterface, _HooksOptim):
7+
"""
8+
Abstract base class for defining an optimizer connector. All specific
9+
optimizers connectors should inherit form this class and implement the
10+
required methods.
11+
"""
12+
13+
def __init__(self, optimizer_class, **optimizer_class_kwargs):
14+
"""
15+
Initialize connector parameters
16+
17+
:param torch.optim.Optimizer optimizer_class: The torch optimizer class.
18+
:param dict optimizer_class_kwargs: The optimizer kwargs.
19+
"""
20+
super().__init__()
21+
self._optimizer_class = optimizer_class
22+
self._optimizer_instance = None
23+
self._optim_kwargs = optimizer_class_kwargs
24+
self._solver = None
25+
26+
def parameter_hook(self, parameters):
27+
"""
28+
Abstract method to define the hook logic for the optimizer. This hook
29+
is used to initialize the optimizer instance with the given parameters.
30+
31+
:param dict parameters: The parameters of the model to be optimized.
32+
"""
33+
self._optimizer_instance = self._optimizer_class(
34+
parameters, **self._optim_kwargs
35+
)
36+
37+
def solver_hook(self, solver):
38+
"""
39+
Method to define the hook logic for the optimizer. This hook
40+
is used to hook the optimizer instance with the given parameters.
41+
42+
:param SolverInterface solver: The solver to hook.
43+
"""
44+
if not self.hooks_done["parameter_hook"]:
45+
raise RuntimeError(
46+
"Cannot run 'solver_hook' before 'parameter_hook'. "
47+
"Please call 'parameter_hook' first to initialize "
48+
"the solver parameters."
49+
)
50+
# hook to both instance and connector the solver
51+
self._solver = solver
52+
self._optimizer_instance.solver = solver
53+
54+
def _register_hooks(self, **kwargs):
55+
"""
56+
Register the optimizers hooks. This method inspects keyword arguments
57+
for known keys (`parameters`, `solver`, ...) and applies the
58+
corresponding hooks.
59+
60+
It allows flexible integration with
61+
different workflows without enforcing a strict method signature.
62+
63+
This method is used inside the
64+
:class:`~pina.solver.solver.SolverInterface` class.
65+
66+
:param kwargs: Expected keys may include:
67+
- ``parameters``: Parameters to be registered for optimization.
68+
- ``solver``: Solver instance.
69+
"""
70+
# parameter hook
71+
parameters = kwargs.get("parameters", None)
72+
if parameters is not None:
73+
self.parameter_hook(parameters)
74+
# solver hook
75+
solver = kwargs.get("solver", None)
76+
if solver is not None:
77+
self.solver_hook(solver)
78+
79+
@property
80+
def solver(self):
81+
"""
82+
Get the solver hooked to the optimizer.
83+
"""
84+
if not self.hooks_done["solver_hook"]:
85+
raise RuntimeError(
86+
"Solver has not been hooked."
87+
"Override the method solver_hook to hook the solver to "
88+
"the optimizer."
89+
)
90+
return self._solver
91+
92+
@property
93+
def instance(self):
94+
"""
95+
Get the optimizer instance.
96+
97+
:return: The optimizer instance
98+
:rtype: torch.optim.Optimizer
99+
"""
100+
return self._optimizer_instance
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Module for the PINA Scheduler."""
2+
3+
from .optim_connector_interface import SchedulerConnectorInterface, _HooksOptim
4+
from .optimizer_connector import OptimizerConnector
5+
from ...utils import check_consistency
6+
7+
8+
class SchedulerConnector(SchedulerConnectorInterface, _HooksOptim):
9+
"""
10+
Class for defining a scheduler connector. All specific schedulers connectors
11+
should inherit form this class and implement the required methods.
12+
"""
13+
14+
def __init__(self, scheduler_class, **scheduler_kwargs):
15+
"""
16+
Initialize connector parameters
17+
18+
:param torch.optim.lr_scheduler.LRScheduler scheduler_class: The torch
19+
scheduler class.
20+
:param dict scheduler_kwargs: The scheduler kwargs.
21+
"""
22+
super().__init__()
23+
self._scheduler_class = scheduler_class
24+
self._scheduler_instance = None
25+
self._scheduler_kwargs = scheduler_kwargs
26+
27+
def optimizer_hook(self, optimizer):
28+
"""
29+
Abstract method to define the hook logic for the scheduler. This hook
30+
is used to hook the scheduler instance with the given optimizer.
31+
32+
:param Optimizer optimizer: The optimizer to hook.
33+
"""
34+
check_consistency(optimizer, OptimizerConnector)
35+
if not optimizer.hooks_done["parameter_hook"]:
36+
raise RuntimeError(
37+
"Scheduler cannot be set, Optimizer not hooked "
38+
"to model parameters. "
39+
"Please call Optimizer.parameter_hook()."
40+
)
41+
self._scheduler_instance = self._scheduler_class(
42+
optimizer.instance, **self._scheduler_kwargs
43+
)
44+
45+
def _register_hooks(self, **kwargs):
46+
"""
47+
Register the optimizers hooks. This method inspects keyword arguments
48+
for known keys (`parameters`, `solver`, ...) and applies the
49+
corresponding hooks.
50+
51+
It allows flexible integration with
52+
different workflows without enforcing a strict method signature.
53+
54+
This method is used inside the
55+
:class:`~pina.solver.solver.SolverInterface` class.
56+
57+
:param kwargs: Expected keys may include:
58+
- ``parameters``: Parameters to be registered for optimization.
59+
- ``solver``: Solver instance.
60+
"""
61+
# optimizer hook
62+
optimizer = kwargs.get("optimizer", None)
63+
if optimizer is not None:
64+
check_consistency(optimizer, OptimizerConnector)
65+
self.optimizer_hook(optimizer)
66+
67+
@property
68+
def instance(self):
69+
"""
70+
Get the scheduler instance.
71+
72+
:return: The scheduler instance
73+
:rtype: torch.optim.lr_scheduler.LRScheduler
74+
"""
75+
return self._scheduler_instance

pina/optim/optimizer_interface.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

pina/optim/scheduler_interface.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)