File tree Expand file tree Collapse file tree 1 file changed +16
-1
lines changed
pytorch_forecasting/models/nbeats Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -60,9 +60,24 @@ def get_base_test_params(cls):
60
60
61
61
@classmethod
62
62
def _get_test_dataloaders_from (cls , params ):
63
- """Get dataloaders from parameters."""
63
+ loss = params .get ("loss" , None )
64
+ data_loader_kwargs = params .get ("data_loader_kwargs" , {})
65
+ from pytorch_forecasting .metrics import TweedieLoss
64
66
from pytorch_forecasting .tests ._data_scenarios import (
67
+ data_with_covariates ,
65
68
dataloaders_fixed_window_without_covariates ,
69
+ make_dataloaders ,
66
70
)
67
71
72
+ if isinstance (loss , TweedieLoss ):
73
+ dwc = data_with_covariates ()
74
+ dl_default_kwargs = dict (
75
+ target = "target" ,
76
+ time_varying_unknown_reals = ["target" ],
77
+ add_relative_time_idx = False ,
78
+ )
79
+ dl_default_kwargs .update (data_loader_kwargs )
80
+ dataloaders_with_covariates = make_dataloaders (dwc , ** dl_default_kwargs )
81
+ return dataloaders_with_covariates
82
+
68
83
return dataloaders_fixed_window_without_covariates ()
You can’t perform that action at this time.
0 commit comments