Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
665825a
initial commit
phoeenniixx Nov 9, 2024
5e57d34
linting
phoeenniixx Nov 9, 2024
e498848
adding some tests and a little in debug in `sLSTM` structure
phoeenniixx Nov 9, 2024
38e4c9c
new baseclass implementation
phoeenniixx Dec 12, 2024
a72c8c6
Update __init__.py
phoeenniixx Dec 13, 2024
b3b3e55
little debug in `predict` method
phoeenniixx Dec 23, 2024
87f4ff4
trying the baseclass predict function and removing the test files
phoeenniixx Dec 24, 2024
a6b2da9
refactor `__init__.py`
phoeenniixx Jan 6, 2025
39e2b6f
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 6, 2025
f67509a
linting
phoeenniixx Jan 6, 2025
46a9e74
Update layer.py
phoeenniixx Jan 6, 2025
7e7d915
docs
phoeenniixx Jan 6, 2025
31cd4de
linting
phoeenniixx Jan 6, 2025
c72bff9
Update __init__.py
phoeenniixx Jan 6, 2025
62e97ae
Update __init__.py
phoeenniixx Jan 6, 2025
93f0913
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 13, 2025
66900bc
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 21, 2025
acb23e7
Adding tests
phoeenniixx Jan 21, 2025
0b85284
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 24, 2025
b01754e
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Feb 10, 2025
5e666b4
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Mar 7, 2025
0a149a7
Merge branch 'main' into pr/1709
fkiraly Jun 6, 2025
9b21892
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 29, 2025
2eda66f
refactor code
phoeenniixx Jul 30, 2025
942e717
add pkg class
phoeenniixx Jul 30, 2025
5556d71
linting
phoeenniixx Jul 30, 2025
2dca593
add docstrings and debug
phoeenniixx Jul 31, 2025
8adcb31
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 31, 2025
1bc559c
add GH credits
phoeenniixx Jul 31, 2025
fd4b2ba
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 31, 2025
6a7cc23
update documentation
phoeenniixx Jul 31, 2025
60d1651
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx Jul 31, 2025
96ec23d
add TriangularCausalMask
phoeenniixx Jul 31, 2025
6a40b7a
refactor files
phoeenniixx Aug 5, 2025
40beee8
Merge branch 'main' into xLSTMTime
phoeenniixx Aug 5, 2025
1cfaf9c
refactor files
phoeenniixx Aug 6, 2025
ed189de
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx Aug 6, 2025
7addfad
update models.rst
phoeenniixx Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
105 changes: 105 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
import math


class mLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None):
super(mLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.layer_norm = layer_norm

self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.Wq = nn.Linear(input_size, hidden_size)
self.Wk = nn.Linear(input_size, hidden_size)
self.Wv = nn.Linear(input_size, hidden_size)

self.Wi = nn.Linear(input_size, hidden_size)
self.Wf = nn.Linear(input_size, hidden_size)
self.Wo = nn.Linear(input_size, hidden_size)

self.Wq.to(self.device)
self.Wk.to(self.device)
self.Wv.to(self.device)
self.Wi.to(self.device)
self.Wf.to(self.device)
self.Wo.to(self.device)

self.dropout = nn.Dropout(dropout)
self.dropout.to(self.device)

if layer_norm:
self.ln_q = nn.LayerNorm(hidden_size)
self.ln_k = nn.LayerNorm(hidden_size)
self.ln_v = nn.LayerNorm(hidden_size)
self.ln_i = nn.LayerNorm(hidden_size)
self.ln_f = nn.LayerNorm(hidden_size)
self.ln_o = nn.LayerNorm(hidden_size)

self.ln_q.to(self.device)
self.ln_k.to(self.device)
self.ln_v.to(self.device)
self.ln_i.to(self.device)
self.ln_f.to(self.device)
self.ln_o.to(self.device)

self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, x, h_prev, c_prev, n_prev):

x = x.to(self.device)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)
n_prev = n_prev.to(self.device)

batch_size = x.size(0)
assert x.dim() == 2, f"Input should be 2D (batch_size, input_size), got {x.dim()}D"
assert h_prev.size() == (batch_size, self.hidden_size), f"h_prev shape mismatch: {h_prev.size()}"
assert c_prev.size() == (batch_size, self.hidden_size), f"c_prev shape mismatch: {c_prev.size()}"
assert n_prev.size() == (batch_size, self.hidden_size), f"n_prev shape mismatch: {n_prev.size()}"

x = self.dropout(x)
h_prev = self.dropout(h_prev)

q = self.Wq(x)
k = self.Wk(x) / math.sqrt(self.hidden_size)
v = self.Wv(x)

if self.layer_norm:
q = self.ln_q(q)
k = self.ln_k(k)
v = self.ln_v(v)

i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x))
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x))
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x))

k_expanded = k.unsqueeze(-1)
v_expanded = v.unsqueeze(-2)

kv_interaction = k_expanded @ v_expanded

kv_sum = kv_interaction.sum(dim=1)

c = f * c_prev + i * kv_sum
n = f * n_prev + i * k

epsilon = 1e-8
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon)
h = o * self.tanh(c * normalized_n)

