From b13b14c8e80a17a6c65fae35d4f3fc96a3180a49 Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Wed, 1 Oct 2025 06:36:13 +0530 Subject: [PATCH 1/3] Standardize output format of NBeats and NBeatsKAN estimators --- pytorch_forecasting/models/nbeats/_nbeats_adapter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index 9ec8d6324..2f8d41d0e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -93,20 +93,20 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: forecast = forecast + forecast_block return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + prediction=self.transform_output(forecast.unsqueeze(-1), target_scale=x["target_scale"]), backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] + prediction=(target - backcast).unsqueeze(-1), target_scale=x["target_scale"] ), trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), + torch.stack(trend_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), + torch.stack(seasonal_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), + torch.stack(generic_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), ) From f67a8286c696dba7e140098c1c0af56405773333 Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Wed, 1 Oct 2025 06:42:27 +0530 Subject: [PATCH 2/3] fixing code quality --- pytorch_forecasting/models/nbeats/_nbeats_adapter.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py index 2f8d41d0e..85e42954e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -93,9 +93,12 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: forecast = forecast + forecast_block return self.to_network_output( - prediction=self.transform_output(forecast.unsqueeze(-1), target_scale=x["target_scale"]), + prediction=self.transform_output( + forecast.unsqueeze(-1), target_scale=x["target_scale"] + ), backcast=self.transform_output( - prediction=(target - backcast).unsqueeze(-1), target_scale=x["target_scale"] + prediction=(target - backcast).unsqueeze(-1), + target_scale=x["target_scale"], ), trend=self.transform_output( torch.stack(trend_forecast, dim=0).sum(0).unsqueeze(-1), From c82b4ed592c1f49a01757dcd7d6c7e31795e74a4 Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Wed, 1 Oct 2025 08:14:57 +0530 Subject: [PATCH 3/3] removed test skip tags for NBeats_pkg and NBeatsKAN_pkg --- pytorch_forecasting/models/nbeats/_nbeats_pkg.py | 1 - pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py index acf300439..daeab1c4e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py @@ -17,7 +17,6 @@ class NBeats_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, - "tests:skip_by_name": "test_integration", } @classmethod diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py index 1ccad7b72..2cda8c996 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -17,7 +17,6 @@ class NBeatsKAN_pkg(_BasePtForecaster): "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, - "tests:skip_by_name": "test_integration", } @classmethod