diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index 041d24044..d757be67b 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -822,6 +822,9 @@ def mask_losses(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: st """ if reduction is None: reduction = self.reduction + + if isinstance(losses, list): losses = losses[0] + if losses.ndim > 0: # mask loss mask = torch.arange(losses.size(1), device=losses.device).unsqueeze(0) >= lengths.unsqueeze(-1)