-
Notifications
You must be signed in to change notification settings - Fork 81
Improve and fix spline #610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,9 @@ class Spline(torch.nn.Module): | |
Spline model class. | ||
""" | ||
|
||
def __init__(self, order=4, knots=None, control_points=None) -> None: | ||
def __init__( | ||
self, order=4, knots=None, control_points=None, grid_extension=True | ||
): | ||
""" | ||
Initialization of the :class:`Spline` class. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing doc string for |
||
|
@@ -20,7 +22,7 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: | |
``None``. | ||
:raises ValueError: If the order is negative. | ||
:raises ValueError: If both knots and control points are ``None``. | ||
:raises ValueError: If the knot tensor is not one-dimensional. | ||
:raises ValueError: If the knot tensor is not one or two dimensional. | ||
""" | ||
super().__init__() | ||
|
||
|
@@ -33,6 +35,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: | |
|
||
self.order = order | ||
self.k = order - 1 | ||
self.grid_extension = grid_extension | ||
|
||
# Cache for performance optimization | ||
self._boundary_interval_idx = None | ||
|
||
if knots is not None and control_points is not None: | ||
self.knots = knots | ||
|
@@ -65,45 +71,154 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: | |
else: | ||
raise ValueError("Knots and control points cannot be both None.") | ||
|
||
if self.knots.ndim != 1: | ||
raise ValueError("Knot vector must be one-dimensional.") | ||
if self.knots.ndim > 2: | ||
raise ValueError("Knot vector must be one or two-dimensional.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message is not consistent with the documentation. Please, update |
||
|
||
# Precompute boundary interval index for performance | ||
self._compute_boundary_interval() | ||
|
||
def basis(self, x, k, i, t): | ||
def _compute_boundary_interval(self): | ||
""" | ||
Precompute the rightmost non-degenerate interval index for performance. | ||
This avoids the search loop in the basis function on every call. | ||
""" | ||
Recursive method to compute the basis functions of the spline. | ||
# Handle multi-dimensional knots | ||
if self.knots.ndim > 1: | ||
# For multi-dimensional knots, we'll handle boundary detection in | ||
# the basis function | ||
self._boundary_interval_idx = None | ||
return | ||
|
||
# For 1D knots, find the rightmost non-degenerate interval | ||
for i in range(len(self.knots) - 2, -1, -1): | ||
if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found | ||
self._boundary_interval_idx = i | ||
return | ||
|
||
self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0 | ||
|
||
def basis(self, x, k, knots): | ||
""" | ||
Compute the basis functions for the spline using an iterative approach. | ||
This is a vectorized implementation based on the Cox-de Boor recursion. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we still using recursion? |
||
|
||
:param torch.Tensor x: The points to be evaluated. | ||
:param int k: The spline degree. | ||
:param int i: The index of the interval. | ||
:param torch.Tensor t: The tensor of knots. | ||
:param torch.Tensor knots: The tensor of knots. | ||
:return: The basis functions evaluated at x | ||
:rtype: torch.Tensor | ||
""" | ||
|
||
if k == 0: | ||
a = torch.where( | ||
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 | ||
if x.ndim == 1: | ||
x = x.unsqueeze(1) # (batch_size, 1) | ||
if x.ndim == 2: | ||
x = x.unsqueeze(2) # (batch_size, in_dim, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of checking dimensions, we could simply add |
||
|
||
if knots.ndim == 1: | ||
knots = knots.unsqueeze(0) # (1, n_knots) | ||
if knots.ndim == 2: | ||
knots = knots.unsqueeze(0) # (1, in_dim, n_knots) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for checks, just type |
||
|
||
# Base case: k=0 | ||
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) | ||
basis = basis.to(x.dtype) | ||
|
||
if self._boundary_interval_idx is not None: | ||
i = self._boundary_interval_idx | ||
tolerance = 1e-10 | ||
x_squeezed = x.squeeze(-1) | ||
knot_left = knots[..., i] | ||
knot_right = knots[..., i + 1] | ||
|
||
at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance | ||
in_rightmost_interval = ( | ||
x_squeezed >= knot_left | ||
) & at_right_boundary | ||
|
||
if torch.any(in_rightmost_interval): | ||
# For points at the boundary, ensure they're included in the | ||
# rightmost interval | ||
basis[..., i] = torch.logical_or( | ||
basis[..., i].bool(), in_rightmost_interval | ||
).to(basis.dtype) | ||
|
||
# Iterative step (Cox-de Boor recursion) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not seem we are actually using recursion |
||
for i in range(1, k + 1): | ||
# First term of the recursion | ||
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] | ||
denom1 = torch.where( | ||
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 | ||
) | ||
if i == len(t) - self.order - 1: | ||
a = torch.where(x == t[-1], 1.0, a) | ||
a.requires_grad_(True) | ||
return a | ||
numer1 = x - knots[..., : -(i + 1)] | ||
term1 = (numer1 / denom1) * basis[..., :-1] | ||
|
||
if t[i + k] == t[i]: | ||
c1 = torch.tensor([0.0] * len(x), requires_grad=True) | ||
else: | ||
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) | ||
denom2 = knots[..., i + 1 :] - knots[..., 1:-i] | ||
denom2 = torch.where( | ||
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 | ||
) | ||
numer2 = knots[..., i + 1 :] - x | ||
term2 = (numer2 / denom2) * basis[..., 1:] | ||
|
||
basis = term1 + term2 | ||
|
||
return basis | ||
|
||
def compute_control_points(self, x_eval, y_eval): | ||
""" | ||
Compute control points from given evaluations using least squares. | ||
This method fits the control points to match the target y_eval values. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing param description |
||
""" | ||
# (batch, in_dim) | ||
A = self.basis(x_eval, self.k, self.knots) | ||
# (batch, in_dim, n_basis) | ||
|
||
if t[i + k + 1] == t[i + 1]: | ||
c2 = torch.tensor([0.0] * len(x), requires_grad=True) | ||
in_dim = A.shape[1] | ||
out_dim = y_eval.shape[2] | ||
n_basis = A.shape[2] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variables |
||
c = torch.zeros(in_dim, out_dim, n_basis).to(A.device) | ||
|
||
for i in range(in_dim): | ||
# A_i is (batch, n_basis) | ||
# y_i is (batch, out_dim) | ||
A_i = A[:, i, :] | ||
y_i = y_eval[:, i, :] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using |
||
c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim) | ||
c[i, :, :] = c_i.T # (out_dim, n_basis) | ||
|
||
self.control_points = torch.nn.Parameter(c) | ||
|
||
def forward(self, x): | ||
""" | ||
Forward pass for the :class:`Spline` model. | ||
|
||
:param torch.Tensor x: The input tensor. | ||
:return: The output tensor. | ||
:rtype: torch.Tensor | ||
""" | ||
t = self.knots | ||
k = self.k | ||
c = self.control_points | ||
|
||
# Create the basis functions | ||
# B will have shape (batch, in_dim, n_basis) | ||
B = self.basis(x, k, t) | ||
|
||
# KAN case where control points are (in_dim, out_dim, n_basis) | ||
if c.ndim == 3: | ||
y_ij = torch.einsum( | ||
"bil,iol->bio", B, c | ||
) # (batch, in_dim, out_dim) | ||
# sum over input dimensions | ||
y = torch.sum(y_ij, dim=1) # (batch, out_dim) | ||
# Original test case | ||
else: | ||
c2 = ( | ||
(t[i + k + 1] - x) | ||
/ (t[i + k + 1] - t[i + 1]) | ||
* self.basis(x, k - 1, i + 1, t) | ||
) | ||
B = B.squeeze(1) # (batch, n_basis) | ||
if c.ndim == 1: | ||
y = torch.einsum("bi,i->b", B, c) | ||
else: | ||
y = torch.einsum("bi,ij->bj", B, c) | ||
|
||
return c1 + c2 | ||
return y | ||
|
||
@property | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please, move all properties and setters to the very end of the file. |
||
def control_points(self): | ||
|
@@ -131,9 +246,12 @@ def control_points(self, value): | |
dim = value.get("dim", 1) | ||
value = torch.zeros(n, dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we setting |
||
|
||
if not isinstance(value, torch.nn.Parameter): | ||
value = torch.nn.Parameter(value) | ||
|
||
if not isinstance(value, torch.Tensor): | ||
raise ValueError("Invalid value for control_points") | ||
self._control_points = torch.nn.Parameter(value, requires_grad=True) | ||
self._control_points = value | ||
|
||
@property | ||
def knots(self): | ||
|
@@ -181,19 +299,6 @@ def knots(self, value): | |
|
||
self._knots = value | ||
|
||
def forward(self, x): | ||
""" | ||
Forward pass for the :class:`Spline` model. | ||
|
||
:param torch.Tensor x: The input tensor. | ||
:return: The output tensor. | ||
:rtype: torch.Tensor | ||
""" | ||
t = self.knots | ||
k = self.k | ||
c = self.control_points | ||
|
||
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) | ||
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) | ||
|
||
return y | ||
# Recompute boundary interval when knots change | ||
if hasattr(self, "_boundary_interval_idx"): | ||
self._compute_boundary_interval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, include a short description of the B-spline model