return h, c, n

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states.
"""
shape = (batch_size, self.hidden_size)
return (
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
)
83 changes: 83 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell


class mLSTMLayer(nn.Module):
def __init__(
self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, device=None
):
super(mLSTMLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.layer_norm = layer_norm
self.residual_conn = residual_conn
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.dropout = nn.Dropout(dropout).to(self.device)

self.cells = nn.ModuleList(
[
mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device)
for i in range(num_layers)
]
)

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states for all layers.
"""
hidden_states, cell_states, norm_states = zip(
*[self.cells[i].init_hidden(batch_size) for i in range(self.num_layers)]
)

return (
torch.stack(hidden_states).to(self.device),
torch.stack(cell_states).to(self.device),
torch.stack(norm_states).to(self.device),
)

def forward(self, x, h=None, c=None, n=None):
"""
Forward pass for the mLSTM layer.
"""

x = x.to(self.device).transpose(0, 1)
batch_size, seq_len, _ = x.size()

if h is None or c is None or n is None:
h, c, n = self.init_hidden(batch_size)

outputs = []

for t in range(seq_len):
layer_input = x[:, t, :]
next_hidden_states = []
next_cell_states = []
next_norm_states = []

for i, cell in enumerate(self.cells):

h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i])

if self.residual_conn and i > 0:
h_i = h_i + layer_input

layer_input = h_i

next_hidden_states.append(h_i)
next_cell_states.append(c_i)
next_norm_states.append(n_i)

h = torch.stack(next_hidden_states).to(self.device)
c = torch.stack(next_cell_states).to(self.device)
n = torch.stack(next_norm_states).to(self.device)

outputs.append(h[-1])

output = torch.stack(outputs, dim=1)

output = output.transpose(0, 1)

return output, (h, c, n)
38 changes: 38 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/mLSTM/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch.nn as nn
import torch
from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer


class mLSTMNetwork(nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_layers,
output_size,
dropout=0.0,
use_layer_norm=True,
use_residual=True,
device=None,
):
super(mLSTMNetwork, self).__init__()
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.mlstm_layer = mLSTMLayer(
input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, self.device
)
self.fc = nn.Linear(hidden_size, output_size)

def forward(self, x, h=None, c=None, n=None):
"""
Forward pass through the mLSTM network.
"""
output, (h, c, n) = self.mlstm_layer(x, h, c, n)

output = self.fc(output[-1])

return output, (h, c, n)

def init_hidden(self, batch_size):
"""Initialize hidden, cell, and normalization states."""
return self.mlstm_layer.init_hidden(batch_size)
Empty file.
94 changes: 94 additions & 0 deletions pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import math


class sLSTMCell(nn.Module):
"""Stabilized LSTM Cell"""

def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None):
super(sLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.use_layer_norm = use_layer_norm
self.eps = 1e-6

self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device)
self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device)

if use_layer_norm:
self.ln_cell = nn.LayerNorm(hidden_size).to(self.device)
self.ln_hidden = nn.LayerNorm(hidden_size).to(self.device)
self.ln_input = nn.LayerNorm(4 * hidden_size).to(self.device)
self.ln_hidden_update = nn.LayerNorm(4 * hidden_size).to(self.device)

self.dropout_layer = nn.Dropout(dropout).to(self.device)

self.reset_parameters()

self.grad_clip = 5.0

self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()

self.to(self.device)

def reset_parameters(self):
"""Initialize parameters using Xavier/Glorot initialization"""
std = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-std, std)

def normalized_exp_gate(self, pre_gate):
"""Compute normalized exponential gate activation"""
centered = pre_gate - torch.mean(pre_gate, dim=1, keepdim=True)
exp_val = torch.exp(torch.clamp(centered, min=-5.0, max=5.0))
normalizer = torch.sum(exp_val, dim=1, keepdim=True) + self.eps
return exp_val / normalizer

def forward(self, x, h_prev, c_prev):
"""Forward pass with stabilized exponential gating"""
x = x.to(self.device)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)

x = self.dropout_layer(x)
h_prev = self.dropout_layer(h_prev)

gates_x = self.input_weights(x)
gates_h = self.hidden_weights(h_prev)

if self.use_layer_norm:
gates_x = self.ln_input(gates_x)
gates_h = self.ln_hidden_update(gates_h)

gates = gates_x + gates_h
i, f, g, o = gates.chunk(4, dim=1)

i = self.normalized_exp_gate(i)
f = self.normalized_exp_gate(f)
gate_sum = i + f
i = i / (gate_sum + self.eps)
f = f / (gate_sum + self.eps)

c_tilde = self.tanh(g)
c = f * c_prev + i * c_tilde
if self.use_layer_norm:
c = self.ln_cell(c)

o = self.sigmoid(o)
c_out = self.tanh(c)
if self.use_layer_norm:
c_out = self.ln_hidden(c_out)
h = o * c_out

return h, c

def init_hidden(self, batch_size):
return (
torch.zeros(batch_size, self.hidden_size, device=self.device),
torch.zeros(batch_size, self.hidden_size, device=self.device),
)
Loading
Loading