Skip to content

Conversation

@dario-coscia
Copy link
Collaborator

Description

This PR fixes #585.

Summary

This PR introduces a new Lightning CallbackNormalizerDataCallback – that normalizes dataset inputs or targets during training, validation, or testing. The transformation applied is:

$$x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}$$

Usage

Users can provide either a single normalizer for all conditions, or condition-specific normalizers:

# Apply the same normalization everywhere
NormalizerDataCallback({"scale": 1, "shift": 0})

# Apply different normalization per condition "a" and "b" on the target training data
NormalizerDataCallback({
    "a": {"scale": 2.0, "shift": 1.0},
    "b": {"scale": 0.5, "shift": 0.0},
}, stage="train", apply_to="target")

Checklist

  • Code follows the project’s Code Style Guidelines
  • Tests have been added or updated
  • Documentation has been updated if necessary
  • Pull request is linked to an open issue

@dario-coscia dario-coscia self-assigned this Sep 9, 2025
@dario-coscia dario-coscia added the pr-to-review Label for PR that are ready to been reviewed label Sep 9, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2025

badge

Code Coverage Summary

Filename                                                       Stmts    Miss  Cover    Missing
-----------------------------------------------------------  -------  ------  -------  -------------------------------------------------------------------------------------------------------
__init__.py                                                        7       0  100.00%
graph.py                                                         114      11  90.35%   99-100, 112, 124, 126, 142, 144, 166, 169, 182, 271
label_tensor.py                                                  251      28  88.84%   81, 121, 144-148, 165, 177, 182, 188-193, 273, 280, 332, 334, 348, 444-447, 490, 537, 629, 664-673, 710
operator.py                                                       72       5  93.06%   250-268, 459
trainer.py                                                        75       5  93.33%   195-204, 293, 314, 318
type_checker.py                                                   22       0  100.00%
utils.py                                                          73      11  84.93%   59, 75, 141, 178, 181, 184, 220-223, 268
adaptive_function/__init__.py                                      3       0  100.00%
adaptive_function/adaptive_function.py                            55       0  100.00%
adaptive_function/adaptive_function_interface.py                  51       6  88.24%   98, 141, 148-151
callback/__init__.py                                               5       0  100.00%
callback/normalizer_data_callback.py                              68       1  98.53%   141
callback/optimizer_callback.py                                    23       0  100.00%
callback/processing_callback.py                                   49       5  89.80%   42-43, 73, 168, 171
callback/refinement/__init__.py                                    3       0  100.00%
callback/refinement/r3_refinement.py                              28       1  96.43%   88
callback/refinement/refinement_interface.py                       50       5  90.00%   32, 59, 67, 72, 78
condition/__init__.py                                              7       0  100.00%
condition/condition.py                                            19       1  94.74%   141
condition/condition_interface.py                                  37       4  89.19%   32, 76, 95, 125
condition/data_condition.py                                       26       1  96.15%   78
condition/domain_equation_condition.py                            19       0  100.00%
condition/input_equation_condition.py                             43       1  97.67%   157
condition/input_target_condition.py                               44       1  97.73%   172
data/__init__.py                                                   3       0  100.00%
data/data_module.py                                              201      22  89.05%   41-52, 132, 172, 193, 232, 313-317, 323-327, 399, 466, 546, 637, 639
data/dataset.py                                                   82       7  91.46%   42, 123-126, 256, 293
domain/__init__.py                                                10       0  100.00%
domain/cartesian.py                                              112      10  91.07%   37, 47, 75-76, 92, 97, 103, 246, 256, 264
domain/difference_domain.py                                       25       2  92.00%   54, 87
domain/domain_interface.py                                        20       5  75.00%   37-41
domain/ellipsoid.py                                              104      24  76.92%   52, 56, 127, 250-257, 269-282, 286-287, 290, 295
domain/exclusion_domain.py                                        28       1  96.43%   86
domain/intersection_domain.py                                     28       1  96.43%   85
domain/operation_interface.py                                     26       1  96.15%   88
domain/simplex.py                                                 72      14  80.56%   62, 207-225, 246-247, 251, 256
domain/union_domain.py                                            25       1  96.00%   43
equation/__init__.py                                               4       0  100.00%
equation/equation.py                                              15       1  93.33%   56
equation/equation_factory.py                                      24       8  66.67%   37, 73, 97-110, 132-145
equation/equation_interface.py                                     4       0  100.00%
equation/system_equation.py                                       20       0  100.00%
loss/__init__.py                                                   9       0  100.00%
loss/linear_weighting.py                                          14       0  100.00%
loss/loss_interface.py                                            17       2  88.24%   45, 51
loss/lp_loss.py                                                   15       0  100.00%
loss/ntk_weighting.py                                             18       0  100.00%
loss/power_loss.py                                                15       0  100.00%
loss/scalar_weighting.py                                          16       0  100.00%
loss/self_adaptive_weighting.py                                   12       0  100.00%
loss/weighting_interface.py                                       29       3  89.66%   35, 41-42
model/__init__.py                                                 11       0  100.00%
model/average_neural_operator.py                                  31       2  93.55%   73, 82
model/deeponet.py                                                 93      13  86.02%   187-190, 209, 240, 283, 293, 303, 313, 323, 333, 488, 498
model/feed_forward.py                                             89      11  87.64%   58, 195, 200, 278-292
model/fourier_neural_operator.py                                  78      10  87.18%   96-100, 110, 155-159, 218, 220, 242, 342
model/graph_neural_operator.py                                    40       2  95.00%   58, 60
model/kernel_neural_operator.py                                   34       6  82.35%   83-84, 103-104, 123-124
model/low_rank_neural_operator.py                                 27       2  92.59%   89, 98
model/multi_feed_forward.py                                       12       5  58.33%   25-31
model/pirate_network.py                                           27       1  96.30%   118
model/spline.py                                                   89      37  58.43%   30, 41-66, 69, 128-132, 135, 159-177, 180
model/block/__init__.py                                           13       0  100.00%
model/block/average_neural_operator_block.py                      12       0  100.00%
model/block/convolution.py                                        64      13  79.69%   77, 81, 85, 91, 97, 111, 114, 151, 161, 171, 181, 191, 201
model/block/convolution_2d.py                                    146      27  81.51%   155, 162, 282, 314, 379-433, 456
model/block/embedding.py                                          48       7  85.42%   93, 143-146, 155, 168
model/block/fourier_block.py                                      31       0  100.00%
model/block/gno_block.py                                          22       4  81.82%   73-77, 87
model/block/integral.py                                           18       4  77.78%   22-25, 71
model/block/low_rank_block.py                                     24       0  100.00%
model/block/orthogonal.py                                         37       0  100.00%
model/block/pirate_network_block.py                               25       1  96.00%   89
model/block/pod_block.py                                          73      10  86.30%   55-58, 70, 83, 113, 148-153, 187, 212
model/block/rbf_block.py                                         179      25  86.03%   18, 42, 53, 64, 75, 86, 97, 223, 280, 282, 298, 301, 329, 335, 363, 367, 511-524
model/block/residual.py                                           46       0  100.00%
model/block/spectral.py                                           83       4  95.18%   132, 140, 262, 270
model/block/stride.py                                             28       7  75.00%   55, 58, 61, 67, 72-74
model/block/utils_convolution.py                                  22       3  86.36%   58-60
model/block/message_passing/__init__.py                            5       0  100.00%
model/block/message_passing/deep_tensor_network_block.py          21       0  100.00%
model/block/message_passing/en_equivariant_network_block.py       39       0  100.00%
model/block/message_passing/interaction_network_block.py          23       0  100.00%
model/block/message_passing/radial_field_network_block.py         20       0  100.00%
optim/__init__.py                                                  5       0  100.00%
optim/optimizer_interface.py                                       7       0  100.00%
optim/scheduler_interface.py                                       7       0  100.00%
optim/torch_optimizer.py                                          14       0  100.00%
optim/torch_scheduler.py                                          19       2  89.47%   5-6
problem/__init__.py                                                6       0  100.00%
problem/abstract_problem.py                                      117      12  89.74%   39-40, 59-70, 149, 161, 179, 253, 257, 286
problem/inverse_problem.py                                        22       0  100.00%
problem/parametric_problem.py                                      8       1  87.50%   29
problem/spatial_problem.py                                         8       0  100.00%
problem/time_dependent_problem.py                                  8       0  100.00%
problem/zoo/__init__.py                                            8       0  100.00%
problem/zoo/advection.py                                          33       7  78.79%   36-38, 52, 108-110
problem/zoo/allen_cahn.py                                         20       6  70.00%   20-22, 34-36
problem/zoo/diffusion_reaction.py                                 29       5  82.76%   94-104
problem/zoo/helmholtz.py                                          30       6  80.00%   36-42, 103-107
problem/zoo/inverse_poisson_2d_square.py                          48       3  93.75%   44-50
problem/zoo/poisson_2d_square.py                                  19       3  84.21%   65-70
problem/zoo/supervised_problem.py                                 11       0  100.00%
solver/__init__.py                                                 6       0  100.00%
solver/garom.py                                                  107       2  98.13%   129-130
solver/solver.py                                                 188      10  94.68%   195, 218, 290, 293-294, 353, 435, 518, 559, 565
solver/ensemble_solver/__init__.py                                 4       0  100.00%
solver/ensemble_solver/ensemble_pinn.py                           23       1  95.65%   104
solver/ensemble_solver/ensemble_solver_interface.py               27       0  100.00%
solver/ensemble_solver/ensemble_supervised.py                      9       0  100.00%
solver/physics_informed_solver/__init__.py                         8       0  100.00%
solver/physics_informed_solver/causal_pinn.py                     47       3  93.62%   157, 166-167
solver/physics_informed_solver/competitive_pinn.py                58       0  100.00%
solver/physics_informed_solver/gradient_pinn.py                   17       0  100.00%
solver/physics_informed_solver/pinn.py                            18       0  100.00%
solver/physics_informed_solver/pinn_interface.py                  54       3  94.44%   75, 166, 222
solver/physics_informed_solver/rba_pinn.py                        74       1  98.65%   324
solver/physics_informed_solver/self_adaptive_pinn.py             104       1  99.04%   392
solver/supervised_solver/__init__.py                               4       0  100.00%
solver/supervised_solver/reduced_order_model.py                   24       1  95.83%   137
solver/supervised_solver/supervised.py                             7       0  100.00%
solver/supervised_solver/supervised_solver_interface.py           25       1  96.00%   90
TOTAL                                                           4825     450  90.67%

