Skip to content

Commit 2cacf20

Browse files
GiovanniCanalidario-coscia
authored andcommitted
allow arbitrary loss and reduction
1 parent 53fea6f commit 2cacf20

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
optimizer=None,
7373
scheduler=None,
7474
weighting=None,
75+
loss=None,
7576
eta=0.001,
7677
gamma=0.999,
7778
):
@@ -88,6 +89,9 @@ def __init__(
8889
scheduler is used. Default is ``None``.
8990
:param WeightingInterface weighting: The weighting schema to be used.
9091
If ``None``, no weighting schema is used. Default is ``None``.
92+
:param torch.nn.Module loss: The loss function to be minimized.
93+
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
94+
Default is `None`.
9195
:param float | int eta: The learning rate for the weights of the
9296
residuals. Default is ``0.001``.
9397
:param float gamma: The decay parameter in the update of the weights
@@ -102,7 +106,7 @@ def __init__(
102106
optimizer=optimizer,
103107
scheduler=scheduler,
104108
weighting=weighting,
105-
loss=torch.nn.MSELoss(reduction="none"),
109+
loss=loss,
106110
)
107111

108112
# check consistency
@@ -130,6 +134,12 @@ def __init__(
130134
self.register_buffer(f"weight_{cond}", buffer_tensor)
131135
self.weights[cond] = getattr(self, f"weight_{cond}")
132136

137+
# Extract the reduction method from the loss function
138+
self._reduction = self._loss_fn.reduction
139+
140+
# Set the loss function to return non-aggregated losses
141+
self._loss_fn = type(self._loss_fn)(reduction="none")
142+
133143
def training_step(self, batch, batch_idx, **kwargs):
134144
"""
135145
Solver training step. It computes the optimization cycle and aggregates
@@ -166,7 +176,7 @@ def validation_step(self, batch, **kwargs):
166176

167177
# Aggregate losses for each condition
168178
for cond, loss in losses.items():
169-
losses[cond] = losses[cond].mean()
179+
losses[cond] = self._apply_reduction(loss=losses[cond])
170180

171181
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
172182
self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -189,7 +199,7 @@ def test_step(self, batch, **kwargs):
189199

190200
# Aggregate losses for each condition
191201
for cond, loss in losses.items():
192-
losses[cond] = losses[cond].mean()
202+
losses[cond] = self._apply_reduction(loss=losses[cond])
193203

194204
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
195205
self.store_log("test_loss", loss, self.get_batch_size(batch))
@@ -228,7 +238,9 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
228238
device=res.device,
229239
) % len(self.problem.input_pts[cond])
230240

231-
losses[cond] = (res * self.weights[cond][idx]).mean()
241+
losses[cond] = self._apply_reduction(
242+
loss=(res * self.weights[cond][idx])
243+
)
232244

233245
# store log
234246
self.store_log(
@@ -275,3 +287,26 @@ def _update_weights(self, batch, batch_idx, residuals):
275287
weights = self.weights[cond]
276288
update = self.gamma * weights[idx] + r_norm
277289
weights[idx] = update.detach()
290+
291+
def _apply_reduction(self, loss):
292+
"""
293+
Apply the specified reduction to the loss. The reduction is deferred
294+
until the end of the optimization cycle to allow residual-based weights
295+
to be applied to each point beforehand.
296+
297+
:param torch.Tensor loss: The loss tensor to be reduced.
298+
:return: The reduced loss tensor.
299+
:rtype: torch.Tensor
300+
:raises ValueError: If the reduction method is neither "mean" nor "sum".
301+
"""
302+
# Apply the specified reduction method
303+
if self._reduction == "mean":
304+
return loss.mean()
305+
if self._reduction == "sum":
306+
return loss.sum()
307+
308+
# Raise an error if the reduction method is not recognized
309+
raise ValueError(
310+
f"Unknown reduction: {self._reduction}."
311+
" Supported reductions are 'mean' and 'sum'."
312+
)

tests/test_solver/test_rba_pinn.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ def test_constructor(problem, eta, gamma):
6060
@pytest.mark.parametrize("problem", [problem, inverse_problem])
6161
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
6262
@pytest.mark.parametrize("compile", [True, False])
63-
def test_solver_train(problem, batch_size, compile):
64-
solver = RBAPINN(model=model, problem=problem)
63+
@pytest.mark.parametrize(
64+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
65+
)
66+
def test_solver_train(problem, batch_size, loss, compile):
67+
solver = RBAPINN(model=model, problem=problem, loss=loss)
6568
trainer = Trainer(
6669
solver=solver,
6770
max_epochs=2,
@@ -80,8 +83,11 @@ def test_solver_train(problem, batch_size, compile):
8083
@pytest.mark.parametrize("problem", [problem, inverse_problem])
8184
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
8285
@pytest.mark.parametrize("compile", [True, False])
83-
def test_solver_validation(problem, batch_size, compile):
84-
solver = RBAPINN(model=model, problem=problem)
86+
@pytest.mark.parametrize(
87+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
88+
)
89+
def test_solver_validation(problem, batch_size, loss, compile):
90+
solver = RBAPINN(model=model, problem=problem, loss=loss)
8591
trainer = Trainer(
8692
solver=solver,
8793
max_epochs=2,
@@ -100,8 +106,11 @@ def test_solver_validation(problem, batch_size, compile):
100106
@pytest.mark.parametrize("problem", [problem, inverse_problem])
101107
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
102108
@pytest.mark.parametrize("compile", [True, False])
103-
def test_solver_test(problem, batch_size, compile):
104-
solver = RBAPINN(model=model, problem=problem)
109+
@pytest.mark.parametrize(
110+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
111+
)
112+
def test_solver_test(problem, batch_size, loss, compile):
113+
solver = RBAPINN(model=model, problem=problem, loss=loss)
105114
trainer = Trainer(
106115
solver=solver,
107116
max_epochs=2,

0 commit comments

Comments
 (0)