Skip to content

[DO NOT MERGE] Experimental PR to demonstrate TimeXer model usage with v2 design #1830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 47 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
0319c29
Example notebook
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
5d80532
Merge branch 'refactor-d1-d2' into refactor-notebook
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
6290dc2
Merge branch 'refactor-d1-d2' into refactor-notebook
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
a23ad8a
Merge branch 'refactor-d1-d2' into refactor-notebook
phoeenniixx Apr 10, 2025
25bc7ee
update notebook as well
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
7a175e9
Merge branch 'refactor-d1-d2' into refactor-notebook
phoeenniixx Apr 11, 2025
8dfcac1
update comments in nb
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
e91cd16
Merge branch 'refactor-model' into refactor-notebook
phoeenniixx Apr 19, 2025
9392c81
update notebook
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
6cb7496
Merge branch 'refactor-model' into refactor-notebook
phoeenniixx May 11, 2025
d6e62bb
update notebook
phoeenniixx May 11, 2025
f195716
add usage notebook for v2 version of timexer
May 12, 2025
2c14517
dummy commit
May 13, 2025
53b7db6
Merge branch 'main' into tslib-v2-expt
May 13, 2025
d84e1a0
dummy commit to trigger code quality checks
May 13, 2025
f87842f
fix lint issues
May 14, 2025
845c3b0
Merge branch 'main' into 'tslib-v2-expt'
PranavBhatP May 30, 2025
21e1c63
fix deprcated syntax to comply with latest code-quality checks
PranavBhatP May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,018 changes: 1,018 additions & 0 deletions examples/ptf_V2_example.ipynb

Large diffs are not rendered by default.

1,677 changes: 1,677 additions & 0 deletions examples/tslib_v2_example.ipynb

Large diffs are not rendered by default.

110 changes: 48 additions & 62 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,50 +108,33 @@ def __init__(
num_workers: int = 0,
train_val_test_split: tuple = (0.7, 0.15, 0.15),
):

super().__init__()
self.time_series_dataset = time_series_dataset
self.time_series_metadata = time_series_dataset.get_metadata()

self.max_encoder_length = max_encoder_length
self.min_encoder_length = min_encoder_length
self.min_encoder_length = min_encoder_length or max_encoder_length
self.max_prediction_length = max_prediction_length
self.min_prediction_length = min_prediction_length
self.min_prediction_length = min_prediction_length or max_prediction_length
self.min_prediction_idx = min_prediction_idx

self.allow_missing_timesteps = allow_missing_timesteps
self.add_relative_time_idx = add_relative_time_idx
self.add_target_scales = add_target_scales
self.add_encoder_length = add_encoder_length
self.randomize_length = randomize_length
self.target_normalizer = target_normalizer
self.categorical_encoders = categorical_encoders
self.scalers = scalers

self.batch_size = batch_size
self.num_workers = num_workers
self.train_val_test_split = train_val_test_split

warn(
"TimeSeries is part of an experimental rework of the "
"pytorch-forecasting data layer, "
"scheduled for release with v2.0.0. "
"The API is not stable and may change without prior warning. "
"For beta testing, but not for stable production use. "
"Feedback and suggestions are very welcome in "
"pytorch-forecasting issue 1736, "
"https://github.com/sktime/pytorch-forecasting/issues/1736",
UserWarning,
)

super().__init__()

# handle defaults and derived attributes
if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto":
self._target_normalizer = RobustScaler()
self.target_normalizer = RobustScaler()
else:
self._target_normalizer = target_normalizer
self.target_normalizer = target_normalizer

self.time_series_metadata = time_series_dataset.get_metadata()
self._min_prediction_length = min_prediction_length or max_prediction_length
self._min_encoder_length = min_encoder_length or max_encoder_length
self._categorical_encoders = _coerce_to_dict(categorical_encoders)
self._scalers = _coerce_to_dict(scalers)
self.categorical_encoders = _coerce_to_dict(categorical_encoders)
self.scalers = _coerce_to_dict(scalers)

