Skip to content

Commit dc808c1

Browse files
dario-cosciaFilippoOlivoGiovanniCanali
authored
Add Normalizer Callback (#631)
* add normalizer callback * implement shift and scale parameters computation * change name files normalizer data callback * reduce tests * fix documentation * add NotImplementedError for PinaGraphDataset --------- Co-authored-by: FilippoOlivo <[email protected]> Co-authored-by: giovanni <[email protected]>
1 parent ef75f13 commit dc808c1

File tree

6 files changed

+483
-1
lines changed

6 files changed

+483
-1
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ Callbacks
253253
Optimizer callback <callback/optimizer_callback.rst>
254254
R3 Refinment callback <callback/refinement/r3_refinement.rst>
255255
Refinment Interface callback <callback/refinement/refinement_interface.rst>
256+
Normalizer callback <callback/normalizer_data_callback.rst>
256257

257258
Losses and Weightings
258259
---------------------
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Normalizer callbacks
2+
=======================
3+
4+
.. currentmodule:: pina.callback.normalizer_data_callback
5+
.. autoclass:: NormalizerDataCallback
6+
:members:
7+
:show-inheritance:

pina/callback/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
"MetricTracker",
66
"PINAProgressBar",
77
"R3Refinement",
8+
"NormalizerDataCallback",
89
]
910

