Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 149 additions & 44 deletions pina/model/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ class Spline(torch.nn.Module):
Spline model class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, include a short description of the B-spline model

"""

def __init__(self, order=4, knots=None, control_points=None) -> None:
def __init__(
self, order=4, knots=None, control_points=None, grid_extension=True
):
"""
Initialization of the :class:`Spline` class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc string for grid_extension

Expand All @@ -20,7 +22,7 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
``None``.
:raises ValueError: If the order is negative.
:raises ValueError: If both knots and control points are ``None``.
:raises ValueError: If the knot tensor is not one-dimensional.
:raises ValueError: If the knot tensor is not one or two dimensional.
"""
super().__init__()

Expand All @@ -33,6 +35,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:

self.order = order
self.k = order - 1
self.grid_extension = grid_extension

# Cache for performance optimization
self._boundary_interval_idx = None

if knots is not None and control_points is not None:
self.knots = knots
Expand Down Expand Up @@ -65,45 +71,154 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
else:
raise ValueError("Knots and control points cannot be both None.")

if self.knots.ndim != 1:
raise ValueError("Knot vector must be one-dimensional.")
if self.knots.ndim > 2:
raise ValueError("Knot vector must be one or two-dimensional.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is not consistent with the documentation. Please, update


# Precompute boundary interval index for performance
self._compute_boundary_interval()

def basis(self, x, k, i, t):
def _compute_boundary_interval(self):
"""
Precompute the rightmost non-degenerate interval index for performance.
This avoids the search loop in the basis function on every call.
"""
Recursive method to compute the basis functions of the spline.
# Handle multi-dimensional knots
if self.knots.ndim > 1:
# For multi-dimensional knots, we'll handle boundary detection in
# the basis function
self._boundary_interval_idx = None
return

# For 1D knots, find the rightmost non-degenerate interval
for i in range(len(self.knots) - 2, -1, -1):
if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found
self._boundary_interval_idx = i
return

self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0

def basis(self, x, k, knots):
"""
Compute the basis functions for the spline using an iterative approach.
This is a vectorized implementation based on the Cox-de Boor recursion.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still using recursion?


:param torch.Tensor x: The points to be evaluated.
:param int k: The spline degree.
:param int i: The index of the interval.
:param torch.Tensor t: The tensor of knots.
:param torch.Tensor knots: The tensor of knots.
:return: The basis functions evaluated at x
:rtype: torch.Tensor
"""

if k == 0:
a = torch.where(
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
if x.ndim == 1:
x = x.unsqueeze(1) # (batch_size, 1)
if x.ndim == 2:
x = x.unsqueeze(2) # (batch_size, in_dim, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking dimensions, we could simply add x = x.unsqueeze(-1) to add a trailing dimension.


if knots.ndim == 1:
knots = knots.unsqueeze(0) # (1, n_knots)
if knots.ndim == 2:
knots = knots.unsqueeze(0) # (1, in_dim, n_knots)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for checks, just type knots = knots.unsqueeze(0)


# Base case: k=0
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
basis = basis.to(x.dtype)

if self._boundary_interval_idx is not None:
i = self._boundary_interval_idx
tolerance = 1e-10
x_squeezed = x.squeeze(-1)
knot_left = knots[..., i]
knot_right = knots[..., i + 1]

at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance
in_rightmost_interval = (
x_squeezed >= knot_left
) & at_right_boundary

if torch.any(in_rightmost_interval):
# For points at the boundary, ensure they're included in the
# rightmost interval
basis[..., i] = torch.logical_or(
basis[..., i].bool(), in_rightmost_interval
).to(basis.dtype)

# Iterative step (Cox-de Boor recursion)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seem we are actually using recursion

for i in range(1, k + 1):
# First term of the recursion
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
denom1 = torch.where(
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1
)
if i == len(t) - self.order - 1:
a = torch.where(x == t[-1], 1.0, a)
a.requires_grad_(True)
return a
numer1 = x - knots[..., : -(i + 1)]
term1 = (numer1 / denom1) * basis[..., :-1]

if t[i + k] == t[i]:
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
else:
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
denom2 = torch.where(
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2
)
numer2 = knots[..., i + 1 :] - x
term2 = (numer2 / denom2) * basis[..., 1:]

basis = term1 + term2

return basis

def compute_control_points(self, x_eval, y_eval):
"""
Compute control points from given evaluations using least squares.
This method fits the control points to match the target y_eval values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing param description

"""
# (batch, in_dim)
A = self.basis(x_eval, self.k, self.knots)
# (batch, in_dim, n_basis)

if t[i + k + 1] == t[i + 1]:
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
in_dim = A.shape[1]
out_dim = y_eval.shape[2]
n_basis = A.shape[2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables in_dim, out_dim, n_basis are used just once. Consider using the shapes directly where needed

c = torch.zeros(in_dim, out_dim, n_basis).to(A.device)

for i in range(in_dim):
# A_i is (batch, n_basis)
# y_i is (batch, out_dim)
A_i = A[:, i, :]
y_i = y_eval[:, i, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using A[:, i, :] and y_eval[:, i, :] where needed, without saving additional variables

c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim)
c[i, :, :] = c_i.T # (out_dim, n_basis)

self.control_points = torch.nn.Parameter(c)

def forward(self, x):
"""
Forward pass for the :class:`Spline` model.

:param torch.Tensor x: The input tensor.
:return: The output tensor.
:rtype: torch.Tensor
"""
t = self.knots
k = self.k
c = self.control_points

# Create the basis functions
# B will have shape (batch, in_dim, n_basis)
B = self.basis(x, k, t)

# KAN case where control points are (in_dim, out_dim, n_basis)
if c.ndim == 3:
y_ij = torch.einsum(
"bil,iol->bio", B, c
) # (batch, in_dim, out_dim)
# sum over input dimensions
y = torch.sum(y_ij, dim=1) # (batch, out_dim)
# Original test case
else:
c2 = (
(t[i + k + 1] - x)
/ (t[i + k + 1] - t[i + 1])
* self.basis(x, k - 1, i + 1, t)
)
B = B.squeeze(1) # (batch, n_basis)
if c.ndim == 1:
y = torch.einsum("bi,i->b", B, c)
else:
y = torch.einsum("bi,ij->bj", B, c)

return c1 + c2
return y

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, move all properties and setters to the very end of the file.

def control_points(self):
Expand Down Expand Up @@ -131,9 +246,12 @@ def control_points(self, value):
dim = value.get("dim", 1)
value = torch.zeros(n, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting value to zero instead of using the argument passed to the setter?


if not isinstance(value, torch.nn.Parameter):
value = torch.nn.Parameter(value)

if not isinstance(value, torch.Tensor):
raise ValueError("Invalid value for control_points")
self._control_points = torch.nn.Parameter(value, requires_grad=True)
self._control_points = value

@property
def knots(self):
Expand Down Expand Up @@ -181,19 +299,6 @@ def knots(self, value):

self._knots = value

def forward(self, x):
"""
Forward pass for the :class:`Spline` model.

:param torch.Tensor x: The input tensor.
:return: The output tensor.
:rtype: torch.Tensor
"""
t = self.knots
k = self.k
c = self.control_points

basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)

return y
# Recompute boundary interval when knots change
if hasattr(self, "_boundary_interval_idx"):
self._compute_boundary_interval()