Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions pytorch_forecasting/models/base/_base_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,56 @@ class _BasePtForecasterV2(_BasePtForecaster_Common):
_tags = {
"object_type": "forecaster_pytorch_v2",
}


class _EncoderDecoderConfigBase(_BasePtForecasterV2):
def _check_metadata(self, metadata):
assert isinstance(metadata, dict)
required_keys = [
"encoder_cat",
"encoder_cont",
"decoder_cat",
"decoder_cont",
"target",
"max_encoder_length",
"min_encoder_length",
"max_prediction_length",
"min_prediction_length",
"static_categorical_features",
"static_continuous_features",
]

for key in required_keys:
assert key in metadata, f"Key {key} missing in metadata"

assert metadata["encoder_cat"] >= 0
assert metadata["encoder_cont"] >= 0
assert metadata["decoder_cat"] >= 0
assert metadata["decoder_cont"] >= 0
assert metadata["target"] > 0


class _TSlibConfigBase(_BasePtForecasterV2):
def _check_metadata(self, metadata):
assert isinstance(metadata, dict)
required_keys = [
"feature_names",
"feature_indices",
"n_features",
"context_length",
"prediction_length",
"freq",
"features",
]

for key in required_keys:
assert key in metadata, f"Key {key} missing in metadata"

assert (
metadata["n_features"]
== len(metadata["feature_names"])
== len(metadata["feature_indices"])
)
assert metadata["context_length"] > 0
assert metadata["prediction_length"] > 0
assert metadata["freq"] is not None
11 changes: 10 additions & 1 deletion pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Packages container for DLinear model.
"""

from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2
from pytorch_forecasting.models.base._base_object import (
_BasePtForecasterV2,
_TSlibConfigBase,
)


class DLinear_pkg_v2(_BasePtForecasterV2):
Expand Down Expand Up @@ -125,3 +128,9 @@ def get_test_train_params(cls):
logging_metrics=[SMAPE()],
),
]


class DLinear_pkg_v2_metadata(_TSlibConfigBase):
@classmethod
def _check_metadata_dlinear(self, metadata):
super()._check_metadata(metadata)
11 changes: 10 additions & 1 deletion pytorch_forecasting/models/samformer/_samformer_v2_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Samformer package container.
"""

from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2
from pytorch_forecasting.models.base._base_object import (
_BasePtForecasterV2,
_EncoderDecoderConfigBase,
)


class Samformer_pkg_v2(_BasePtForecasterV2):
Expand Down Expand Up @@ -134,3 +137,9 @@ def get_test_train_params(cls):
"use_revin": False,
},
]


class Samformer_pkg_v2_metadata(_EncoderDecoderConfigBase):
@classmethod
def _check_metadata_samformer(self, metadata):
super()._check_metadata(metadata)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""TFT package container."""

from pytorch_forecasting.models.base import _BasePtForecasterV2
from pytorch_forecasting.models.base._base_object import _EncoderDecoderConfigBase


class TFT_pkg_v2(_BasePtForecasterV2):
Expand Down Expand Up @@ -137,3 +138,9 @@ def get_test_train_params(cls):
),
dict(attention_head_size=2),
]


class TFT_pkg_v2_metadata(_EncoderDecoderConfigBase):
@classmethod
def _check_metadata_tft(self, metadata):
super()._check_metadata(metadata)
11 changes: 10 additions & 1 deletion pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""TIDE package container."""

from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2
from pytorch_forecasting.models.base._base_object import (
_BasePtForecasterV2,
_EncoderDecoderConfigBase,
)


class TIDE_pkg_v2(_BasePtForecasterV2):
Expand Down Expand Up @@ -138,3 +141,9 @@ def get_test_train_params(cls):
loss=MAPE(),
),
]


class TIDE_pkg_v2_metadata(_EncoderDecoderConfigBase):
@classmethod
def _check_metadata_tide(self, metadata):
super()._check_metadata(metadata)
11 changes: 10 additions & 1 deletion pytorch_forecasting/models/timexer/_timexer_pkg_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Metadata container for TimeXer v2.
"""

from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2
from pytorch_forecasting.models.base._base_object import (
_BasePtForecasterV2,
_TSlibConfigBase,
)


class TimeXer_pkg_v2(_BasePtForecasterV2):
Expand Down Expand Up @@ -163,3 +166,9 @@ def get_test_train_params(cls):
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
),
]


class TimeXer_pkg_v2_metadata(_TSlibConfigBase):
@classmethod
def _check_metadata_timexer(self, metadata):
super()._check_metadata(metadata)
15 changes: 15 additions & 0 deletions pytorch_forecasting/tests/test_all_estimators_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,18 @@ def test_pkg_linkage(self, object_pkg, object_class):
f"{object_class.__name__}_pkg."
)
assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg

def test_d2_metadata(self, object_pkg, trainer_kwargs):
object_class = object_pkg.get_cls()
dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs)
data_module = dataloaders.get("data_module")
metadata = data_module.metadata

model_kwargs = dict(trainer_kwargs)
model_kwargs.pop("data_loader_kwargs", None)

model_name = object_class.__name__

check_method_name = f"_check_metadata_{model_name.lower()}"
if hasattr(object_pkg, check_method_name):
getattr(object_pkg, check_method_name)(metadata)
Loading