Skip to content

Commit 31051a2

Browse files
authored
fix missing y_shift in predict_trend (#390)
1 parent 6338c3c commit 31051a2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

neuralprophet/forecaster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def predict_trend(self, df):
10201020
df = df_utils.normalize(df, self.data_params)
10211021
t = torch.from_numpy(np.expand_dims(df["t"].values, 1))
10221022
trend = self.model.trend(t).squeeze().detach().numpy()
1023-
trend = trend * self.data_params["y"].scale
1023+
trend = trend * self.data_params["y"].scale + self.data_params["y"].shift
10241024
return pd.DataFrame({"ds": df["ds"], "trend": trend})
10251025

10261026
def predict_seasonal_components(self, df):

0 commit comments

Comments
 (0)