Results for commit: 7f6316d

Minimum allowed coverage is 80.123%

♻️ This comment has been updated with latest results

Copy link
Collaborator

@GiovanniCanali GiovanniCanali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dario-coscia, thank you for the nice PR.

I have left some comments that should be addressed before we can merge.

I also have a few doubts regarding the logic of this callback:

  • Scope of normalization: normalization should be applicable only to InputTargetCondition conditions. Indeed, changing the scale in physics-related problems can be a major source of error and often leads to inconsistent results. In these cases, it’s usually better if the user handles scaling explicitly for more transparent management of the physics.
  • Graph support: InputTargetCondition conditions can also work with Graph objects. This scenario doesn’t seem to be covered in the DataNormalizationCallback implementation or in the related tests. Could you please ensure that this is handled?
  • Tensor-based scale and shift: I am not entirely sure I understand the case where scale or shift are tensors. Could you clarify the intended usage here?

Overall, the current approach feels a bit too complex for managing a relatively simple normalization. Personally, I think it might be cleaner and safer to let the user normalize the data before passing it to the PINA problem. What are your thoughts on this?

@FilippoOlivo
Copy link
Member

Hi @dario-coscia, thank you for the nice PR.

I have left some comments that should be addressed before we can merge.

I also have a few doubts regarding the logic of this callback:

  • Scope of normalization: normalization should be applicable only to InputTargetCondition conditions. Indeed, changing the scale in physics-related problems can be a major source of error and often leads to inconsistent results. In these cases, it’s usually better if the user handles scaling explicitly for more transparent management of the physics.
  • Graph support: InputTargetCondition conditions can also work with Graph objects. This scenario doesn’t seem to be covered in the DataNormalizationCallback implementation or in the related tests. Could you please ensure that this is handled?
  • Tensor-based scale and shift: I am not entirely sure I understand the case where scale or shift are tensors. Could you clarify the intended usage here?

