Skip to content

Commit 6074011

Browse files
authored
[ENH] Tide model in v2 interface (#1889)
Adds `Tide` model from `dsipts` to `ptf-v2` Created a new folder `tide_dsipts` in `tide` that contains all the necessary parts for the tide
1 parent b5eb779 commit 6074011

File tree

11 files changed

+690
-9
lines changed

11 files changed

+690
-9
lines changed

pytorch_forecasting/data/data_module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def __getitem__(self, idx):
430430
encoder_indices = slice(start_idx, start_idx + enc_length)
431431
decoder_indices = slice(start_idx + enc_length, end_idx)
432432

433-
target_scale = data["target"][encoder_indices]
434-
target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()
433+
target_past = data["target"][encoder_indices]
434+
target_scale = target_past[~torch.isnan(target_past)].abs().mean()
435435
if torch.isnan(target_scale) or target_scale == 0:
436436
target_scale = torch.tensor(1.0)
437437

@@ -503,6 +503,7 @@ def __getitem__(self, idx):
503503
"decoder_lengths": torch.tensor(pred_length),
504504
"decoder_target_lengths": torch.tensor(pred_length),
505505
"groups": data["group"],
506+
"target_past": target_past,
506507
"encoder_time_idx": torch.arange(enc_length),
507508
"decoder_time_idx": torch.arange(enc_length, enc_length + pred_length),
508509
"target_scale": target_scale,
@@ -713,6 +714,7 @@ def collate_fn(batch):
713714
[x["decoder_target_lengths"] for x, _ in batch]
714715
),
715716
"groups": torch.stack([x["groups"] for x, _ in batch]),
717+
"target_past": torch.stack([x["target_past"] for x, _ in batch]),
716718
"encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]),
717719
"decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]),
718720
"target_scale": torch.stack([x["target_scale"] for x, _ in batch]),

pytorch_forecasting/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
FullAttention,
88
TriangularCausalMask,
99
)
10+
from pytorch_forecasting.layers._blocks import ResidualBlock
1011
from pytorch_forecasting.layers._decomposition import SeriesDecomposition
1112
from pytorch_forecasting.layers._embeddings import (
1213
DataEmbedding_inverted,
1314
EnEmbedding,
1415
PositionalEmbedding,
16+
embedding_cat_variables,
1517
)
1618
from pytorch_forecasting.layers._encoders import (
1719
Encoder,
@@ -48,4 +50,6 @@
4850
"sLSTMLayer",
4951
"sLSTMNetwork",
5052
"SeriesDecomposition",
53+
"ResidualBlock",
54+
"embedding_cat_variables",
5155
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pytorch_forecasting.layers._blocks._residual_block_dsipts import ResidualBlock
2+
3+
__all__ = ["ResidualBlock"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch.nn as nn
2+
3+
4+
class ResidualBlock(nn.Module):
5+
def __init__(
6+
self, in_size: int, out_size: int, dropout_rate: float, activation_fun: str = ""
7+
):
8+
"""Residual Block as basic layer of the archetecture.
9+
10+
MLP with one hidden layer, activation and skip connection
11+
Basically dimension d_model, but better if input_dim and output_dim are explicit
12+
13+
in_size and out_size to handle dimensions at different stages of the NN
14+
15+
Parameters
16+
----------
17+
in_size: int
18+
input size
19+
out_size: int
20+
output size
21+
dropout_rate: float
22+
dropout
23+
activation_fun: str, Optional
24+
activation function to use in the Residual Block. Defaults to nn.ReLU.
25+
""" # noqa: E501
26+
import ast
27+
28+
super().__init__()
29+
30+
self.direct_linear = nn.Linear(in_size, out_size, bias=False)
31+
32+
if activation_fun == "":
33+
self.act = nn.ReLU()
34+
else:
35+
activation = ast.literal_eval(activation_fun)
36+
self.act = activation()
37+
self.lin = nn.Linear(in_size, out_size)
38+
self.dropout = nn.Dropout(dropout_rate)
39+
40+
self.final_norm = nn.LayerNorm(out_size)
41+
42+
def forward(self, x, apply_final_norm=True):
43+
direct_x = self.direct_linear(x)
44+
45+
x = self.dropout(self.lin(self.act(x)))
46+
47+
out = x + direct_x
48+
if apply_final_norm:
49+
return self.final_norm(out)
50+
return out

pytorch_forecasting/layers/_embeddings/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,11 @@
99
from pytorch_forecasting.layers._embeddings._positional_embedding import (
1010
PositionalEmbedding,
1111
)
12+
from pytorch_forecasting.layers._embeddings._sub_nn import embedding_cat_variables
1213

13-
__all__ = ["PositionalEmbedding", "DataEmbedding_inverted", "EnEmbedding"]
14+
__all__ = [
15+
"PositionalEmbedding",
16+
"DataEmbedding_inverted",
17+
"EnEmbedding",
18+
"embedding_cat_variables",
19+
]
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import Union
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class embedding_cat_variables(nn.Module):
8+
# at the moment cat_past and cat_fut together
9+
def __init__(self, seq_len: int, lag: int, d_model: int, emb_dims: list, device):
10+
"""Class for embedding categorical variables, adding 3 positional variables during forward
11+
12+
Parameters
13+
----------
14+
seq_len: int
15+
length of the sequence (sum of past and future steps)
16+
lag: (int):
17+
number of future step to be predicted
18+
hiden_size: int
19+
dimension of all variables after they are embedded
20+
emb_dims: list
21+
size of the dictionary for embedding. One dimension for each categorical variable
22+
device : torch.device
23+
""" # noqa: E501
24+
super().__init__()
25+
self.seq_len = seq_len
26+
self.lag = lag
27+
self.device = device
28+
self.cat_embeds = emb_dims + [seq_len, lag + 1, 2] #
29+
self.cat_n_embd = nn.ModuleList(
30+
[nn.Embedding(emb_dim, d_model) for emb_dim in self.cat_embeds]
31+
)
32+
33+
def forward(
34+
self, x: Union[torch.Tensor, int], device: torch.device
35+
) -> torch.Tensor:
36+
"""All components of x are concatenated with 3 new variables for data augmentation, in the order:
37+
38+
- pos_seq: assign at each step its time-position
39+
- pos_fut: assign at each step its future position. 0 if it is a past step
40+
- is_fut: explicit for each step if it is a future(1) or past one(0)
41+
42+
Parameters
43+
----------
44+
x: torch.Tensor
45+
`[bs, seq_len, num_vars]`
46+
47+
Returns
48+
------
49+
torch.Tensor:
50+
`[bs, seq_len, num_vars+3, n_embd]`
51+
""" # noqa: E501
52+
if isinstance(x, int):
53+
no_emb = True
54+
B = x
55+
else:
56+
no_emb = False
57+
B, _, _ = x.shape
58+
59+
pos_seq = self.get_pos_seq(bs=B).to(device)
60+
pos_fut = self.get_pos_fut(bs=B).to(device)
61+
is_fut = self.get_is_fut(bs=B).to(device)
62+
63+
if no_emb:
64+
cat_vars = torch.cat((pos_seq, pos_fut, is_fut), dim=2)
65+
else:
66+
cat_vars = torch.cat((x, pos_seq, pos_fut, is_fut), dim=2)
67+
cat_vars = cat_vars.long()
68+
cat_n_embd = self.get_cat_n_embd(cat_vars)
69+
return cat_n_embd
70+
71+
def get_pos_seq(self, bs):
72+
pos_seq = torch.arange(0, self.seq_len)
73+
pos_seq = pos_seq.repeat(bs, 1).unsqueeze(2).to(self.device)
74+
return pos_seq
75+
76+
def get_pos_fut(self, bs):
77+
pos_fut = torch.cat(
78+
(
79+
torch.zeros((self.seq_len - self.lag), dtype=torch.long),
80+
torch.arange(1, self.lag + 1),
81+
)
82+
)
83+
pos_fut = pos_fut.repeat(bs, 1).unsqueeze(2).to(self.device)
84+
return pos_fut
85+
86+
def get_is_fut(self, bs):
87+
is_fut = torch.cat(
88+
(
89+
torch.zeros((self.seq_len - self.lag), dtype=torch.long),
90+
torch.ones((self.lag), dtype=torch.long),
91+
)
92+
)
93+
is_fut = is_fut.repeat(bs, 1).unsqueeze(2).to(self.device)
94+
return is_fut
95+
96+
def get_cat_n_embd(self, cat_vars):
97+
cat_n_embd = torch.Tensor().to(cat_vars.device)
98+
for index, layer in enumerate(self.cat_n_embd):
99+
emb = layer(cat_vars[:, :, index])
100+
cat_n_embd = torch.cat((cat_n_embd, emb.unsqueeze(2)), dim=2)
101+
return cat_n_embd
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
"""Tide model."""
22

33
from pytorch_forecasting.models.tide._tide import TiDEModel
4+
from pytorch_forecasting.models.tide._tide_dsipts import TIDE, TIDE_pkg_v2
45
from pytorch_forecasting.models.tide._tide_pkg import TiDEModel_pkg
56
from pytorch_forecasting.models.tide.sub_modules import _TideModule
67

7-
__all__ = [
8-
"_TideModule",
9-
"TiDEModel",
10-
"TiDEModel_pkg",
11-
]
8+
__all__ = ["_TideModule", "TiDEModel", "TiDEModel_pkg", "TIDE", "TIDE_pkg_v2"]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""DSIPTS Tide Implementation for V2"""
2+
3+
from pytorch_forecasting.models.tide._tide_dsipts._tide_v2 import TIDE
4+
from pytorch_forecasting.models.tide._tide_dsipts._tide_v2_pkg import TIDE_pkg_v2
5+
6+
__all__ = ["TIDE", "TIDE_pkg_v2"]

0 commit comments

Comments
 (0)