Skip to content

Conversation

ajacoby9
Copy link

@ajacoby9 ajacoby9 commented Jul 24, 2025

This PR fixes #609.

Checklist

  • Code follows the project’s Code Style Guidelines
  • Tests have been added or updated
  • Documentation has been updated if necessary
  • Pull request is linked to an open issue

@ajacoby9 ajacoby9 changed the base branch from master to dev July 24, 2025 10:27
@FilippoOlivo FilippoOlivo added bug Something isn't working enhancement New feature or request pr-to-fix Label for PR that needs modification labels Jul 24, 2025
@FilippoOlivo FilippoOlivo marked this pull request as ready for review August 4, 2025 10:28
@FilippoOlivo FilippoOlivo added pr-to-review Label for PR that are ready to been reviewed and removed pr-to-fix Label for PR that needs modification labels Aug 4, 2025
@FilippoOlivo FilippoOlivo requested review from GiovanniCanali, dario-coscia and ndem0 and removed request for GiovanniCanali August 4, 2025 10:28
Copy link
Collaborator

@GiovanniCanali GiovanniCanali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @FilippoOlivo,
I left some comments to be addressed before the approval.

Please, add some comments in the for loop of the basis function to explain better what is being computed. Also, I would remove the comments on shapes all over the code.

At last, I would update the tests since we changed a bit the logic.

@@ -9,7 +9,9 @@ class Spline(torch.nn.Module):
Spline model class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, include a short description of the B-spline model

if self.knots.ndim != 1:
raise ValueError("Knot vector must be one-dimensional.")
if self.knots.ndim > 2:
raise ValueError("Knot vector must be one or two-dimensional.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is not consistent with the documentation. Please, update

c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim)
c[i, :, :] = c_i.T # (out_dim, n_basis)

self.control_points = torch.nn.Parameter(c)

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, move all properties and setters to the very end of the file.

return

# Find the rightmost interval with positive width
knots = self.knots
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would delete this line and keep using self.knots

"""
Recursive method to compute the basis functions of the spline.
if not isinstance(self.knots, torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform this check in the __init__? Is it necessary?

return c1 + c2
in_dim = A.shape[1]
out_dim = y_eval.shape[2]
n_basis = A.shape[2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables in_dim, out_dim, n_basis are used just once. Consider using the shapes directly where needed

# A_i is (batch, n_basis)
# y_i is (batch, out_dim)
A_i = A[:, i, :]
y_i = y_eval[:, i, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using A[:, i, :] and y_eval[:, i, :] where needed, without saving additional variables

@@ -131,9 +220,12 @@ def control_points(self, value):
dim = value.get("dim", 1)
value = torch.zeros(n, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting value to zero instead of using the argument passed to the setter?

@@ -193,7 +289,23 @@ def forward(self, x):
k = self.k
c = self.control_points
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use t, k, c instead of the class attributes?

def __init__(self, order=4, knots=None, control_points=None) -> None:
def __init__(
self, order=4, knots=None, control_points=None, grid_extension=True
):
"""
Initialization of the :class:`Spline` class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc string for grid_extension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request pr-to-review Label for PR that are ready to been reviewed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve spline efficiency and correct unexpected behavior
3 participants