Overall, the current approach feels a bit too complex for managing a relatively simple normalization. Personally, I think it might be cleaner and safer to let the user normalize the data before passing it to the PINA problem. What are your thoughts on this?

Hi, I thinks we should keep normalization inside the training cycle. I have a doubt regarding how the normalization strategy is applied. In particular how can I perform a normalization of the type

$$ z = \frac{x-\mu_{train}}{\sigma_{train}} $$

note that this should be applied on all the three dataset and when I define the callback I do not have access to the dataset splits.

@GiovanniCanali
Copy link
Collaborator

Hi @dario-coscia, thank you for the nice PR.
I have left some comments that should be addressed before we can merge.
I also have a few doubts regarding the logic of this callback:

  • Scope of normalization: normalization should be applicable only to InputTargetCondition conditions. Indeed, changing the scale in physics-related problems can be a major source of error and often leads to inconsistent results. In these cases, it’s usually better if the user handles scaling explicitly for more transparent management of the physics.
  • Graph support: InputTargetCondition conditions can also work with Graph objects. This scenario doesn’t seem to be covered in the DataNormalizationCallback implementation or in the related tests. Could you please ensure that this is handled?
  • Tensor-based scale and shift: I am not entirely sure I understand the case where scale or shift are tensors. Could you clarify the intended usage here?

Overall, the current approach feels a bit too complex for managing a relatively simple normalization. Personally, I think it might be cleaner and safer to let the user normalize the data before passing it to the PINA problem. What are your thoughts on this?

Hi, I thinks we should keep normalization inside the training cycle. I have a doubt regarding how the normalization strategy is applied. In particular how can I perform a normalization of the type

z = x − μ t r a i n σ t r a i n

note that this should be applied on all the three dataset and when I define the callback I do not have access to the dataset splits.

I share the same concern as @FilippoOlivo.

Note that to apply the standard normalization, the mean and the standard deviation should be computed outside the Pina pipeline. This way, applying normalization externally would be more straightforward and natural than doing it through a Callback.

@dario-coscia
Copy link
Collaborator Author

I agree with you both. If before training we could access the data, normalizing would be super easy. Unfortunately we can not.. as dataset are created when starting training (do you agree?). Idk what is the best way to do normalization then

@GiovanniCanali
Copy link
Collaborator

I agree with you both. If before training we could access the data, normalizing would be super easy. Unfortunately we can not.. as dataset are created when starting training (do you agree?). Idk what is the best way to do normalization then

I realize this somewhat defeats the purpose of the PR, but it might be simpler to let users handle normalization themselves. In most cases, this shouldn’t be a difficult task, and trying to cover all possible scenarios and data structures here feels like overkill. What do you think @dario-coscia @FilippoOlivo?

@FilippoOlivo
Copy link
Member

FilippoOlivo commented Sep 11, 2025

I guess dataset are created during setup. In my opinion what we can do either accept scale and shift parameters by the user or compute them before training starts inside setup

@FilippoOlivo FilippoOlivo added pr-to-fix Label for PR that needs modification and removed pr-to-review Label for PR that are ready to been reviewed labels Sep 11, 2025
@dario-coscia
Copy link
Collaborator Author

dario-coscia commented Sep 11, 2025

I guess dataset are created during setup. In my opinion what we can do either accept scale and shift parameters by the user or compute them before training starts inside setup

