Skip to content

Conversation

Sohaib-Ahmed21
Copy link
Contributor

@Sohaib-Ahmed21 Sohaib-Ahmed21 commented Jan 13, 2025

Description

Fixes: #1741

This PR adds Kolmogorov Arnold(KAN) Blocks in NBeats and also does refactoring of NBeats. Implementation of KAN blocks' layers is taken from original paper code.

Changes in Structure

  • Introduced the NBEATSKAN module, which enables usage of KAN blocks within the NBEATS architecture.

  • Integrated KANLayer logic, implemented in kan_layer.py, which handles KAN-specific operations such as:

    • Spline coefficient computation,
    • Grid initialization and updates, etc.
  • Imported KANLayer to submodules.py for block operations, allowing NBEATSKAN to delegate block-level behavior through use_kan=True.

  • Added the NBEATSAdapter class to encapsulate common methods shared by both NBEATS and NBEATSKAN, including:

    • Standard training, forward logic, etc.
    • Excludes block initialization (__init__), which is separately defined in each class to maintain architectural flexibility.

GridUpdateCallback

When training KAN-based models, the grid can be iteratively refined during training for better performance.

To support this, logic from the original [pykan](https://github.com/KindXiaoming/pykan) implementation has been adapted to define a custom callback named GridUpdateCallback. This callback automatically updates the grid at specified training steps, improving model accuracy and convergence.

This callback has been tested successfully and demonstrates improved results in practice.

An example usage is provided in:
examples/nbeats_with_kan.py

Checklist

  • Linked issues (if existing)
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Make sure to have fun coding!

@Sohaib-Ahmed21
Copy link
Contributor Author

@fkiraly @benHeid this PR is ready for review. Kindly review it. Thanks!

@Sohaib-Ahmed21
Copy link
Contributor Author

@fkiraly @benHeid can you kindly review/merge it so I integrate NBEATSX modification in NBEATS without conflicts as I have asked @julian-fong and he is not working on NBEATSX.

Copy link
Collaborator

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully reviewed yet. Will continue in the next days. But I share my current comments so that you already receive some feedbacks.

When training KANs, the grid can be iteratively be refined. I wonder, if there is a way to implement this also here. However, this might probably more difficult and require changes to the trainer. So probably out of scope for this PR. Do you have opinions on that?

@@ -0,0 +1,528 @@
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The license of the original implementation is MIT. So in theory it is okay to copy the file. However, please add some credits at the top of the file.

Alternatively, we could think about adding KAN as a dependency.

@fkiraly do you have any additions on that matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will add the appropriate credits at the top of the file. Additionally, I agree that adding KAN as a dependency—perhaps as a soft dependency—seems like a good idea, especially considering its increasing relevance in time series forecasting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fkiraly pinging you again to check if this is okay for you :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which package are we exactly planning to add as a soft dep?

If it is a single layer, I think copying it over and including the license is perhaps better for now, because we do not have machinery to manage soft dependencies (like scikit-base or similar).

The proposed design in here sktime/enhancement-proposals#39 would allow that, but right now I think this would require a significant amounts of custom code to handle.

Or, is there an easy way that I am not seeing how the soft dependency import would work for part of the NN?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which package are we exactly planning to add as a soft dep?

It is pykan library , reference https://pypi.org/project/pykan/

If it is a single layer, I think copying it over and including the license is perhaps better for now, because we do not have machinery to manage soft dependencies (like scikit-base or similar).

The proposed design in here sktime/enhancement-proposals#39 would allow that, but right now I think this would require a significant amounts of custom code to handle.

Or, is there an easy way that I am not seeing how the soft dependency import would work for part of the NN?

yes it is a single layer, only used in NBEATS. Also the library pykan has much more, but we only need this. I have implemented what you already suggested above. I have copied it over and included the license.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, makes sense, as long as the license points to the original source and is provided in full form (assuming it hsa the usual reqiurement to reproduce)

@Sohaib-Ahmed21
Copy link
Contributor Author

Not fully reviewed yet. Will continue in the next days. But I share my current comments so that you already receive some feedbacks.

Thanks! Will address these reviews soon.

@Sohaib-Ahmed21
Copy link
Contributor Author

When training KANs, the grid can be iteratively be refined. I wonder, if there is a way to implement this also here. However, this might probably more difficult and require changes to the trainer. So probably out of scope for this PR. Do you have opinions on that?

I'll explore this and share my thoughts.

@Sohaib-Ahmed21
Copy link
Contributor Author

Sohaib-Ahmed21 commented Jan 23, 2025

@benHeid I have addressed the reviews. Kindly review the updated PR.

When training KANs, the grid can be iteratively be refined. I wonder, if there is a way to implement this also here. However, this > might probably more difficult and require changes to the trainer. So probably out of scope for this PR. Do you have opinions on > that?

To address this, I have taken logic from original implementation of pykan library and made custom Callback i.e. GridUpdateCallback which updates grid after specified steps. This has been tested and working fine with improved results. Your thoughts will be helpful in this regard.

@Sohaib-Ahmed21
Copy link
Contributor Author

Sohaib-Ahmed21 commented Jan 26, 2025

@benHeid @fkiraly I have addressed the reviews. Kindly review it so I proceed integrating NBEATSX without conflicts.

Copy link
Collaborator

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have one last comment. But I also would like to hear @fkiraly opinion on the license of the kan_layer

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 26, 2025

FYI @Sohaib-Ahmed21, I have upgraded your code to the most recent package/test structure, so the tests should now genuinely run for the NBeatsKAN model. Previously, this was prevented because it was missing the (newly introduced) get_cls and _pkg methods

@Sohaib-Ahmed21
Copy link
Contributor Author

FYI @Sohaib-Ahmed21, I have upgraded your code to the most recent package/test structure, so the tests should now genuinely run for the NBeatsKAN model. Previously, this was prevented because it was missing the (newly introduced) get_cls and _pkg methods

Thanks @fkiraly , kindly let me know if something else is needed here in this PR.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests are running now and this looks like a genuine failure in the new code

@Sohaib-Ahmed21
Copy link
Contributor Author

the tests are running now and this looks like a genuine failure in the new code

Okay, I'll look into and fix it.

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 27, 2025

FYI, I am working on a check_estimator utility that should help you with running tests selectively from a python prompt:
#1954

This is still WiP but might be helpful

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 28, 2025

ah, seems to work now! Can you explain why it was failing?

@Sohaib-Ahmed21
Copy link
Contributor Author

ah, seems to work now! Can you explain why it was failing?

It failed because the test data used before could give TweedieLoss invalid targets; now the dataloader ensures valid non-negative targets, so the loss runs without errors.

@Sohaib-Ahmed21
Copy link
Contributor Author

@fkiraly kindly review/merge this pull request, thanks!


__all__ = [
"NBeats",
"NBEATSGenericBlock",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid breaking existing code, we need to keep exporting these in the current location, so import them from the networks location.

We should also add a # todo v2: remove imports to those, at a major we can break.

return b_ls, f_ls


class NBEATSBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here - I would also import the blocks here

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Very small thing, I think we should restore the exports of the layers from their old locations, to avoid breaking user code that imported these previously public classes.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, great addition!

@fkiraly fkiraly merged commit e1cc1ce into sktime:main Sep 9, 2025
34 of 35 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in May - Sep 2025 mentee projects Sep 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

[ENH] Kolmogorov Arnold Block for NBeats network
4 participants