-
Notifications
You must be signed in to change notification settings - Fork 81
Improve and fix spline #610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @FilippoOlivo,
I left some comments to be addressed before the approval.
Please, add some comments in the for loop of the basis
function to explain better what is being computed. Also, I would remove the comments on shapes all over the code.
At last, I would update the tests since we changed a bit the logic.
@@ -9,7 +9,9 @@ class Spline(torch.nn.Module): | |||
Spline model class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, include a short description of the B-spline model
if self.knots.ndim != 1: | ||
raise ValueError("Knot vector must be one-dimensional.") | ||
if self.knots.ndim > 2: | ||
raise ValueError("Knot vector must be one or two-dimensional.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is not consistent with the documentation. Please, update
c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim) | ||
c[i, :, :] = c_i.T # (out_dim, n_basis) | ||
|
||
self.control_points = torch.nn.Parameter(c) | ||
|
||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, move all properties and setters to the very end of the file.
return | ||
|
||
# Find the rightmost interval with positive width | ||
knots = self.knots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would delete this line and keep using self.knots
""" | ||
Recursive method to compute the basis functions of the spline. | ||
if not isinstance(self.knots, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we perform this check in the __init__
? Is it necessary?
return c1 + c2 | ||
in_dim = A.shape[1] | ||
out_dim = y_eval.shape[2] | ||
n_basis = A.shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables in_dim
, out_dim
, n_basis
are used just once. Consider using the shapes directly where needed
# A_i is (batch, n_basis) | ||
# y_i is (batch, out_dim) | ||
A_i = A[:, i, :] | ||
y_i = y_eval[:, i, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using A[:, i, :]
and y_eval[:, i, :]
where needed, without saving additional variables
@@ -131,9 +220,12 @@ def control_points(self, value): | |||
dim = value.get("dim", 1) | |||
value = torch.zeros(n, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we setting value
to zero instead of using the argument passed to the setter?
@@ -193,7 +289,23 @@ def forward(self, x): | |||
k = self.k | |||
c = self.control_points |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use t
, k
, c
instead of the class attributes?
def __init__(self, order=4, knots=None, control_points=None) -> None: | ||
def __init__( | ||
self, order=4, knots=None, control_points=None, grid_extension=True | ||
): | ||
""" | ||
Initialization of the :class:`Spline` class. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing doc string for grid_extension
This PR fixes #609.
Checklist