Which scale and shift? Mean and var? Sometimes you want mean and mad. I don't think is super straightforward... Maybe better to let the user pass a scaling function, which given a tensor returns scale and shift, or something similar

@FilippoOlivo
Copy link
Member

sure! in this stage I suggest to give the user only the possibility to provide a function instead of a set of parameters (scale and shift)

@dario-coscia
Copy link
Collaborator Author

@FilippoOlivo I put you as an assignee as well. I agree with you to let the user directly put a function. Are you on it? Let me know how much time it will take, since I would like to merge #636 as soon as possible

@FilippoOlivo FilippoOlivo requested review from FilippoOlivo and GiovanniCanali and removed request for FilippoOlivo September 15, 2025 12:16
@FilippoOlivo FilippoOlivo added enhancement New feature or request and removed pr-to-fix Label for PR that needs modification labels Sep 15, 2025
@FilippoOlivo FilippoOlivo added the pr-to-review Label for PR that are ready to been reviewed label Sep 15, 2025
)
return stage

def setup(self, trainer, pl_module, stage):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer solver instead of pl_module

}

@staticmethod
def _norm_fn(value, scale, shift):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it public and property

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a bit difficult since norm_fn must at least have the dataset you want to normalize

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I see _norm_fn only needs value, scale and shift. By making it public and property we are putting the code in a way that in a future the user could define norm_fn

scaled_value = LabelTensor(scaled_value, value.labels)
return scaled_value

def _scale_data(self, dataset):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename normalize_dataset

@dario-coscia
Copy link
Collaborator Author

@FilippoOlivo I put some comments, but it looks very good!

FilippoOlivo and others added 2 commits September 15, 2025 17:19
* change name files normalizer data callback
@dario-coscia
Copy link
Collaborator Author

Hi @FilippoOlivo ! I updated is_function and also changed the names as suggested by @GiovanniCanali. Only one thing, the tests take around 14 sec, which is a bit too much IMHO. Can we get to 5 sec max?

@FilippoOlivo
Copy link
Member

Hi @FilippoOlivo ! I updated is_function and also changed the names as suggested by @GiovanniCanali. Only one thing, the tests take around 14 sec, which is a bit too much IMHO. Can we get to 5 sec max?

Hi @dario-coscia, I reduced the number of tests. Now it takes around 3s to run

Copy link
Collaborator

@GiovanniCanali GiovanniCanali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is fine, but doc needs to be updated. I will take care of it as soon as possible.

target_2 = torch.rand(20, 1) * 5


class LabelTensorProblem(AbstractProblem):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it easier to use SupervisedProblem and add a condition? The same applies to TensorProblem

@FilippoOlivo
Copy link
Member

Hi @GiovanniCanali thank you for the review. The callback normalises only InputTargetConditions. Other types of conditions are automatically discarded

@dario-coscia
Copy link
Collaborator Author

dario-coscia commented Sep 16, 2025

Indeed

https://github.com/mathLab/PINA/blob/data_normalizer/pina/callback/normalizer_data_callback.py#L112-L116

@GiovanniCanali
Copy link
Collaborator

Hi @GiovanniCanali thank you for the review. The callback normalises only InputTargetConditions. Other types of conditions are automatically discarded

Sorry, I missed it.

@FilippoOlivo FilippoOlivo force-pushed the data_normalizer branch 2 times, most recently from 7f49e59 to 787fa0a Compare September 16, 2025 12:53
@GiovanniCanali
Copy link
Collaborator

I think we can merge @dario-coscia @FilippoOlivo!

@FilippoOlivo
Copy link
Member

I added a test to check that the normalizer raises a NotImplementedError when graphs are present!

@FilippoOlivo FilippoOlivo self-requested a review September 16, 2025 15:04
@dario-coscia dario-coscia merged commit dc808c1 into dev Sep 16, 2025
19 checks passed
@dario-coscia
Copy link
Collaborator Author

Great job all!

@FilippoOlivo FilippoOlivo deleted the data_normalizer branch September 23, 2025 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request pr-to-review Label for PR that are ready to been reviewed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants