Skip to content

Commit fd849cc

Browse files
authored
Standardize output format of NBeats and NBeatsKAN estimators (#1977)
This PR updates the NBeatsAdapter to ensure that both NBeats and NBeatsKAN models return 3D tensors (batch_size, prediction_length, 1) for point predictions, aligning with the expected output format #1975.
1 parent b75ea42 commit fd849cc

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

pytorch_forecasting/models/nbeats/_nbeats_adapter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,23 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
9393
forecast = forecast + forecast_block
9494

9595
return self.to_network_output(
96-
prediction=self.transform_output(forecast, target_scale=x["target_scale"]),
96+
prediction=self.transform_output(
97+
forecast.unsqueeze(-1), target_scale=x["target_scale"]
98+
),
9799
backcast=self.transform_output(
98-
prediction=target - backcast, target_scale=x["target_scale"]
100+
prediction=(target - backcast).unsqueeze(-1),
101+
target_scale=x["target_scale"],
99102
),
100103
trend=self.transform_output(
101-
torch.stack(trend_forecast, dim=0).sum(0),
104+
torch.stack(trend_forecast, dim=0).sum(0).unsqueeze(-1),
102105
target_scale=x["target_scale"],
103106
),
104107
seasonality=self.transform_output(
105-
torch.stack(seasonal_forecast, dim=0).sum(0),
108+
torch.stack(seasonal_forecast, dim=0).sum(0).unsqueeze(-1),
106109
target_scale=x["target_scale"],
107110
),
108111
generic=self.transform_output(
109-
torch.stack(generic_forecast, dim=0).sum(0),
112+
torch.stack(generic_forecast, dim=0).sum(0).unsqueeze(-1),
110113
target_scale=x["target_scale"],
111114
),
112115
)

pytorch_forecasting/models/nbeats/_nbeats_pkg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class NBeats_pkg(_BasePtForecaster):
1717
"capability:pred_int": False,
1818
"capability:flexible_history_length": False,
1919
"capability:cold_start": False,
20-
"tests:skip_by_name": "test_integration",
2120
}
2221

2322
@classmethod

pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class NBeatsKAN_pkg(_BasePtForecaster):
1717
"capability:pred_int": False,
1818
"capability:flexible_history_length": False,
1919
"capability:cold_start": False,
20-
"tests:skip_by_name": "test_integration",
2120
}
2221

2322
@classmethod

0 commit comments

Comments
 (0)