Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3bab673
Refactored NBeats and added comments for KAN block and NBeats.
Sohaib-Ahmed21 Jan 13, 2025
acfa626
Refactored NBeats and added comments for KAN block and NBeats.
Sohaib-Ahmed21 Jan 13, 2025
41d7403
End to end integrated Kolmogorov Arnold Networks in NBeats. Also refa…
Sohaib-Ahmed21 Jan 13, 2025
53fb126
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jan 13, 2025
594102d
Resolved import error.
Sohaib-Ahmed21 Jan 13, 2025
45c63f6
Merge branch 'kan-nbeats' of github.com:Sohaib-Ahmed21/pytorch-foreca…
Sohaib-Ahmed21 Jan 13, 2025
88de705
Merge branch 'main' of https://github.com/Sohaib-Ahmed21/pytorch-fore…
Sohaib-Ahmed21 Jan 22, 2025
c8ccfaf
Refactored NBEATS and added support for grid updation during training…
Sohaib-Ahmed21 Jan 23, 2025
348da97
Refactored comments.
Sohaib-Ahmed21 Jan 23, 2025
09facba
Merge branch 'sktime:main' into kan-nbeats
Sohaib-Ahmed21 Feb 1, 2025
1ab0da0
Added example to use grid_update_callback and added correct device to…
Sohaib-Ahmed21 Feb 1, 2025
05350c2
Refactored code for NBEATSKAN and introduced it as separate model/ent…
Sohaib-Ahmed21 Feb 20, 2025
7070f8b
Made modules private.
Sohaib-Ahmed21 Feb 23, 2025
0219fc3
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Feb 25, 2025
ca78516
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Apr 5, 2025
e4f8790
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 May 11, 2025
dd8358d
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 May 22, 2025
315b819
Resolved merge conflicts with main
Sohaib-Ahmed21 Jul 5, 2025
2a4d3ec
Merge branch 'sktime:main' into kan-nbeats
Sohaib-Ahmed21 Jul 5, 2025
14ca66f
Address deprecated typing classes
Sohaib-Ahmed21 Jul 5, 2025
89a9a4f
Refactor code with proper docstrings and cleaner structure
Sohaib-Ahmed21 Jul 5, 2025
0c43448
Refactor examples in docstring
Sohaib-Ahmed21 Jul 5, 2025
2da4d13
Include NBEATSKAN package container
Sohaib-Ahmed21 Jul 6, 2025
eb9c79d
Refactor and enhance docstrings to follow NumPy style, include KAN re…
Sohaib-Ahmed21 Jul 7, 2025
cc819ff
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jul 10, 2025
5a31a58
Merge branch 'main' into kan-nbeats
fkiraly Jul 10, 2025
58e14d8
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jul 12, 2025
c77a44e
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 12, 2025
f533241
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 16, 2025
ec4844e
Restructure KAN and NBeats layers to include them in pytorch_forecast…
Sohaib-Ahmed21 Aug 17, 2025
c792dc3
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 21, 2025
698f242
rename get_cls
fkiraly Aug 26, 2025
1570e02
add _pkg pointer
fkiraly Aug 26, 2025
33284b0
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 26, 2025
a6929b1
Update _nbeatskan_pkg.py
fkiraly Aug 27, 2025
8554c8a
Merge branch 'kan-nbeats' of https://github.com/Sohaib-Ahmed21/pytorc…
fkiraly Aug 27, 2025
d61b2b5
Update _nbeatskan_pkg.py
fkiraly Aug 27, 2025
92213aa
Solve failing TweedieLoss test with NBeatsKAN
Sohaib-Ahmed21 Aug 27, 2025
6bb93a7
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 27, 2025
8212b2d
Adjust docstring example of b_batch function
Sohaib-Ahmed21 Aug 28, 2025
0861322
Add compatibility imports for NBEATS' blocks
Sohaib-Ahmed21 Sep 1, 2025
eeff6d8
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of
:py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2
:py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1
:py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
Expand Down
105 changes: 105 additions & 0 deletions examples/nbeats_with_kan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import sys

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
import pandas as pd