self.categorical_indices = []
self.continuous_indices = []
Expand All @@ -171,38 +154,39 @@ def _prepare_metadata(self):
dict
dictionary containing the following keys:

* ``encoder_cat``: Number of categorical variables in the encoder.
Computed as ``len(self.categorical_indices)``, which counts the
categorical feature indices.
* ``encoder_cont``: Number of continuous variables in the encoder.
Computed as ``len(self.continuous_indices)``, which counts the
continuous feature indices.
* ``decoder_cat``: Number of categorical variables in the decoder that
are known in advance.
Computed by filtering ``self.time_series_metadata["cols"]["x"]``
where col_type == "C"(categorical) and col_known == "K" (known)
* ``decoder_cont``: Number of continuous variables in the decoder that
are known in advance.
Computed by filtering ``self.time_series_metadata["cols"]["x"]``
where col_type == "F"(continuous) and col_known == "K"(known)
* ``target``: Number of target variables.
Computed as ``len(self.time_series_metadata["cols"]["y"])``, which
gives the number of output target columns..
* ``static_categorical_features``: Number of static categorical features
Computed by filtering ``self.time_series_metadata["cols"]["st"]``
(static features) where col_type == "C" (categorical).
* ``static_continuous_features``: Number of static continuous features
Computed as difference of
``len(self.time_series_metadata["cols"]["st"])`` (static features)
and static_categorical_features that gives static continuous feature
* ``max_encoder_length``: maximum encoder length
Taken directly from `self.max_encoder_length`.
* ``max_prediction_length``: maximum prediction length
Taken directly from `self.max_prediction_length`.
* ``min_encoder_length``: minimum encoder length
Taken directly from `self.min_encoder_length`.
* ``min_prediction_length``: minimum prediction length
Taken directly from `self.min_prediction_length`.
* ``encoder_cat``: Number of categorical variables in the encoder.
Computed as ``len(self.categorical_indices)``, which counts the
categorical feature indices.
* ``encoder_cont``: Number of continuous variables in the encoder.
Computed as ``len(self.continuous_indices)``, which counts the
continuous feature indices.
* ``decoder_cat``: Number of categorical variables in the decoder that
are known in advance.
Computed by filtering ``self.time_series_metadata["cols"]["x"]``
where col_type == "C"(categorical) and col_known == "K" (known)
* ``decoder_cont``: Number of continuous variables in the decoder that
are known in advance.
Computed by filtering ``self.time_series_metadata["cols"]["x"]``
where col_type == "F"(continuous) and col_known == "K"(known)
* ``target``: Number of target variables.
Computed as ``len(self.time_series_metadata["cols"]["y"])``, which
gives the number of output target columns..
* ``static_categorical_features``: Number of static categorical features
Computed by filtering ``self.time_series_metadata["cols"]["st"]``
(static features) where col_type == "C" (categorical).
* ``static_continuous_features``: Number of static continuous features
Computed as difference of
``len(self.time_series_metadata["cols"]["st"])`` (static features)
and static_categorical_features that gives static continuous feature
* ``max_encoder_length``: maximum encoder length
Taken directly from `self.max_encoder_length`.
* ``max_prediction_length``: maximum prediction length
Taken directly from `self.max_prediction_length`.
* ``min_encoder_length``: minimum encoder length
Taken directly from `self.min_encoder_length`.
* ``min_prediction_length``: minimum prediction length
Taken directly from `self.min_prediction_length`.

"""
encoder_cat_count = len(self.categorical_indices)
encoder_cont_count = len(self.continuous_indices)
Expand Down Expand Up @@ -254,8 +238,8 @@ def _prepare_metadata(self):
{
"max_encoder_length": self.max_encoder_length,
"max_prediction_length": self.max_prediction_length,
"min_encoder_length": self._min_encoder_length,
"min_prediction_length": self._min_prediction_length,
"min_encoder_length": self.min_encoder_length,
"min_prediction_length": self.min_prediction_length,
}
)

Expand Down Expand Up @@ -504,6 +488,7 @@ def __getitem__(self, idx):
"decoder_lengths": torch.tensor(pred_length),
"decoder_target_lengths": torch.tensor(pred_length),
"groups": data["group"],
"target": data["target"][encoder_indices],
"encoder_time_idx": torch.arange(enc_length),
"decoder_time_idx": torch.arange(enc_length, enc_length + pred_length),
"target_scale": target_scale,
Expand Down Expand Up @@ -714,6 +699,7 @@ def collate_fn(batch):
[x["decoder_target_lengths"] for x, _ in batch]
),
"groups": torch.stack([x["groups"] for x, _ in batch]),
"target": torch.stack([x["target"] for x, _ in batch]),
"encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]),
"decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]),
"target_scale": torch.stack([x["target_scale"] for x, _ in batch]),
Expand Down
124 changes: 44 additions & 80 deletions pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ def __init__(
self.data = data
self.data_future = data_future
self.time = time
self.target = target
self.group = group
self.target = _coerce_to_list(target)
self.group = _coerce_to_list(group)
self.weight = weight
self.num = num
self.cat = cat
self.known = known
self.unknown = unknown
self.static = static
self.num = _coerce_to_list(num)
self.cat = _coerce_to_list(cat)
self.known = _coerce_to_list(known)
self.unknown = _coerce_to_list(unknown)
self.static = _coerce_to_list(static)

warn(
"TimeSeries is part of an experimental rework of the "
Expand All @@ -115,41 +115,20 @@ def __init__(
UserWarning,
)

super().__init__()

# handle defaults, coercion, and derived attributes
self._target = _coerce_to_list(target)
self._group = _coerce_to_list(group)
self._num = _coerce_to_list(num)
self._cat = _coerce_to_list(cat)
self._known = _coerce_to_list(known)
self._unknown = _coerce_to_list(unknown)
self._static = _coerce_to_list(static)

self.feature_cols = [
col
for col in data.columns
if col not in [self.time] + self._group + [self.weight] + self._target
if col not in [self.time] + self.group + [self.weight] + self.target
]
if self._group:
self._groups = self.data.groupby(self._group).groups
if self.group:
self._groups = self.data.groupby(self.group).groups
self._group_ids = list(self._groups.keys())
else:
self._groups = {"_single_group": self.data.index}
self._group_ids = ["_single_group"]

self._prepare_metadata()

# overwrite __init__ params for upwards compatibility with AS PRs
# todo: should we avoid this and ensure classes are dataclass-like?
self.group = self._group
self.target = self._target
self.num = self._num
self.cat = self._cat
self.known = self._known
self.unknown = self._unknown
self.static = self._static

def _prepare_metadata(self):
"""Prepare metadata for the dataset.

Expand All @@ -169,19 +148,19 @@ def _prepare_metadata(self):
"""
self.metadata = {
"cols": {
"y": self._target,
"y": self.target,
"x": self.feature_cols,
"st": self._static,
"st": self.static,
},
"col_type": {},
"col_known": {},
}

all_cols = self._target + self.feature_cols + self._static
all_cols = self.target + self.feature_cols + self.static
for col in all_cols:
self.metadata["col_type"][col] = "C" if col in self._cat else "F"
self.metadata["col_type"][col] = "C" if col in self.cat else "F"

self.metadata["col_known"][col] = "K" if col in self._known else "U"
self.metadata["col_known"][col] = "K" if col in self.known else "U"

def __len__(self) -> int:
"""Return number of time series in the dataset."""
Expand Down Expand Up @@ -216,69 +195,54 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
weights : torch.Tensor of shape (n_timepoints,), optional
Only included if weights are not `None`.
"""
time = self.time
feature_cols = self.feature_cols
_target = self._target
_known = self._known
_static = self._static
_group = self._group
_groups = self._groups
_group_ids = self._group_ids
weight = self.weight
data_future = self.data_future

