@@ -72,6 +72,7 @@ def __init__(
7272 optimizer = None ,
7373 scheduler = None ,
7474 weighting = None ,
75+ loss = None ,
7576 eta = 0.001 ,
7677 gamma = 0.999 ,
7778 ):
@@ -88,6 +89,9 @@ def __init__(
8889 scheduler is used. Default is ``None``.
8990 :param WeightingInterface weighting: The weighting schema to be used.
9091 If ``None``, no weighting schema is used. Default is ``None``.
92+ :param torch.nn.Module loss: The loss function to be minimized.
93+ If ``None``, the :class:`torch.nn.MSELoss` loss is used.
94+ Default is `None`.
9195 :param float | int eta: The learning rate for the weights of the
9296 residuals. Default is ``0.001``.
9397 :param float gamma: The decay parameter in the update of the weights
@@ -102,7 +106,7 @@ def __init__(
102106 optimizer = optimizer ,
103107 scheduler = scheduler ,
104108 weighting = weighting ,
105- loss = torch . nn . MSELoss ( reduction = "none" ) ,
109+ loss = loss ,
106110 )
107111
108112 # check consistency
@@ -130,6 +134,12 @@ def __init__(
130134 self .register_buffer (f"weight_{ cond } " , buffer_tensor )
131135 self .weights [cond ] = getattr (self , f"weight_{ cond } " )
132136
137+ # Extract the reduction method from the loss function
138+ self ._reduction = self ._loss_fn .reduction
139+
140+ # Set the loss function to return non-aggregated losses
141+ self ._loss_fn = type (self ._loss_fn )(reduction = "none" )
142+
133143 def training_step (self , batch , batch_idx , ** kwargs ):
134144 """
135145 Solver training step. It computes the optimization cycle and aggregates
@@ -166,7 +176,7 @@ def validation_step(self, batch, **kwargs):
166176
167177 # Aggregate losses for each condition
168178 for cond , loss in losses .items ():
169- losses [cond ] = losses [cond ]. mean ( )
179+ losses [cond ] = self . _apply_reduction ( loss = losses [cond ])
170180
171181 loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
172182 self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
@@ -189,7 +199,7 @@ def test_step(self, batch, **kwargs):
189199
190200 # Aggregate losses for each condition
191201 for cond , loss in losses .items ():
192- losses [cond ] = losses [cond ]. mean ( )
202+ losses [cond ] = self . _apply_reduction ( loss = losses [cond ])
193203
194204 loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
195205 self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
@@ -228,7 +238,9 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
228238 device = res .device ,
229239 ) % len (self .problem .input_pts [cond ])
230240
231- losses [cond ] = (res * self .weights [cond ][idx ]).mean ()
241+ losses [cond ] = self ._apply_reduction (
242+ loss = (res * self .weights [cond ][idx ])
243+ )
232244
233245 # store log
234246 self .store_log (
@@ -275,3 +287,26 @@ def _update_weights(self, batch, batch_idx, residuals):
275287 weights = self .weights [cond ]
276288 update = self .gamma * weights [idx ] + r_norm
277289 weights [idx ] = update .detach ()
290+
291+ def _apply_reduction (self , loss ):
292+ """
293+ Apply the specified reduction to the loss. The reduction is deferred
294+ until the end of the optimization cycle to allow residual-based weights
295+ to be applied to each point beforehand.
296+
297+ :param torch.Tensor loss: The loss tensor to be reduced.
298+ :return: The reduced loss tensor.
299+ :rtype: torch.Tensor
300+ :raises ValueError: If the reduction method is neither "mean" nor "sum".
301+ """
302+ # Apply the specified reduction method
303+ if self ._reduction == "mean" :
304+ return loss .mean ()
305+ if self ._reduction == "sum" :
306+ return loss .sum ()
307+
308+ # Raise an error if the reduction method is not recognized
309+ raise ValueError (
310+ f"Unknown reduction: { self ._reduction } ."
311+ " Supported reductions are 'mean' and 'sum'."
312+ )
0 commit comments