from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.models.nbeats import GridUpdateCallback

sys.path.append("..")


print("load data")
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
data["static"] = 2
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
validation = data.series.sample(20)


max_encoder_length = 150
max_prediction_length = 20

training_cutoff = data["time_idx"].max() - max_prediction_length

context_length = max_encoder_length
prediction_length = max_prediction_length

training = TimeSeriesDataSet(
data[lambda x: x.time_idx < training_cutoff],
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
min_encoder_length=context_length,
max_encoder_length=context_length,
max_prediction_length=prediction_length,
min_prediction_length=prediction_length,
time_varying_unknown_reals=["value"],
randomize_length=None,
add_relative_time_idx=False,
add_target_scales=False,
)

validation = TimeSeriesDataSet.from_dataset(
training, data, min_prediction_idx=training_cutoff
)
batch_size = 128
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
train=False, batch_size=batch_size, num_workers=0
)


early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
# updates KAN layers' grid after every 3 steps during training
grid_update_callback = GridUpdateCallback(update_interval=3)

trainer = pl.Trainer(
max_epochs=1,
accelerator="auto",
gradient_clip_val=0.1,
callbacks=[early_stop_callback, grid_update_callback],
limit_train_batches=15,
# limit_val_batches=1,
# fast_dev_run=True,
# logger=logger,
# profiler=True,
)


net = NBeatsKAN.from_dataset(
training,
learning_rate=3e-2,
log_interval=10,
log_val_interval=1,
log_gradient_flow=False,
weight_decay=1e-2,
)
print(f"Number of parameters in network: {net.size() / 1e3:.1f}k")

# # find optimal learning rate
# # remove logging and artificial epoch size
# net.hparams.log_interval = -1
# net.hparams.log_val_interval = -1
# trainer.limit_train_batches = 1.0
# # run learning rate finder
# res = Tuner(trainer).lr_find(
# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501
# )
# print(f"suggested learning rate: {res.suggestion()}")
# fig = res.plot(show=True, suggest=True)
# fig.show()
# net.hparams.learning_rate = res.suggestion()

