diff --git a/docs/source/models.rst b/docs/source/models.rst index cd9048b0a..f53ae2670 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 + :py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py new file mode 100644 index 000000000..6a018ce5d --- /dev/null +++ b/examples/nbeats_with_kan.py @@ -0,0 +1,105 @@ +import sys + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd + +from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data +from pytorch_forecasting.models.nbeats import GridUpdateCallback + +sys.path.append("..") + + +print("load data") +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +validation = data.series.sample(20) + + +max_encoder_length = 150 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx < training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + min_encoder_length=context_length, + max_encoder_length=context_length, + max_prediction_length=prediction_length, + min_prediction_length=prediction_length, + time_varying_unknown_reals=["value"], + randomize_length=None, + add_relative_time_idx=False, + add_target_scales=False, +) + +validation = TimeSeriesDataSet.from_dataset( + training, data, min_prediction_idx=training_cutoff +) +batch_size = 128 +train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 +) +val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 +) + + +early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" +) +# updates KAN layers' grid after every 3 steps during training +grid_update_callback = GridUpdateCallback(update_interval=3) + +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + gradient_clip_val=0.1, + callbacks=[early_stop_callback, grid_update_callback], + limit_train_batches=15, + # limit_val_batches=1, + # fast_dev_run=True, + # logger=logger, + # profiler=True, +) + + +net = NBeatsKAN.from_dataset( + training, + learning_rate=3e-2, + log_interval=10, + log_val_interval=1, + log_gradient_flow=False, + weight_decay=1e-2, +) +print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") + +# # find optimal learning rate +# # remove logging and artificial epoch size +# net.hparams.log_interval = -1 +# net.hparams.log_val_interval = -1 +# trainer.limit_train_batches = 1.0 +# # run learning rate finder +# res = Tuner(trainer).lr_find( +# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 +# ) +# print(f"suggested learning rate: {res.suggestion()}") +# fig = res.plot(show=True, suggest=True) +# fig.show() +# net.hparams.learning_rate = res.suggestion() + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index 8cbb8e9d1..8c8da17d2 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -43,6 +43,7 @@ DeepAR, MultiEmbedding, NBeats, + NBeatsKAN, NHiTS, RecurrentNetwork, TemporalFusionTransformer, @@ -73,6 +74,7 @@ "TemporalFusionTransformer", "TiDEModel", "NBeats", + "NBeatsKAN", "NHiTS", "Baseline", "DeepAR", diff --git a/pytorch_forecasting/layers/_kan/__init__.py b/pytorch_forecasting/layers/_kan/__init__.py new file mode 100644 index 000000000..55e296e87 --- /dev/null +++ b/pytorch_forecasting/layers/_kan/__init__.py @@ -0,0 +1,7 @@ +""" +KAN (Kolmogorov Arnold Network) layer implementation. +""" + +from pytorch_forecasting.layers._kan._kan_layer import KANLayer + +__all__ = ["KANLayer"] diff --git a/pytorch_forecasting/layers/_kan/_kan_layer.py b/pytorch_forecasting/layers/_kan/_kan_layer.py new file mode 100644 index 000000000..217d92e92 --- /dev/null +++ b/pytorch_forecasting/layers/_kan/_kan_layer.py @@ -0,0 +1,237 @@ +# The following implementation of KANLayer is inspired by the pykan library. +# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py + +import numpy as np +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._kan._utils import ( + coef2curve, + curve2coef, + extend_grid, + sparse_mask, +) + + +class KANLayer(nn.Module): + """ + Initialize a KANLayer + + Parameters + ---------- + in_dim : int + input dimension. Default: 2. + out_dim : int + output dimension. Default: 3. + num : int + the number of grid intervals = G. Default: 5. + k : int + the order of piecewise polynomial. Default: 3. + noise_scale : float + the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_sp : float + the scale of the base function spline(x). + base_fun : function + residual function b(x). Default: None + grid_eps : float + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is + partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates + between the two extremes. + grid_range : list or np.array of shape (2,) + setting the range of grids. Default: None. + sp_trainable : bool + If true, scale_sp is trainable. + sb_trainable : bool + If true, scale_base is trainable. + sparse_init : bool + if sparse_init = True, sparse initialization is applied. + + Returns + ------- + self : reference to self + + Examples + -------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + + Install the `pykan` package first: + pip install pykan + Then use: + + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> (model.in_dim, model.out_dim) + """ + + def __init__( + self, + in_dim=3, + out_dim=2, + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=None, + grid_eps=0.02, + grid_range=None, + sp_trainable=True, + sb_trainable=True, + sparse_init=False, + ): + super().__init__() + + # Handle mutable parameters + if grid_range is None: + grid_range = [-1, 1] + if base_fun is None: + base_fun = torch.nn.SiLU() + # size + self.out_dim = out_dim + self.in_dim = in_dim + self.num = num + self.k = k + + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[ + None, : + ].expand(self.in_dim, num + 1) + grid = extend_grid(grid, k_extend=k) + self.grid = torch.nn.Parameter(grid).requires_grad_(False) + noises = ( + (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) + * noise_scale + / num + ) + + self.coef = torch.nn.Parameter( + curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k) + ) + + if sparse_init: + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_( + False + ) + else: + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_( + False + ) + + self.scale_base = torch.nn.Parameter( + scale_base_mu * 1 / np.sqrt(in_dim) + + scale_base_sigma + * (torch.rand(in_dim, out_dim) * 2 - 1) + * 1 + / np.sqrt(in_dim) + ).requires_grad_(sb_trainable) + self.scale_sp = torch.nn.Parameter( + torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask + ).requires_grad_(sp_trainable) # make scale trainable + self.base_fun = base_fun + + self.grid_eps = grid_eps + + def forward(self, x): + """ + KANLayer forward given input x + + Parameters + ----- + x : torch.Tensor + Input tensor of shape (batch_size, in_dim), where: + - batch_size is the number of input samples. + - in_dim is the input feature dimension. + + Returns + -------- + y : torch.Tensor + Output tensor, the result of applying spline and residual + transformations followed by weighted summation. + + Examples + -------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + + Install the `pykan` package first: + pip install pykan + Then use: + + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> x = torch.normal(0,1,size=(100,3)) + >>> y, _, _, _ = model(x) + >>> y.shape + """ + + base = self.base_fun(x) # (batch, in_dim) + y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) + y = ( + self.scale_base[None, :, :] * base[:, :, None] + + self.scale_sp[None, :, :] * y + ) + y = self.mask[None, :, :] * y + y = torch.sum(y, dim=1) + return y + + def update_grid_from_samples(self, x): + """ + Update grid from samples + + Parameters + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Examples + ------- + >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(model.grid.data) + >>> x = torch.linspace(-3,3,steps=100)[:,None] + >>> model.update_grid_from_samples(x) + >>> print(model.grid.data) + """ + + batch = x.shape[0] + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + def get_grid(num_interval): + """ + Generate adaptive or uniform grid points from sorted input samples. + + Parameters + ----- + num_interval : int + Number of intervals between grid points. + + Returns: + -------- + grid : torch.Tensor + New grid of shape (in_dim, num_interval + 1). + """ + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1, 0) + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]]) / num_interval + grid_uniform = ( + grid_adaptive[:, [0]] + + h * torch.arange(num_interval + 1, device=h.device)[None, :] + ) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid + + grid = get_grid(num_interval) + self.grid.data = extend_grid(grid, k_extend=self.k) + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) diff --git a/pytorch_forecasting/layers/_kan/_utils.py b/pytorch_forecasting/layers/_kan/_utils.py new file mode 100644 index 000000000..fbbcb65d4 --- /dev/null +++ b/pytorch_forecasting/layers/_kan/_utils.py @@ -0,0 +1,186 @@ +""" +Utility functions for KAN (Kolmogorov Arnold Network) Layer. +Contains B-spline computations, curve transformations, and grid manipulation functions. +""" + +import torch + + +def b_batch(x, grid, k=0): + """ + Evaluate x on B-spline bases + + Parameters + ---------- + x : torch.Tensor + 2D tensor of inputs, shape (number of splines, number of samples). + grid : torch.Tensor + 2D tensor of grids, shape (number of splines, number of grid points). + k : int + The piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True. + + Returns + ------- + spline values : torch.Tensor + 3D tensor of shape (batch, in_dim, G+k), where G is the number of + grid intervals and k is the spline order. + + Examples + -------- + The following is an example from the original `pykan` library, adapted here + for illustration within the PyTorch Forecasting integration. + + Install the `pykan` package first: + pip install pykan + Then use: + + >>> from pytorch_forecasting.layers._kan._utils import b_batch + >>> import torch + >>> x = torch.rand(100, 2) + >>> grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11) + >>> b_batch(x, grid, k=3).shape + torch.Size([100, 2, 7]) + """ + + x = x.unsqueeze(dim=2) + grid = grid.unsqueeze(dim=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = b_batch(x[:, :, 0], grid=grid[0], k=k - 1) + + value = (x - grid[:, :, : -(k + 1)]) / ( + grid[:, :, k:-1] - grid[:, :, : -(k + 1)] + ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( + grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] + ) * B_km1[:, :, 1:] + + # in case grid is degenerate + value = torch.nan_to_num(value) + return value + + +def coef2curve(x_eval, grid, coef, k): + """ + Converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + (summing up b_batch results over B-spline basis). + + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, G+2k). G: the number of grid intervals; + k: spline order. + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G+k). + k : int + The piecewise polynomial order of splines. + + Returns + ------- + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). + """ + + b_splines = b_batch(x_eval, grid, k=k) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) + + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + """ + Estimate spline coefficients via batched least squares. + + Parameters + ---------- + x_eval : torch.Tensor + 2D tensor of shape (batch, in_dim). + y_eval : torch.Tensor + 3D tensor of shape (batch, in_dim, out_dim). + grid : torch.Tensor + 2D tensor of shape (in_dim, grid + 2 * k). + k : int + Spline order. + lamb : float + Regularized least square lambda. + + Returns + ------- + coef : torch.Tensor + 3D tensor of shape (in_dim, out_dim, G + k). + """ + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + mat = b_batch(x_eval, grid, k) + mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) + y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) + try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] + except Exception as e: + print(f"lstsq failed with error: {e}") + + return coef + + +def extend_grid(grid, k_extend=0): + """ + Extend a grid tensor by padding both ends with equal spacing. + + Parameters + ---------- + grid : torch.Tensor + Grid of shape (in_dim, grid_points). + k_extend : int + Number of points to extend on both ends. + + Returns + ------- + grid : torch.Tensor + Extended grid of shape (in_dim, grid_points + 2 * k_extend). + """ + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + return grid + + +def sparse_mask(in_dim, out_dim): + """ + Generate a sparse connection mask between input and output units. + + Parameters + ---------- + in_dim : int + Number of input units. + out_dim : int + Number of output units. + + Returns + ------- + mask : torch.Tensor + Sparse binary mask of shape (in_dim, out_dim). + """ + in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) + out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) + + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dim, out_dim) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask diff --git a/pytorch_forecasting/layers/_nbeats/__init__.py b/pytorch_forecasting/layers/_nbeats/__init__.py new file mode 100644 index 000000000..daf47de2d --- /dev/null +++ b/pytorch_forecasting/layers/_nbeats/__init__.py @@ -0,0 +1,17 @@ +""" +Implementation of N-BEATS model blocks and utilities. +""" + +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSBlock, + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) + +__all__ = [ + "NBEATSBlock", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", +] diff --git a/pytorch_forecasting/layers/_nbeats/_blocks.py b/pytorch_forecasting/layers/_nbeats/_blocks.py new file mode 100644 index 000000000..80e7657da --- /dev/null +++ b/pytorch_forecasting/layers/_nbeats/_blocks.py @@ -0,0 +1,438 @@ +""" +Implementation of ``nn.Modules`` for N-Beats model. +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_forecasting.layers._kan._kan_layer import KANLayer +from pytorch_forecasting.layers._nbeats._utils import linear, linspace + + +class NBEATSBlock(nn.Module): + """ + Initialize an N-BEATS block using either MLP or KAN layers. + + Parameters + ---------- + units : int + Number of units in each layer. + thetas_dim : int + Output dimension of the theta layers. + num_block_layers : int + Number of hidden layers in the block. Default is 4. + backcast_length : int + Length of the input (past) sequence. Default is 10. + forecast_length : int + Length of the output (future) sequence. Default is 5. + dropout : float + Dropout rate for regularization. Default is 0.1. + kan_params : dict + Dictionary of parameters for KAN layers. Only required if `use_kan=True`. + Default values will be used if not provided. Includes: + - num : int, default=5 + Number of grid intervals. + - k : int, default=3 + Order of piecewise polynomial. + - noise_scale : float, default=0.5 + Initialization noise scale. + - scale_base_mu : float, default=0.0 + Mean for residual function initialization. + - scale_base_sigma : float, default=1.0 + Std deviation for residual function initialization. + - scale_sp : float, default=1.0 + Scale for the spline function. + - base_fun : nn.Module, default=torch.nn.SiLU() + Base function module. + - grid_eps : float, default=0.02 + Determines grid spacing (0 for quantile, 1 for uniform). + - grid_range : list of float, default=[-1, 1] + Range of the spline grid. + - sp_trainable : bool, default=True + Whether scale_sp is trainable. + - sb_trainable : bool, default=True + Whether scale_base is trainable. + - sparse_init : bool, default=False + Whether to apply sparse initialization. + use_kan : bool + If True, uses KAN layers instead of MLP. Default is False. + """ + + def __init__( + self, + units, + thetas_dim, + num_block_layers=4, + backcast_length=10, + forecast_length=5, + dropout=0.1, + kan_params=None, + use_kan=False, + ): + super().__init__() + + if use_kan and kan_params is None: + # Define default parameters for KAN if not provided + kan_params = dict( + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=torch.nn.SiLU(), + grid_eps=0.02, + grid_range=[-1, 1], + sp_trainable=True, + sb_trainable=True, + sparse_init=False, + ) + + self.units = units + self.thetas_dim = thetas_dim + self.backcast_length = backcast_length + self.forecast_length = forecast_length + self.kan_params = kan_params + self.use_kan = use_kan + + if self.use_kan: + layers = [ + KANLayer( + in_dim=backcast_length, + out_dim=units, + **self.kan_params, + ) + ] + + # Add additional layers for deeper structure + for _ in range(num_block_layers - 1): + layers.append( + KANLayer( + in_dim=units, + out_dim=units, + **self.kan_params, + ) + ) + + # Define the fully connected layers + self.fc = nn.Sequential(*layers) + + # Define the theta layers + self.theta_f_fc = self.theta_b_fc = KANLayer( + in_dim=units, + out_dim=thetas_dim, + **self.kan_params, + ) + + else: + fc_stack = [ + nn.Linear(backcast_length, units), + nn.ReLU(), + ] + for _ in range(num_block_layers - 1): + fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) + self.fc = nn.Sequential(*fc_stack) + self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) + + def forward(self, x): + """ + Forward pass through the block using either MLP or KAN layers. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after processing through the block. + """ + if self.use_kan: + # save outputs to be used in updating grid in kan layers during training + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # storing outputs for updating grids of self.fc when using KAN + self.outputs.append(x.clone().detach()) + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN + self.outputs.append(x.clone().detach()) + return x # Return final output + return self.fc(x) + + +class NBEATSSeasonalBlock(NBEATSBlock): + """ + Initialize a Seasonal N-BEATS block with Fourier-based seasonality modeling. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Output dimension of theta layers. Inferred from harmonics if not provided. + num_block_layers : int + Number of layers in the block. Default is 4. + backcast_length : int + Length of the input (past) sequence. Default is 10. + forecast_length : int + Length of the output (future) sequence. Default is 5. + nb_harmonics : int + Number of harmonics for Fourier features. Default is None. + min_period : int + Minimum period for seasonality. Default is 1. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + Dictionary of KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + + def __init__( + self, + units, + thetas_dim=None, + num_block_layers=4, + backcast_length=10, + forecast_length=5, + nb_harmonics=None, + min_period=1, + dropout=0.1, + kan_params=None, + use_kan=False, + ): + if nb_harmonics: + thetas_dim = nb_harmonics + else: + thetas_dim = forecast_length + self.min_period = min_period + + super().__init__( + units=units, + thetas_dim=thetas_dim, + num_block_layers=num_block_layers, + backcast_length=backcast_length, + forecast_length=forecast_length, + dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, + ) + + backcast_linspace, forecast_linspace = linspace( + backcast_length, forecast_length, centered=False + ) + + p1, p2 = ( + (thetas_dim // 2, thetas_dim // 2) + if thetas_dim % 2 == 0 + else (thetas_dim // 2, thetas_dim // 2 + 1) + ) + s1_b = torch.tensor( + np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * backcast_linspace), + dtype=torch.float32, + ) # H/2-1 + s2_b = torch.tensor( + np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * backcast_linspace), + dtype=torch.float32, + ) + self.register_buffer("S_backcast", torch.cat([s1_b, s2_b])) + + s1_f = torch.tensor( + np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * forecast_linspace), + dtype=torch.float32, + ) # H/2-1 + s2_f = torch.tensor( + np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * forecast_linspace), + dtype=torch.float32, + ) + self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) + + def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute seasonal backcast and forecast outputs using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast), each of shape (batch_size, time_steps). + """ + x = super().forward(x) + amplitudes_backward = self.theta_b_fc(x) + backcast = amplitudes_backward.mm(self.S_backcast) + amplitudes_forward = self.theta_f_fc(x) + forecast = amplitudes_forward.mm(self.S_forecast) + + return backcast, forecast + + def get_frequencies(self, n): + """ + Generates frequency values based on the backcast and forecast lengths. + """ + return np.linspace( + 0, (self.backcast_length + self.forecast_length) / self.min_period, n + ) + + +class NBEATSTrendBlock(NBEATSBlock): + """ + Initialize a Trend N-BEATS block using polynomial basis functions. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Output dimension of theta layers (number of polynomial terms). + num_block_layers : int + Number of hidden layers. Default is 4. + backcast_length : int + Length of input sequence. Default is 10. + forecast_length : int + Length of output sequence. Default is 5. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + + def __init__( + self, + units, + thetas_dim, + num_block_layers=4, + backcast_length=10, + forecast_length=5, + dropout=0.1, + kan_params=None, + use_kan=False, + ): + super().__init__( + units=units, + thetas_dim=thetas_dim, + num_block_layers=num_block_layers, + backcast_length=backcast_length, + forecast_length=forecast_length, + dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, + ) + + backcast_linspace, forecast_linspace = linspace( + backcast_length, forecast_length, centered=True + ) + norm = np.sqrt( + forecast_length / thetas_dim + ) # ensure range of predictions is comparable to input + thetas_dims_range = np.array(range(thetas_dim)) + coefficients = torch.tensor( + backcast_linspace ** thetas_dims_range[:, None], + dtype=torch.float32, + ) + self.register_buffer("T_backcast", coefficients * norm) + coefficients = torch.tensor( + forecast_linspace ** thetas_dims_range[:, None], + dtype=torch.float32, + ) + self.register_buffer("T_forecast", coefficients * norm) + + def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute backcast and forecast outputs using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast). + """ + + x = super().forward(x) + backcast = self.theta_b_fc(x).mm(self.T_backcast) + forecast = self.theta_f_fc(x).mm(self.T_forecast) + return backcast, forecast + + +class NBEATSGenericBlock(NBEATSBlock): + """ + Initialize a Generic N-BEATS block using linear mapping of theta outputs. + + Parameters + ---------- + units : int + Number of units in each hidden layer. + thetas_dim : int + Dimension of the theta parameter. + num_block_layers : int + Number of hidden layers. Default is 4. + backcast_length : int + Length of past input. Default is 10. + forecast_length : int + Length of future prediction. Default is 5. + dropout : float + Dropout rate. Default is 0.1. + kan_params : dict + KAN layer parameters. See NBEATSBlock for details. + use_kan : bool + If True, uses KAN instead of MLP. Default is False. + """ + + def __init__( + self, + units, + thetas_dim, + num_block_layers=4, + backcast_length=10, + forecast_length=5, + dropout=0.1, + kan_params=None, + use_kan=False, + ): + super().__init__( + units=units, + thetas_dim=thetas_dim, + num_block_layers=num_block_layers, + backcast_length=backcast_length, + forecast_length=forecast_length, + dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, + ) + + self.backcast_fc = nn.Linear(thetas_dim, backcast_length) + self.forecast_fc = nn.Linear(thetas_dim, forecast_length) + + def forward(self, x): + """ + Compute backcast and forecast using using input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, backcast_length). + + Returns + ------- + tuple of torch.Tensor + Tuple (backcast, forecast). + """ + x = super().forward(x) + theta_b = F.relu(self.theta_b_fc(x)) + theta_f = F.relu(self.theta_f_fc(x)) + return self.backcast_fc(theta_b), self.forecast_fc(theta_f) diff --git a/pytorch_forecasting/layers/_nbeats/_utils.py b/pytorch_forecasting/layers/_nbeats/_utils.py new file mode 100644 index 000000000..0b884d4e1 --- /dev/null +++ b/pytorch_forecasting/layers/_nbeats/_utils.py @@ -0,0 +1,39 @@ +""" +Utility functions for N-BEATS model implementation. +""" + +import numpy as np +import torch.nn as nn + + +def linear(input_size, output_size, bias=True, dropout: int = None): + """ + Initialize linear layers for MLP block layers. + """ + lin = nn.Linear(input_size, output_size, bias=bias) + if dropout is not None: + return nn.Sequential(nn.Dropout(dropout), lin) + else: + return lin + + +def linspace( + backcast_length: int, forecast_length: int, centered: bool = False +) -> tuple[np.ndarray, np.ndarray]: + """ + Generate linear spaced values for backcast and forecast. + """ + if centered: + norm = max(backcast_length, forecast_length) + start = -backcast_length + stop = forecast_length - 1 + else: + norm = backcast_length + forecast_length + start = 0 + stop = backcast_length + forecast_length - 1 + lin_space = np.linspace( + start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32 + ) + b_ls = lin_space[:backcast_length] + f_ls = lin_space[backcast_length:] + return b_ls, f_ls diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 07335a08f..dc635b261 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -11,7 +11,7 @@ from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR from pytorch_forecasting.models.mlp import DecoderMLP -from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.nbeats import NBeats, NBeatsKAN from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork @@ -24,6 +24,7 @@ __all__ = [ "NBeats", + "NBeatsKAN", "NHiTS", "TemporalFusionTransformer", "RecurrentNetwork", diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index f4ee55230..5377ff67f 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,17 +1,30 @@ -"""N-Beats model for timeseries forecasting without covariates.""" +""" +N-Beats model for timeseries forecasting without covariates. -from pytorch_forecasting.models.nbeats._nbeats import NBeats -from pytorch_forecasting.models.nbeats._nbeats_pkg import NBeats_pkg -from pytorch_forecasting.models.nbeats.sub_modules import ( +# TODO v2: remove compatibility imports, kept to avoid breaking existing code. +""" + +# Import blocks from new location for backward compatibility +from pytorch_forecasting.layers._nbeats._blocks import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) +from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback +from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeats_pkg import NBeats_pkg +from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN +from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg __all__ = [ "NBeats", - "NBEATSGenericBlock", + "NBeatsKAN", "NBeats_pkg", + "NBeatsKAN_pkg", + "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock", + "NBeatsAdapter", + "GridUpdateCallback", ] diff --git a/pytorch_forecasting/models/nbeats/_grid_callback.py b/pytorch_forecasting/models/nbeats/_grid_callback.py new file mode 100644 index 000000000..dabba0fb2 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_grid_callback.py @@ -0,0 +1,46 @@ +from lightning.pytorch.callbacks import Callback + + +class GridUpdateCallback(Callback): + """ + Custom callback to update the grid of the model during training at regular + intervals. + + Parameters + ---------- + update_interval : int + The frequency at which the grid is updated. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` + """ + + def __init__(self, update_interval): + self.update_interval = update_interval + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """ + Hook called at the end of each training batch. + + Updates the grid of KAN layers if the current step is a multiple of the update + interval. + + Parameters + ---------- + trainer : Trainer + The PyTorch Lightning Trainer object. + pl_module : LightningModule + The model being trained (LightningModule). + outputs : Any + Outputs from the model for the current batch. + batch : Any + The current batch of data. + batch_idx : int + Index of the current batch. + """ + # Check if the current step is a multiple of the update interval + if (trainer.global_step + 1) % self.update_interval == 0: + # Call the model's update_kan_grid method + pl_module.update_kan_grid() diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index e87fa649c..5a160e2cc 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -4,23 +4,89 @@ from typing import Optional -import torch from torch import nn -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base import BaseModel -from pytorch_forecasting.models.nbeats.sub_modules import ( +from pytorch_forecasting.layers._nbeats._blocks import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -from pytorch_forecasting.utils._dependencies import _check_matplotlib - - -class NBeats(BaseModel): - """N-Beats model for timeseries forecasting without covariates.""" +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter + + +class NBeats(NBeatsAdapter): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Parameters + ---------- + stack_types : list of str + One of the following values “generic”, “seasonality” or “trend”. + A list of strings of length 1 or `num_stacks`. Default and recommended + value for generic mode is ["generic"]. Recommended value for interpretable + mode is ["trend","seasonality"]. + num_blocks : list of int + The number of blocks per stack. Length 1 or `num_stacks`. Default for + generic mode is [1], interpretable mode is [3]. + num_block_layers : list of int + Number of fully connected layers with ReLU activation per block. Length 1 + or `num_stacks`. Default [4] for both modes. + width : list of int + Widths of fully connected layers with ReLU activation. List length 1 or + `num_stacks`. Default [512] for generic; [256, 2048] for interpretable. + sharing : list of bool + Whether weights are shared across blocks in a stack. List length 1 or + `num_stacks`. Default [False] for generic; [True] for interpretable. + expansion_coefficient_length : list of int + If type is "G", length of expansion coefficient; if "T", degree of + polynomial; if "S", minimum period (e.g., 2 for every timestep). List + length 1 or `num_stacks`. Default [32] for generic; [3] for interpretable. + prediction_length : int + Length of the forecast horizon. + context_length : int + Number of time units conditioning the predictions (lookback period). + Should be between 1-10x `prediction_length`. + dropout : float + Dropout probability applied in the network. Helps prevent overfitting. + Default is 0.1. + learning_rate : float + Learning rate used by the optimizer during training. Default is 1e-2. + log_interval : int + Interval (in steps) at which training logs are recorded. If -1, logging + is disabled. Default is -1. + log_gradient_flow : bool + Whether to log gradient flow during training. Useful for diagnosing + vanishing/exploding gradients. Default is False. + log_val_interval : int + Interval (in steps) at which validation metrics are logged. If None, + uses default logging behavior. Default is None. + weight_decay : float + Weight decay (L2 regularization) coefficient used by the optimizer to + reduce overfitting. Default is 1e-3. + loss + Loss to optimize. Defaults to `MASE()`. + reduce_on_plateau_patience : int + Patience after which learning rate is reduced by factor of 10. + backcast_loss_ratio : float + Weight of backcast loss relative to forecast loss. 1.0 gives equal weight; + default 0.0 means no backcast loss. + logging_metrics : nn.ModuleList of MultiHorizonMetric + List of metrics logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + **kwargs + Additional arguments forwarded to :py:class:`~BaseModel`. + """ # noqa: E501 @classmethod def _pkg(cls): @@ -51,55 +117,6 @@ def __init__( logging_metrics: nn.ModuleList = None, **kwargs, ): - """ - Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. - - Based on the article - `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if used as ensemble) outperformed all - other methods - including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably - the most - important benchmark for univariate time series forecasting. - - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform - N-BEATS. - - Args: - stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings - of length 1 or ‘num_stacks’. Default and recommended value - for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”] - num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. - Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length - 1 or ‘num_stacks’. - Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4] - width: Widths of the fully connected layers with ReLu activation in the blocks. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] - Recommended value for interpretable mode: [256, 2048] - sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] - Recommended value for interpretable mode: [True] - expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion - coefficient. - If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” - (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. - A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for - interpretable mode: [3] - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. - A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and - forecast lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: @@ -116,9 +133,9 @@ def __init__( logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() - self.save_hyperparameters() - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) # setup stacks self.net_blocks = nn.ModuleList() for stack_id, stack_type in enumerate(stack_types): @@ -130,7 +147,7 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, + dropout=dropout, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -138,8 +155,8 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - min_period=self.hparams.expansion_coefficient_lengths[stack_id], - dropout=self.hparams.dropout, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -148,296 +165,9 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, + dropout=dropout, ) else: raise ValueError(f"Unknown stack type {stack_type}") self.net_blocks.append(net_block) - - def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - target = x["encoder_cont"][..., 0] - - timesteps = self.hparams.context_length + self.hparams.prediction_length - generic_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - trend_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - seasonal_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - forecast = torch.zeros( - (target.size(0), self.hparams.prediction_length), - dtype=torch.float32, - device=self.device, - ) - - backcast = target # initialize backcast - for i, block in enumerate(self.net_blocks): - # evaluate block - backcast_block, forecast_block = block(backcast) - - # add for interpretation - full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) - if isinstance(block, NBEATSTrendBlock): - trend_forecast.append(full) - elif isinstance(block, NBEATSSeasonalBlock): - seasonal_forecast.append(full) - else: - generic_forecast.append(full) - - # update backcast and forecast - backcast = ( - backcast - backcast_block - ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 - forecast = forecast + forecast_block - - return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), - backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] - ), - trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - ) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ # noqa: E501 - new_kwargs = { - "prediction_length": dataset.max_prediction_length, - "context_length": dataset.max_encoder_length, - } - new_kwargs.update(kwargs) - - # validate arguments - assert isinstance( - dataset.target, str - ), "only one target is allowed (passed as string to dataset)" - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert dataset.min_encoder_length == dataset.max_encoder_length, ( - "only fixed encoder length is allowed," - " but min_encoder_length != max_encoder_length" - ) - - assert dataset.max_prediction_length == dataset.min_prediction_length, ( - "only fixed prediction length is allowed," - " but max_prediction_length != min_prediction_length" - ) - - assert ( - dataset.randomize_length is None - ), "length has to be fixed, but randomize_length is not None" - assert ( - not dataset.add_relative_time_idx - ), "add_relative_time_idx has to be False" - - assert ( - len(dataset.flat_categoricals) == 0 - and len(dataset.reals) == 1 - and len(dataset._time_varying_unknown_reals) == 1 - and dataset._time_varying_unknown_reals[0] == dataset.target - ), ( - "The only variable as input should be the" - " target which is part of time_varying_unknown_reals" - ) - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if ( - self.hparams.backcast_loss_ratio > 0 and not self.predicting - ): # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio - * self.hparams.prediction_length - / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, MASE): - backcast_loss = ( - self.loss(backcast, x["encoder_target"], x["decoder_target"]) - * backcast_weight - ) - else: - backcast_loss = ( - self.loss(backcast, x["encoder_target"]) * backcast_weight - ) - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - - def plot_interpretation( - self, - x: dict[str, torch.Tensor], - output: dict[str, torch.Tensor], - idx: int, - ax=None, - plot_seasonality_and_generic_on_secondary_axis: bool = False, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. - Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and - generic forecast on secondary axis in second panel. Defaults to False. - - Returns: - plt.Figure: matplotlib figure - """ # noqa: E501 - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - else: - fig = ax[0].get_figure() - - time = torch.arange( - -self.hparams.context_length, self.hparams.prediction_length - ) - - # plot target vs prediction - ax[0].plot( - time, - torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) - .detach() - .cpu(), - label="target", - ) - ax[0].plot( - time, - torch.cat( - [ - output["backcast"][idx].detach(), - output["prediction"][idx].detach(), - ], - dim=0, - ).cpu(), - label="prediction", - ) - ax[0].set_xlabel("Time") - - # plot blocks - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - next(prop_cycle) # prediction - next(prop_cycle) # observations - if plot_seasonality_and_generic_on_secondary_axis: - ax2 = ax[1].twinx() - ax2.set_ylabel("Seasonality / Generic") - else: - ax2 = ax[1] - for title in ["trend", "seasonality", "generic"]: - if title not in self.hparams.stack_types: - continue - if title == "trend": - ax[1].plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - else: - ax2.plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - ax[1].set_xlabel("Time") - ax[1].set_ylabel("Decomposition") - - fig.legend() - return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py new file mode 100644 index 000000000..9ec8d6324 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -0,0 +1,337 @@ +""" +N-Beats model adapter for timeseries forecasting without covariates. +""" + +from typing import Optional + +import torch + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) +from pytorch_forecasting.metrics import MASE +from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NBeatsAdapter(BaseModel): + """ + Initialize NBeats Adapter. + + Parameters + ---------- + **kwargs + additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Pass forward of network. + + Parameters + ---------- + x : dict of str to torch.Tensor + input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns + ------- + dict of str to torch.Tensor + output of model + """ + target = x["encoder_cont"][..., 0] + + timesteps = self.hparams.context_length + self.hparams.prediction_length + generic_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + trend_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + seasonal_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + forecast = torch.zeros( + (target.size(0), self.hparams.prediction_length), + dtype=torch.float32, + device=self.device, + ) + + backcast = target # initialize backcast + for i, block in enumerate(self.net_blocks): + # evaluate block + backcast_block, forecast_block = block(backcast) + + # add for interpretation + full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) + if isinstance(block, NBEATSTrendBlock): + trend_forecast.append(full) + elif isinstance(block, NBEATSSeasonalBlock): + seasonal_forecast.append(full) + else: + generic_forecast.append(full) + + # update backcast and forecast + backcast = ( + backcast - backcast_block + ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 + forecast = forecast + forecast_block + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output( + prediction=target - backcast, target_scale=x["target_scale"] + ), + trend=self.transform_output( + torch.stack(trend_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + seasonality=self.transform_output( + torch.stack(seasonal_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + generic=self.transform_output( + torch.stack(generic_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class + `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Parameters + ---------- + dataset : TimeSeriesDataSet + dataset where sole predictor is the target. + **kwargs + additional arguments to be passed to ``__init__`` method. + + Returns + ------- + NBeats + """ # noqa: E501 + new_kwargs = { + "prediction_length": dataset.max_prediction_length, + "context_length": dataset.max_encoder_length, + } + new_kwargs.update(kwargs) + + # validate arguments + assert isinstance( + dataset.target, str + ), "only one target is allowed (passed as string to dataset)" + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert dataset.min_encoder_length == dataset.max_encoder_length, ( + "only fixed encoder length is allowed," + " but min_encoder_length != max_encoder_length" + ) + + assert dataset.max_prediction_length == dataset.min_prediction_length, ( + "only fixed prediction length is allowed," + " but max_prediction_length != min_prediction_length" + ) + + assert ( + dataset.randomize_length is None + ), "length has to be fixed, but randomize_length is not None" + assert ( + not dataset.add_relative_time_idx + ), "add_relative_time_idx has to be False" + + assert ( + len(dataset.flat_categoricals) == 0 + and len(dataset.reals) == 1 + and len(dataset._time_varying_unknown_reals) == 1 + and dataset._time_varying_unknown_reals[0] == dataset.target + ), ( + "The only variable as input should be the" + " target which is part of time_varying_unknown_reals" + ) + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if ( + self.hparams.backcast_loss_ratio > 0 and not self.predicting + ): # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio + * self.hparams.prediction_length + / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, MASE): + backcast_loss = ( + self.loss(backcast, x["encoder_target"], x["decoder_target"]) + * backcast_weight + ) + else: + backcast_loss = ( + self.loss(backcast, x["encoder_target"]) * backcast_weight + ) + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + + def plot_interpretation( + self, + x: dict[str, torch.Tensor], + output: dict[str, torch.Tensor], + idx: int, + ax=None, + plot_seasonality_and_generic_on_secondary_axis: bool = False, + ): + """ + Plot interpretation. + + Plot two panels: prediction and backcast vs actuals and + decomposition of prediction into trend, seasonality and generic forecast. + + Parameters + ---------- + x : dict of str to torch.Tensor + network input + output : dict of str to torch.Tensor + network output + idx : int + index of sample for which to plot the interpretation. + ax : list of matplotlib.axes + list of two matplotlib axes onto which to plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis : bool + if to plot seasonality and generic forecast on secondary axis in second panel. + Defaults to False. + + Returns + ------- + matplotlib.figure.Figure + matplotlib figure + """ # noqa: E501 + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange( + -self.hparams.context_length, self.hparams.prediction_length + ) + + # plot target vs prediction + ax[0].plot( + time, + torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) + .detach() + .cpu(), + label="target", + ) + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + if plot_seasonality_and_generic_on_secondary_axis: + ax2 = ax[1].twinx() + ax2.set_ylabel("Seasonality / Generic") + else: + ax2 = ax[1] + for title in ["trend", "seasonality", "generic"]: + if title not in self.hparams.stack_types: + continue + if title == "trend": + ax[1].plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + else: + ax2.plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py new file mode 100644 index 000000000..555483734 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -0,0 +1,290 @@ +""" +N-Beats model with KAN blocks for timeseries forecasting without covariates. +""" + +from typing import Optional + +import torch +from torch import nn + +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter + + +class NBeatsKAN(NBeatsAdapter): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Parameters + ---------- + stack_types : list of str + One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks : list of int + The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers : list of int + Number of fully connected layers with ReLu activation per block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + widths : list of int + Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] + sharing : list of bool + Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_lengths : list of int + If the type is “G” (generic), then the length of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length : int + Length of the prediction. Also known as 'horizon'. + context_length : int + Number of time units that condition the predictions. + Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio : float + Weight of backcast in comparison to forecast when calculating the loss. + A weight of 1.0 means that forecast and backcast loss is weighted the same + (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight. + loss : MultiHorizonMetric + Loss to optimize. Defaults to MASE(). + log_gradient_flow : bool + If to log gradient flow, this takes time and should be only done to diagnose + training failures. + reduce_on_plateau_patience : int + Patience after which learning rate is reduced by a factor of 10 + logging_metrics : nn.ModuleList of MultiHorizonMetric + List of metrics that are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + num : int + Parameter for KAN layer. the number of grid intervals = G. + Default: 5. + k : int + Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : float + Parameter for KAN layer. the scale of noise injected at initialization. + Default: 0.1. + scale_base_mu : float + Parameter for KAN layer. the scale of the residual function b(x) is intialized + to be N(scale_base_mu, scale_base_sigma^2). Deafult: 0.0. + scale_base_sigma : float + Parameter for KAN layer. the scale of the residual function b(x) is intialized + to be N(scale_base_mu, scale_base_sigma^2). Deafult: 1.0. + scale_sp : float + Parameter for KAN layer. the scale of the base function spline(x). Deafult: 1.0. + base_fun : callable + Parameter for KAN layer. residual function b(x). Default: None. + grid_eps : float + Parameter for KAN layer. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + grid_range : list of int + Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. + Default: None. + sp_trainable : bool + Parameter for KAN layer. If true, scale_sp is trainable. Default: True. + sb_trainable : bool + Parameter for KAN layer. If true, scale_base is trainable. Default: True. + sparse_init : bool + Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. + Default: False. + **kwargs + Additional arguments to :py:class:`~BaseModel`. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` + + Notes + -------- + The KAN blocks are based on the Kolmogorov-Arnold representation theorem and replace fixed MLP edge weights + with learnable univariate spline functions. This allows KAN-augmented N-BEATS to better capture complex patterns, + improve interpretability, and achieve parameter efficiency. Additionally, when applied in a doubly-residual + adversarial framework, the model excels at zero-shot time-series forecasting across markets. + + Key differences from original N-BEATS: + - MLP layers are replaced by KAN layers with spline-based edge functions. + - Each weight is a trainable function, not a scalar. + - Enables visualization of learned functions and better domain adaptation. + - Yields improved accuracy and interpretability with fewer parameters. + + References + ---------- + .. [1] Z. Liu et al. (2024), “KAN: Kolmogorov-Arnold Networks” + propose replacing MLP weights with spline-based learnable edge functions, enabling improved accuracy, + interpretability, and scaling behavior compared to standard MLPs. + .. [2] A. Bhattacharya & N. Haq (2024), “Zero Shot Time Series Forecasting Using Kolmogorov Arnold Networks” + incorporate KAN layers into a doubly-residual N-BEATS architecture with adversarial domain adaptation, + achieving strong zero-shot cross-market electricity price forecasting performance. + """ # noqa: E501 + + @classmethod + def _pkg(cls): + """Package for the model.""" + from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg + + return NBeatsKAN_pkg + + def __init__( + self, + stack_types: Optional[list[str]] = None, + num_blocks: Optional[list[int]] = None, + num_block_layers: Optional[list[int]] = None, + widths: Optional[list[int]] = None, + sharing: Optional[list[bool]] = None, + expansion_coefficient_lengths: Optional[list[int]] = None, + prediction_length: int = 1, + context_length: int = 1, + dropout: float = 0.1, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: list[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, + **kwargs, + ): + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] + if expansion_coefficient_lengths is None: + expansion_coefficient_lengths = [3, 7] + if sharing is None: + sharing = [True, True] + if widths is None: + widths = [32, 512] + if num_block_layers is None: + num_block_layers = [3, 3] + if num_blocks is None: + num_blocks = [3, 3] + if stack_types is None: + stack_types = ["trend", "seasonality"] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # Bundle KAN parameters into a dictionary + kan_params = { + "num": num, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.kan_params = kan_params + # setup stacks + self.net_blocks = nn.ModuleList() + for stack_id, stack_type in enumerate(stack_types): + for _ in range(num_blocks[stack_id]): + if stack_type == "generic": + net_block = NBEATSGenericBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "seasonality": + net_block = NBEATSSeasonalBlock( + units=self.hparams.widths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "trend": + net_block = NBEATSTrendBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + else: + raise ValueError(f"Unknown stack type {stack_type}") + + self.net_blocks.append(net_block) + + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + + Examples + -------- + See the full example in: + `examples/nbeats_with_kan.py` + """ + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py new file mode 100644 index 000000000..2cda8c996 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan_pkg.py @@ -0,0 +1,83 @@ +"""NBeatsKAN package container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class NBeatsKAN_pkg(_BasePtForecaster): + """NBeatsKAN package container.""" + + _tags = { + "info:name": "NBeatsKAN", + "info:compute": 1, + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "authors": ["Sohaib-Ahmed21"], + "capability:exogenous": False, + "capability:multivariate": False, + "capability:pred_int": False, + "capability:flexible_history_length": False, + "capability:cold_start": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import NBeatsKAN + + return NBeatsKAN + + @classmethod + def get_base_test_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + return [ + {"backcast_loss_ratio": 0.0}, # pure forecast loss + {"backcast_loss_ratio": 1.0}, # equal forecast/backcast + { + "stack_types": ["generic"], + "expansion_coefficient_lengths": [16], + }, + { + "num_blocks": [1, 2], + "num_block_layers": [2, 3], + }, # varying block structure + { + "num": 7, + "k": 4, + "sparse_init": True, + "grid_range": [-0.5, 0.5], + "sp_trainable": False, + }, # complex KAN config + ] + + @classmethod + def _get_test_dataloaders_from(cls, params): + loss = params.get("loss", None) + data_loader_kwargs = params.get("data_loader_kwargs", {}) + from pytorch_forecasting.metrics import TweedieLoss + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates, + dataloaders_fixed_window_without_covariates, + make_dataloaders, + ) + + if isinstance(loss, TweedieLoss): + dwc = data_with_covariates() + dl_default_kwargs = dict( + target="target", + time_varying_unknown_reals=["target"], + add_relative_time_idx=False, + ) + dl_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders(dwc, **dl_default_kwargs) + return dataloaders_with_covariates + + return dataloaders_fixed_window_without_covariates() diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 6abb637f5..c7ec58972 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -1,219 +1,12 @@ """ -Implementation of ``nn.Modules`` for N-Beats model. -""" - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def linear(input_size, output_size, bias=True, dropout: int = None): - lin = nn.Linear(input_size, output_size, bias=bias) - if dropout is not None: - return nn.Sequential(nn.Dropout(dropout), lin) - else: - return lin - - -def linspace( - backcast_length: int, forecast_length: int, centered: bool = False -) -> tuple[np.ndarray, np.ndarray]: - if centered: - norm = max(backcast_length, forecast_length) - start = -backcast_length - stop = forecast_length - 1 - else: - norm = backcast_length + forecast_length - start = 0 - stop = backcast_length + forecast_length - 1 - lin_space = np.linspace( - start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32 - ) - b_ls = lin_space[:backcast_length] - f_ls = lin_space[backcast_length:] - return b_ls, f_ls - - -class NBEATSBlock(nn.Module): - def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - share_thetas=False, - dropout=0.1, - ): - super().__init__() - self.units = units - self.thetas_dim = thetas_dim - self.backcast_length = backcast_length - self.forecast_length = forecast_length - self.share_thetas = share_thetas - - fc_stack = [ - nn.Linear(backcast_length, units), - nn.ReLU(), - ] - for _ in range(num_block_layers - 1): - fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) - self.fc = nn.Sequential(*fc_stack) - - if share_thetas: - self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) - else: - self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) - self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) - - def forward(self, x): - return self.fc(x) - - -class NBEATSSeasonalBlock(NBEATSBlock): - def __init__( - self, - units, - thetas_dim=None, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - nb_harmonics=None, - min_period=1, - dropout=0.1, - ): - if nb_harmonics: - thetas_dim = nb_harmonics - else: - thetas_dim = forecast_length - self.min_period = min_period - - super().__init__( - units=units, - thetas_dim=thetas_dim, - num_block_layers=num_block_layers, - backcast_length=backcast_length, - forecast_length=forecast_length, - share_thetas=True, - dropout=dropout, - ) - - backcast_linspace, forecast_linspace = linspace( - backcast_length, forecast_length, centered=False - ) +Backward-compatibility shim for N-BEATS blocks. +Real implementations live in `pytorch_forecasting.layers._nbeats._blocks`. - p1, p2 = ( - (thetas_dim // 2, thetas_dim // 2) - if thetas_dim % 2 == 0 - else (thetas_dim // 2, thetas_dim // 2 + 1) - ) - s1_b = torch.tensor( - np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * backcast_linspace), - dtype=torch.float32, - ) # H/2-1 - s2_b = torch.tensor( - np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * backcast_linspace), - dtype=torch.float32, - ) - self.register_buffer("S_backcast", torch.cat([s1_b, s2_b])) - - s1_f = torch.tensor( - np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * forecast_linspace), - dtype=torch.float32, - ) # H/2-1 - s2_f = torch.tensor( - np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * forecast_linspace), - dtype=torch.float32, - ) - self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) - - def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: - x = super().forward(x) - amplitudes_backward = self.theta_b_fc(x) - backcast = amplitudes_backward.mm(self.S_backcast) - amplitudes_forward = self.theta_f_fc(x) - forecast = amplitudes_forward.mm(self.S_forecast) - - return backcast, forecast - - def get_frequencies(self, n): - return np.linspace( - 0, (self.backcast_length + self.forecast_length) / self.min_period, n - ) - - -class NBEATSTrendBlock(NBEATSBlock): - def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, - ): - super().__init__( - units=units, - thetas_dim=thetas_dim, - num_block_layers=num_block_layers, - backcast_length=backcast_length, - forecast_length=forecast_length, - share_thetas=True, - dropout=dropout, - ) - - backcast_linspace, forecast_linspace = linspace( - backcast_length, forecast_length, centered=True - ) - norm = np.sqrt( - forecast_length / thetas_dim - ) # ensure range of predictions is comparable to input - thetas_dims_range = np.array(range(thetas_dim)) - coefficients = torch.tensor( - backcast_linspace ** thetas_dims_range[:, None], - dtype=torch.float32, - ) - self.register_buffer("T_backcast", coefficients * norm) - coefficients = torch.tensor( - forecast_linspace ** thetas_dims_range[:, None], - dtype=torch.float32, - ) - self.register_buffer("T_forecast", coefficients * norm) - - def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: - x = super().forward(x) - backcast = self.theta_b_fc(x).mm(self.T_backcast) - forecast = self.theta_f_fc(x).mm(self.T_forecast) - return backcast, forecast - - -class NBEATSGenericBlock(NBEATSBlock): - def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, - ): - super().__init__( - units=units, - thetas_dim=thetas_dim, - num_block_layers=num_block_layers, - backcast_length=backcast_length, - forecast_length=forecast_length, - dropout=dropout, - ) - - self.backcast_fc = nn.Linear(thetas_dim, backcast_length) - self.forecast_fc = nn.Linear(thetas_dim, forecast_length) - - def forward(self, x): - x = super().forward(x) - - theta_b = F.relu(self.theta_b_fc(x)) - theta_f = F.relu(self.theta_f_fc(x)) +# TODO v2: remove this file. +""" - return self.backcast_fc(theta_b), self.forecast_fc(theta_f) +from pytorch_forecasting.layers._nbeats._blocks import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +)