diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index 9ec8d6324..85e42954e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -93,20 +93,23 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: forecast = forecast + forecast_block return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + prediction=self.transform_output( + forecast.unsqueeze(-1), target_scale=x["target_scale"] + ), backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] + prediction=(target - backcast).unsqueeze(-1), + target_scale=x["target_scale"], ), trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), + torch.stack(trend_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), + torch.stack(seasonal_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), + torch.stack(generic_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), ) diff --git a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py index acf300439..daeab1c4e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py @@ -17,7 +17,6 @@ class NBeats_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, - "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 1ccad7b72..2cda8c996 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -17,7 +17,6 @@ class NBeatsKAN_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, - "tests:skip_by_name": "test_integration", } @classmethod