Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3bab673
Refactored NBeats and added comments for KAN block and NBeats.
Sohaib-Ahmed21 Jan 13, 2025
acfa626
Refactored NBeats and added comments for KAN block and NBeats.
Sohaib-Ahmed21 Jan 13, 2025
41d7403
End to end integrated Kolmogorov Arnold Networks in NBeats. Also refa…
Sohaib-Ahmed21 Jan 13, 2025
53fb126
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jan 13, 2025
594102d
Resolved import error.
Sohaib-Ahmed21 Jan 13, 2025
45c63f6
Merge branch 'kan-nbeats' of github.com:Sohaib-Ahmed21/pytorch-foreca…
Sohaib-Ahmed21 Jan 13, 2025
88de705
Merge branch 'main' of https://github.com/Sohaib-Ahmed21/pytorch-fore…
Sohaib-Ahmed21 Jan 22, 2025
c8ccfaf
Refactored NBEATS and added support for grid updation during training…
Sohaib-Ahmed21 Jan 23, 2025
348da97
Refactored comments.
Sohaib-Ahmed21 Jan 23, 2025
09facba
Merge branch 'sktime:main' into kan-nbeats
Sohaib-Ahmed21 Feb 1, 2025
1ab0da0
Added example to use grid_update_callback and added correct device to…
Sohaib-Ahmed21 Feb 1, 2025
05350c2
Refactored code for NBEATSKAN and introduced it as separate model/ent…
Sohaib-Ahmed21 Feb 20, 2025
7070f8b
Made modules private.
Sohaib-Ahmed21 Feb 23, 2025
0219fc3
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Feb 25, 2025
ca78516
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Apr 5, 2025
e4f8790
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 May 11, 2025
dd8358d
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 May 22, 2025
315b819
Resolved merge conflicts with main
Sohaib-Ahmed21 Jul 5, 2025
2a4d3ec
Merge branch 'sktime:main' into kan-nbeats
Sohaib-Ahmed21 Jul 5, 2025
14ca66f
Address deprecated typing classes
Sohaib-Ahmed21 Jul 5, 2025
89a9a4f
Refactor code with proper docstrings and cleaner structure
Sohaib-Ahmed21 Jul 5, 2025
0c43448
Refactor examples in docstring
Sohaib-Ahmed21 Jul 5, 2025
2da4d13
Include NBEATSKAN package container
Sohaib-Ahmed21 Jul 6, 2025
eb9c79d
Refactor and enhance docstrings to follow NumPy style, include KAN re…
Sohaib-Ahmed21 Jul 7, 2025
cc819ff
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jul 10, 2025
5a31a58
Merge branch 'main' into kan-nbeats
fkiraly Jul 10, 2025
58e14d8
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Jul 12, 2025
c77a44e
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 12, 2025
f533241
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 16, 2025
ec4844e
Restructure KAN and NBeats layers to include them in pytorch_forecast…
Sohaib-Ahmed21 Aug 17, 2025
c792dc3
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 21, 2025
698f242
rename get_cls
fkiraly Aug 26, 2025
1570e02
add _pkg pointer
fkiraly Aug 26, 2025
33284b0
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 26, 2025
a6929b1
Update _nbeatskan_pkg.py
fkiraly Aug 27, 2025
8554c8a
Merge branch 'kan-nbeats' of https://github.com/Sohaib-Ahmed21/pytorc…
fkiraly Aug 27, 2025
d61b2b5
Update _nbeatskan_pkg.py
fkiraly Aug 27, 2025
92213aa
Solve failing TweedieLoss test with NBeatsKAN
Sohaib-Ahmed21 Aug 27, 2025
6bb93a7
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Aug 27, 2025
8212b2d
Adjust docstring example of b_batch function
Sohaib-Ahmed21 Aug 28, 2025
0861322
Add compatibility imports for NBEATS' blocks
Sohaib-Ahmed21 Sep 1, 2025
eeff6d8
Merge branch 'main' into kan-nbeats
Sohaib-Ahmed21 Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 113 additions & 40 deletions pytorch_forecasting/models/nbeats/_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
prediction_length: int = 1,
context_length: int = 1,
use_kan: bool = False,
num_grids: int = 5,
k: int = 3,
noise_scale: float = 0.5,
scale_base_mu: float = 0.0,
scale_base_sigma: float = 1.0,
scale_sp: float = 1.0,
base_fun: callable = torch.nn.SiLU(),
grid_eps: float = 0.02,
grid_range: List[int] = [-1, 1],
sp_trainable: bool = True,
sb_trainable: bool = True,
sparse_init: bool = False,
dropout: float = 0.1,
learning_rate: float = 1e-2,
log_interval: int = -1,
Expand All @@ -47,48 +60,86 @@ def __init__(

Based on the article
`N-BEATS: Neural basis expansion analysis for interpretable time series
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if used as ensemble) outperformed all
other methods
including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably
the most
important benchmark for univariate time series forecasting.
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if
used as ensemble) outperformed all other methods including ensembles of
traditional statical methods in the M4 competition. The M4 competition is
arguably the most important benchmark for univariate time series forecasting.

The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform
N-BEATS.
The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently
shown to consistently outperform N-BEATS.

Args:
stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings
of length 1 or ‘num_stacks’. Default and recommended value
for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]
num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’.
Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length
1 or ‘num_stacks’.
Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]
width: Widths of the fully connected layers with ReLu activation in the blocks.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512]
Recommended value for interpretable mode: [256, 2048]
stack_types: One of the following values: “generic”, “seasonality" or
“trend". A list of strings of length 1 or ‘num_stacks’. Default and
recommended value for generic mode: [“generic”] Recommended value for
interpretable mode: [“trend”,”seasonality”].
num_blocks: The number of blocks per stack. A list of ints of length 1 or
‘num_stacks’. Default and recommended value for generic mode: [1]
Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with ReLu activation per
block.
A list of ints of length 1 or ‘num_stacks’. Default and recommended
value for generic mode: [4] Recommended value for interpretable mode:
[4].
width: Widths of the fully connected layers with ReLu activation in the
blocks. A list of ints of length 1 or ‘num_stacks’. Default and
recommended value for generic mode: [512]. Recommended value for
interpretable mode: [256, 2048]
sharing: Whether the weights are shared with the other blocks per stack.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False]
Recommended value for interpretable mode: [True]
expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion
coefficient.
If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S”
(seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep.
A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for
A list of ints of length 1 or ‘num_stacks’. Default and recommended
value for generic mode: [False]. Recommended value for interpretable
mode: [True].
expansion_coefficient_length: If the type is “G” (generic), then the length
of the expansion coefficient.
If type is “T” (trend), then it corresponds to the degree of the
polynomial.
If the type is “S” (seasonal) then this is the minimum period allowed,
e.g. 2 for changes every timestep. A list of ints of length 1 or
‘num_stacks’. Default value for generic mode: [32] Recommended value for
interpretable mode: [3]
prediction_length: Length of the prediction. Also known as 'horizon'.
context_length: Number of time units that condition the predictions. Also known as 'lookback period'.
context_length: Number of time units that condition the predictions.
Also known as 'lookback period'.
Should be between 1-10 times the prediction length.
backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss.
A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and
forecast lengths). Defaults to 0.0, i.e. no weight.
num_grids : Parameter for KAN layer. the number of grid intervals = G.
Default: 5.
k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3.
noise_scale : Parameter for KAN layer. the scale of noise injected at
initialization. Default: 0.1.
scale_base_mu : Parameter for KAN layer. the scale of the residual
function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
Deafult: 0.0
scale_base_sigma : Parameter for KAN layer. the scale of the residual
function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
Deafult: 1.0
scale_sp : Parameter for KAN layer. the scale of the base function
spline(x). Deafult: 1.0
base_fun : Parameter for KAN layer. residual function b(x).
Default: torch.nn.SiLU()
grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform;
when grid_eps = 0, the grid is partitioned using percentiles of samples.
0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02
grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting
the range of grids.
Default: [-1,1].
sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable.
Default: True.
sb_trainable : Parameter for KAN layer. If true, scale_base is trainable.
Default: True.
sparse_init : Parameter for KAN layer. if sparse_init = True, sparse
initialization is applied. Default: False.
backcast_loss_ratio: weight of backcast in comparison to forecast when
calculating the loss. A weight of 1.0 means that forecast and
backcast loss is weighted the same (regardless of backcast and forecast
lengths). Defaults to 0.0, i.e. no weight.
loss: loss to optimize. Defaults to MASE().
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
failures
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
log_gradient_flow: if to log gradient flow, this takes time and should be
only done to diagnose training failures.
reduce_on_plateau_patience (int): patience after which learning rate is
reduced by a factor of 10
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that
are logged during training. Defaults to
nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
**kwargs: additional arguments to :py:class:`~BaseModel`.
""" # noqa: E501
if expansion_coefficient_lengths is None:
Expand All @@ -107,7 +158,24 @@ def __init__(
logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
if loss is None:
loss = MASE()
self.save_hyperparameters()
# Bundle KAN parameters into a dictionary
self.kan_params = {
"use_kan": use_kan,
"num_grids": num_grids,
"k": k,
"noise_scale": noise_scale,
"scale_base_mu": scale_base_mu,
"scale_base_sigma": scale_base_sigma,
"scale_sp": scale_sp,
"base_fun": base_fun,
"grid_eps": grid_eps,
"grid_range": grid_range,
"sp_trainable": sp_trainable,
"sb_trainable": sb_trainable,
"sparse_init": sparse_init,
}

self.save_hyperparameters(ignore=["loss", "logging_metrics"])
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)

# setup stacks
Expand All @@ -122,6 +190,7 @@ def __init__(
backcast_length=context_length,
forecast_length=prediction_length,
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
elif stack_type == "seasonality":
net_block = NBEATSSeasonalBlock(
Expand All @@ -131,6 +200,7 @@ def __init__(
forecast_length=prediction_length,
min_period=self.hparams.expansion_coefficient_lengths[stack_id],
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
elif stack_type == "trend":
net_block = NBEATSTrendBlock(
Expand All @@ -140,6 +210,7 @@ def __init__(
backcast_length=context_length,
forecast_length=prediction_length,
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
else:
raise ValueError(f"Unknown stack type {stack_type}")
Expand Down Expand Up @@ -223,7 +294,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
"""
Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
Convenience function to create network from :py:class
`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.

Args:
dataset (TimeSeriesDataSet): dataset where sole predictor is the target.
Expand Down Expand Up @@ -359,10 +431,11 @@ def plot_interpretation(
x (Dict[str, torch.Tensor]): network input
output (Dict[str, torch.Tensor]): network output
idx (int): index of sample for which to plot the interpretation.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation.
Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and
generic forecast on secondary axis in second panel. Defaults to False.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which
to plot the interpretation. Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot
seasonality and generic forecast on secondary axis in second panel.
Defaults to False.

Returns:
plt.Figure: matplotlib figure
Expand Down
Loading
Loading