group_id = _group_ids[index]

if _group:
mask = _groups[group_id]
group_id = self._group_ids[index]

if self.group:
mask = self._groups[group_id]
data = self.data.loc[mask]
else:
data = self.data

cutoff_time = data[time].max()

data_vals = data[time].values
data_tgt_vals = data[_target].values
data_feat_vals = data[feature_cols].values
cutoff_time = data[self.time].max()

result = {
"t": data_vals,
"y": torch.tensor(data_tgt_vals),
"x": torch.tensor(data_feat_vals),
"t": data[self.time].values,
"y": torch.tensor(data[self.target].values),
"x": torch.tensor(data[self.feature_cols].values),
"group": torch.tensor([hash(str(group_id))]),
"st": torch.tensor(data[_static].iloc[0].values if _static else []),
"st": torch.tensor(data[self.static].iloc[0].values if self.static else []),
"cutoff_time": cutoff_time,
}

if data_future is not None:
if _group:
future_mask = self.data_future.groupby(_group).groups[group_id]
if self.data_future is not None:
if self.group:
future_mask = self.data_future.groupby(self.group).groups[group_id]
future_data = self.data_future.loc[future_mask]
else:
future_data = self.data_future

data_fut_vals = future_data[time].values

combined_times = np.concatenate([data_vals, data_fut_vals])
combined_times = np.concatenate(
[data[self.time].values, future_data[self.time].values]
)
combined_times = np.unique(combined_times)
combined_times.sort()

num_timepoints = len(combined_times)
x_merged = np.full((num_timepoints, len(feature_cols)), np.nan)
y_merged = np.full((num_timepoints, len(_target)), np.nan)
x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)
y_merged = np.full((num_timepoints, len(self.target)), np.nan)

current_time_indices = {t: i for i, t in enumerate(combined_times)}
for i, t in enumerate(data_vals):
for i, t in enumerate(data[self.time].values):
idx = current_time_indices[t]
x_merged[idx] = data_feat_vals[i]
y_merged[idx] = data_tgt_vals[i]
x_merged[idx] = data[self.feature_cols].values[i]
y_merged[idx] = data[self.target].values[i]

for i, t in enumerate(data_fut_vals):
for i, t in enumerate(future_data[self.time].values):
if t in current_time_indices:
idx = current_time_indices[t]
for j, col in enumerate(_known):
if col in feature_cols:
feature_idx = feature_cols.index(col)
for j, col in enumerate(self.known):
if col in self.feature_cols:
feature_idx = self.feature_cols.index(col)
x_merged[idx, feature_idx] = future_data[col].values[i]

result.update(
Expand All @@ -289,17 +253,17 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
}
)

if weight:
if self.weight:
if self.data_future is not None and self.weight in self.data_future.columns:
weights_merged = np.full(num_timepoints, np.nan)
for i, t in enumerate(data_vals):
for i, t in enumerate(data[self.time].values):
idx = current_time_indices[t]
weights_merged[idx] = data[weight].values[i]
weights_merged[idx] = data[self.weight].values[i]

for i, t in enumerate(data_fut_vals):
for i, t in enumerate(future_data[self.time].values):
if t in current_time_indices and self.weight in future_data.columns:
idx = current_time_indices[t]
weights_merged[idx] = future_data[weight].values[i]
weights_merged[idx] = future_data[self.weight].values[i]

result["weights"] = torch.tensor(weights_merged, dtype=torch.float32)
else:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TemporalFusionTransformer,
)
from pytorch_forecasting.models.tide import TiDEModel
from pytorch_forecasting.models.timexer import TimeXer

__all__ = [
"NBeats",
Expand All @@ -37,4 +38,5 @@
"MultiEmbedding",
"DecoderMLP",
"TiDEModel",
"TimeXer",
]
Loading
Loading