1011
from .optimizer_callback import SwitchOptimizer
1112
from .processing_callback import MetricTracker, PINAProgressBar
1213
from .refinement import R3Refinement
14+
from .normalizer_data_callback import NormalizerDataCallback
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""Module for the Normalizer callback."""
2+
3+
import torch
4+
from lightning.pytorch import Callback
5+
from ..label_tensor import LabelTensor
6+
from ..utils import check_consistency, is_function
7+
from ..condition import InputTargetCondition
8+
from ..data.dataset import PinaGraphDataset
9+
10+
11+
class NormalizerDataCallback(Callback):
12+
r"""
13+
A Callback used to normalize the dataset inputs or targets according to
14+
user-provided scale and shift functions.
15+
16+
The transformation is applied as:
17+
18+
.. math::
19+
20+
x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}
21+
22+
:Example:
23+
24+
>>> NormalizerDataCallback()
25+
>>> NormalizerDataCallback(
26+
... scale_fn: torch.std,
27+
... shift_fn: torch.mean,
28+
... stage: "all",
29+
... apply_to: "input",
30+
... )
31+
"""
32+
33+
def __init__(
34+
self,
35+
scale_fn=torch.std,
36+
shift_fn=torch.mean,
37+
stage="all",
38+
apply_to="input",
39+
):
40+
"""
41+
Initialization of the :class:`NormalizerDataCallback` class.
42+
43+
:param Callable scale_fn: The function to compute the scaling factor.
44+
Default is ``torch.std``.
45+
:param Callable shift_fn: The function to compute the shifting factor.
46+
Default is ``torch.mean``.
47+
:param str stage: The stage in which normalization is applied.
48+
Accepted values are "train", "validate", "test", or "all".
49+
Default is ``"all"``.
50+
:param str apply_to: Whether to normalize "input" or "target" data.
51+
Default is ``"input"``.
52+
:raises ValueError: If ``scale_fn`` is not callable.
53+
:raises ValueError: If ``shift_fn`` is not callable.
54+
"""
55+
super().__init__()
56+
57+
# Validate parameters
58+
self.apply_to = self._validate_apply_to(apply_to)
59+
self.stage = self._validate_stage(stage)
60+
61+
# Validate functions
62+
if not is_function(scale_fn):
63+
raise ValueError(f"scale_fn must be Callable, got {scale_fn}")
64+
if not is_function(shift_fn):
65+
raise ValueError(f"shift_fn must be Callable, got {shift_fn}")
66+
self.scale_fn = scale_fn
67+
self.shift_fn = shift_fn
68+
69+
# Initialize normalizer dictionary
70+
self._normalizer = {}
71+
72+
def _validate_apply_to(self, apply_to):
73+
"""
74+
Validate the ``apply_to`` parameter.
75+
76+
:param str apply_to: The candidate value for the ``apply_to`` parameter.
77+
:raises ValueError: If ``apply_to`` is neither "input" nor "target".
78+
:return: The validated ``apply_to`` value.
79+
:rtype: str
80+
"""
81+
check_consistency(apply_to, str)
82+
if apply_to not in {"input", "target"}:
83+
raise ValueError(
84+
f"apply_to must be either 'input' or 'target', got {apply_to}"
85+
)
86+
87+
return apply_to
88+
89+
def _validate_stage(self, stage):
90+
"""
91+
Validate the ``stage`` parameter.
92+
93+
:param str stage: The candidate value for the ``stage`` parameter.
94+
:raises ValueError: If ``stage`` is not one of "train", "validate",
95+
"test", or "all".
96+
:return: The validated ``stage`` value.
97+
:rtype: str
98+
"""
99+
check_consistency(stage, str)
100+
if stage not in {"train", "validate", "test", "all"}:
101+
raise ValueError(
102+
"stage must be one of 'train', 'validate', 'test', or 'all',"
103+
f" got {stage}"
104+
)
105+
106+
return stage
107+
108+
def setup(self, trainer, pl_module, stage):
109+
"""
110+
Apply normalization during setup.
111+
112+
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
113+
:param SolverInterface pl_module: A
114+
:class:`~pina.solver.solver.SolverInterface` instance.
115+
:param str stage: The current stage.
116+
:raises RuntimeError: If the training dataset is not available when
117+
computing normalization parameters.
118+
:return: The result of the parent setup.
119+
:rtype: Any
120+
121+
:raises NotImplementedError: If the dataset is graph-based.
122+
"""
123+
124+
# Ensure datsets are not graph-based
125+
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
126+
raise NotImplementedError(
127+
"NormalizerDataCallback is not compatible with "
128+
"graph-based datasets."
129+
)
130+
131+
# Extract conditions
132+
conditions_to_normalize = [
133+
name
134+
for name, cond in pl_module.problem.conditions.items()
135+
if isinstance(cond, InputTargetCondition)
136+
]
137+
138+
# Compute scale and shift parameters
139+
if not self.normalizer:
140+
if not trainer.datamodule.train_dataset:
141+
raise RuntimeError(
142+
"Training dataset is not available. Cannot compute "
143+
"normalization parameters."
144+
)
145+
self._compute_scale_shift(
146+
conditions_to_normalize, trainer.datamodule.train_dataset
147+
)
148+
149+
# Apply normalization based on the specified stage
150+
if stage == "fit" and self.stage in ["train", "all"]:
151+
self.normalize_dataset(trainer.datamodule.train_dataset)
152+
if stage == "fit" and self.stage in ["validate", "all"]:
153+
self.normalize_dataset(trainer.datamodule.val_dataset)
154+
if stage == "test" and self.stage in ["test", "all"]:
155+
self.normalize_dataset(trainer.datamodule.test_dataset)
156+
157+
return super().setup(trainer, pl_module, stage)
158+
159+
def _compute_scale_shift(self, conditions, dataset):
160+
"""
161+
Compute scale and shift parameters for each condition in the dataset.
162+
163+
:param list conditions: The list of condition names.
164+
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
165+
"""
166+
for cond in conditions:
167+
if cond in dataset.conditions_dict:
168+
data = dataset.conditions_dict[cond][self.apply_to]
169+
shift = self.shift_fn(data)
170+
scale = self.scale_fn(data)
171+
self._normalizer[cond] = {
172+
"shift": shift,
173+
"scale": scale,
174+
}
175+
176+
@staticmethod
177+
def _norm_fn(value, scale, shift):
178+
"""
179+
Normalize a value according to the scale and shift parameters.
180+
181+
:param value: The input tensor to normalize.
182+
:type value: torch.Tensor | LabelTensor
183+
:param float scale: The scaling factor.
184+
:param float shift: The shifting factor.
185+
:return: The normalized tensor.
186+
:rtype: torch.Tensor | LabelTensor
187+
"""
188+
scaled_value = (value - shift) / scale
189+
if isinstance(value, LabelTensor):
190+
scaled_value = LabelTensor(scaled_value, value.labels)
191+
192+
return scaled_value
193+
194+
def normalize_dataset(self, dataset):
195+
"""
196+
Apply in-place normalization to the dataset.
197+
198+
:param PinaDataset dataset: The dataset to be normalized.
199+
"""
200+
# Initialize update dictionary
201+
update_dataset_dict = {}
202+
203+
# Iterate over conditions and apply normalization
204+
for cond, norm_params in self.normalizer.items():
205+
points = dataset.conditions_dict[cond][self.apply_to]
206+
scale = norm_params["scale"]
207+
shift = norm_params["shift"]
208+
normalized_points = self._norm_fn(points, scale, shift)
209+
update_dataset_dict[cond] = {
210+
self.apply_to: (
211+
LabelTensor(normalized_points, points.labels)
212+
if isinstance(points, LabelTensor)
213+
else normalized_points
214+
)
215+
}
216+
217+
# Update the dataset in-place
218+
dataset.update_data(update_dataset_dict)
219+
220+
@property
221+
def normalizer(self):
222+
"""
223+
Get the dictionary of normalization parameters.
224+
225+
:return: The dictionary of normalization parameters.
226+
:rtype: dict
227+
"""
228+
return self._normalizer

pina/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def is_function(f):
206206
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
207207
:rtype: bool
208208
"""
209-
return isinstance(f, (types.FunctionType, types.LambdaType))
209+
return callable(f)
210210

211211

212212
def chebyshev_roots(n):

0 commit comments

Comments
 (0)