trainer.fit(
net,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
2 changes: 2 additions & 0 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DeepAR,
MultiEmbedding,
NBeats,
NBeatsKAN,
NHiTS,
RecurrentNetwork,
TemporalFusionTransformer,
Expand Down Expand Up @@ -73,6 +74,7 @@
"TemporalFusionTransformer",
"TiDEModel",
"NBeats",
"NBeatsKAN",
"NHiTS",
"Baseline",
"DeepAR",
Expand Down
7 changes: 7 additions & 0 deletions pytorch_forecasting/layers/_kan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
KAN (Kolmogorov Arnold Network) layer implementation.
"""

from pytorch_forecasting.layers._kan._kan_layer import KANLayer

__all__ = ["KANLayer"]
237 changes: 237 additions & 0 deletions pytorch_forecasting/layers/_kan/_kan_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# The following implementation of KANLayer is inspired by the pykan library.
# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py

import numpy as np
import torch
import torch.nn as nn

from pytorch_forecasting.layers._kan._utils import (
coef2curve,
curve2coef,
extend_grid,
sparse_mask,
)


class KANLayer(nn.Module):
"""
Initialize a KANLayer

Parameters
----------
in_dim : int
input dimension. Default: 2.
out_dim : int
output dimension. Default: 3.
num : int
the number of grid intervals = G. Default: 5.
k : int
the order of piecewise polynomial. Default: 3.
noise_scale : float
the scale of noise injected at initialization. Default: 0.1.
scale_base_mu : float
the scale of the residual function b(x) is intialized to be
N(scale_base_mu, scale_base_sigma^2).
scale_base_sigma : float
the scale of the residual function b(x) is intialized to be
N(scale_base_mu, scale_base_sigma^2).
scale_sp : float
the scale of the base function spline(x).
base_fun : function
residual function b(x). Default: None
grid_eps : float
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is
partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates
between the two extremes.
grid_range : list or np.array of shape (2,)
setting the range of grids. Default: None.
sp_trainable : bool
If true, scale_sp is trainable.
sb_trainable : bool
If true, scale_base is trainable.
sparse_init : bool
if sparse_init = True, sparse initialization is applied.

Returns
-------
self : reference to self

Examples
--------
The following is an example from the original `pykan` library, adapted here
for illustration within the PyTorch Forecasting integration.

Install the `pykan` package first:
pip install pykan
Then use:

>>> from kan.KANLayer import *
>>> model = KANLayer(in_dim=3, out_dim=5)
>>> (model.in_dim, model.out_dim)
"""

def __init__(
self,
in_dim=3,
out_dim=2,
num=5,
k=3,
noise_scale=0.5,
scale_base_mu=0.0,
scale_base_sigma=1.0,
scale_sp=1.0,
base_fun=None,
grid_eps=0.02,
grid_range=None,
sp_trainable=True,
sb_trainable=True,
sparse_init=False,
):
super().__init__()

# Handle mutable parameters
if grid_range is None:
grid_range = [-1, 1]
if base_fun is None:
base_fun = torch.nn.SiLU()
# size
self.out_dim = out_dim
self.in_dim = in_dim
self.num = num
self.k = k

grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[
None, :
].expand(self.in_dim, num + 1)
grid = extend_grid(grid, k_extend=k)
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
noises = (
(torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2)
* noise_scale
/ num
)

self.coef = torch.nn.Parameter(
curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)
)

if sparse_init:
self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(
False
)
else:
self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(
False
)

self.scale_base = torch.nn.Parameter(
scale_base_mu * 1 / np.sqrt(in_dim)
+ scale_base_sigma
* (torch.rand(in_dim, out_dim) * 2 - 1)
* 1
/ np.sqrt(in_dim)
).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(
torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask
).requires_grad_(sp_trainable) # make scale trainable
self.base_fun = base_fun

self.grid_eps = grid_eps

def forward(self, x):
"""
KANLayer forward given input x

Parameters
-----
x : torch.Tensor
Input tensor of shape (batch_size, in_dim), where:
- batch_size is the number of input samples.
- in_dim is the input feature dimension.

Returns
--------
y : torch.Tensor
Output tensor, the result of applying spline and residual
transformations followed by weighted summation.

Examples
--------
The following is an example from the original `pykan` library, adapted here
for illustration within the PyTorch Forecasting integration.

Install the `pykan` package first:
pip install pykan
Then use:

>>> from kan.KANLayer import *
>>> model = KANLayer(in_dim=3, out_dim=5)
>>> x = torch.normal(0,1,size=(100,3))
>>> y, _, _, _ = model(x)
>>> y.shape
"""

base = self.base_fun(x) # (batch, in_dim)
y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
y = (
self.scale_base[None, :, :] * base[:, :, None]
+ self.scale_sp[None, :, :] * y
)
y = self.mask[None, :, :] * y
y = torch.sum(y, dim=1)
return y

def update_grid_from_samples(self, x):
"""
Update grid from samples

Parameters
-----
x : 2D torch.float
inputs, shape (number of samples, input dimension)

Returns:
--------
None

Examples
-------
>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
>>> print(model.grid.data)
>>> x = torch.linspace(-3,3,steps=100)[:,None]
>>> model.update_grid_from_samples(x)
>>> print(model.grid.data)
"""

batch = x.shape[0]
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
num_interval = self.grid.shape[1] - 1 - 2 * self.k

def get_grid(num_interval):
"""
Generate adaptive or uniform grid points from sorted input samples.

Parameters
-----
num_interval : int
Number of intervals between grid points.

Returns:
--------
grid : torch.Tensor
New grid of shape (in_dim, num_interval + 1).
"""
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1, 0)
h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]]) / num_interval
grid_uniform = (
grid_adaptive[:, [0]]
+ h * torch.arange(num_interval + 1, device=h.device)[None, :]
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)
self.grid.data = extend_grid(grid, k_extend=self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
Loading
Loading