From 251ec38c71c16485a675bdeeff2bb89779cb53e1 Mon Sep 17 00:00:00 2001 From: ajacoby9 Date: Thu, 24 Jul 2025 12:19:08 +0200 Subject: [PATCH 1/3] Improve spline --- pina/model/spline.py | 165 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 134 insertions(+), 31 deletions(-) diff --git a/pina/model/spline.py b/pina/model/spline.py index c22c7937c..bc854ec56 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -9,7 +9,7 @@ class Spline(torch.nn.Module): Spline model class. """ - def __init__(self, order=4, knots=None, control_points=None) -> None: + def __init__(self, order=4, knots=None, control_points=None, grid_extension=True) -> None: """ Initialization of the :class:`Spline` class. @@ -33,6 +33,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: self.order = order self.k = order - 1 + self.grid_extension = grid_extension + + # Cache for performance optimization + self._boundary_interval_idx = None if knots is not None and control_points is not None: self.knots = knots @@ -65,45 +69,123 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: else: raise ValueError("Knots and control points cannot be both None.") - if self.knots.ndim != 1: - raise ValueError("Knot vector must be one-dimensional.") + if self.knots.ndim > 2: + raise ValueError("Knot vector must be one or two-dimensional.") + + # Precompute boundary interval index for performance + self._compute_boundary_interval() - def basis(self, x, k, i, t): + def _compute_boundary_interval(self): """ - Recursive method to compute the basis functions of the spline. + Precompute the rightmost non-degenerate interval index for performance. + This avoids the search loop in the basis function on every call. + """ + if not isinstance(self.knots, torch.Tensor): + self._boundary_interval_idx = None + return + + # Find the rightmost interval with positive width + knots = self.knots + + # Handle multi-dimensional knots + if knots.ndim > 1: + # For multi-dimensional knots, we'll handle boundary detection in the basis function + self._boundary_interval_idx = None + return + + # For 1D knots, find the rightmost non-degenerate interval + for i in range(len(knots) - 2, -1, -1): + if knots[i] < knots[i + 1]: # Non-degenerate interval found + self._boundary_interval_idx = i + return + + self._boundary_interval_idx = len(knots) - 2 if len(knots) > 1 else 0 + + def basis(self, x, k, knots): + """ + Compute the basis functions for the spline using an iterative approach. + This is a vectorized implementation based on the Cox-de Boor recursion. :param torch.Tensor x: The points to be evaluated. :param int k: The spline degree. - :param int i: The index of the interval. - :param torch.Tensor t: The tensor of knots. + :param torch.Tensor knots: The tensor of knots. :return: The basis functions evaluated at x :rtype: torch.Tensor """ - if k == 0: - a = torch.where( - torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 + if x.ndim == 1: + x = x.unsqueeze(1) # (batch_size, 1) + if x.ndim == 2: + x = x.unsqueeze(2) # (batch_size, in_dim, 1) + + if knots.ndim == 1: + knots = knots.unsqueeze(0) # (1, n_knots) + if knots.ndim == 2: + knots = knots.unsqueeze(0) # (1, in_dim, n_knots) + + # Base case: k=0 + basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) + basis = basis.to(x.dtype) + + + if self._boundary_interval_idx is not None: + i = self._boundary_interval_idx + tolerance = 1e-10 + x_squeezed = x.squeeze(-1) + knot_left = knots[..., i] + knot_right = knots[..., i + 1] + + at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance + in_rightmost_interval = (x_squeezed >= knot_left) & at_right_boundary + + if torch.any(in_rightmost_interval): + # For points at the boundary, ensure they're included in the rightmost interval + basis[..., i] = torch.logical_or(basis[..., i].bool(), in_rightmost_interval).to(basis.dtype) + + # Iterative step (Cox-de Boor recursion) + for i in range(1, k + 1): + # First term of the recursion + denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom1 = torch.where( + torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 ) - if i == len(t) - self.order - 1: - a = torch.where(x == t[-1], 1.0, a) - a.requires_grad_(True) - return a - - if t[i + k] == t[i]: - c1 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) + numer1 = x - knots[..., : -(i + 1)] + term1 = (numer1 / denom1) * basis[..., :-1] - if t[i + k + 1] == t[i + 1]: - c2 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c2 = ( - (t[i + k + 1] - x) - / (t[i + k + 1] - t[i + 1]) - * self.basis(x, k - 1, i + 1, t) + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + denom2 = torch.where( + torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 ) + numer2 = knots[..., i + 1 :] - x + term2 = (numer2 / denom2) * basis[..., 1:] + + basis = term1 + term2 + + return basis - return c1 + c2 + def compute_control_points(self, x_eval, y_eval): + """ + Compute control points from given evaluations using least squares. + This method fits the control points to match the target y_eval values. + """ + # (batch, in_dim) + A = self.basis(x_eval, self.k, self.knots) + # (batch, in_dim, n_basis) + + in_dim = A.shape[1] + out_dim = y_eval.shape[2] + n_basis = A.shape[2] + c = torch.zeros(in_dim, out_dim, n_basis).to(A.device) + + for i in range(in_dim): + # A_i is (batch, n_basis) + # y_i is (batch, out_dim) + A_i = A[:, i, :] + y_i = y_eval[:, i, :] + c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim) + c[i, :, :] = c_i.T # (out_dim, n_basis) + + self.control_points = torch.nn.Parameter(c) @property def control_points(self): @@ -131,9 +213,12 @@ def control_points(self, value): dim = value.get("dim", 1) value = torch.zeros(n, dim) + if not isinstance(value, torch.nn.Parameter): + value = torch.nn.Parameter(value) + if not isinstance(value, torch.Tensor): raise ValueError("Invalid value for control_points") - self._control_points = torch.nn.Parameter(value, requires_grad=True) + self._control_points = value @property def knots(self): @@ -180,6 +265,10 @@ def knots(self, value): raise ValueError("Invalid value for knots") self._knots = value + + # Recompute boundary interval when knots change + if hasattr(self, '_boundary_interval_idx'): + self._compute_boundary_interval() def forward(self, x): """ @@ -193,7 +282,21 @@ def forward(self, x): k = self.k c = self.control_points - basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) - y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) + # Create the basis functions + # B will have shape (batch, in_dim, n_basis) + B = self.basis(x, k, t) + + # KAN case where control points are (in_dim, out_dim, n_basis) + if c.ndim == 3: + y_ij = torch.einsum("bil,iol->bio", B, c) # (batch, in_dim, out_dim) + # sum over input dimensions + y = torch.sum(y_ij, dim=1) # (batch, out_dim) + # Original test case + else: + B = B.squeeze(1) # (batch, n_basis) + if c.ndim == 1: + y = torch.einsum("bi,i->b", B, c) + else: + y = torch.einsum("bi,ij->bj", B, c) - return y + return y \ No newline at end of file From 774542ffd84d02b3821a932142600d8960d63f18 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 4 Aug 2025 12:18:20 +0200 Subject: [PATCH 2/3] refact --- pina/model/spline.py | 49 ++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/pina/model/spline.py b/pina/model/spline.py index bc854ec56..3f7eb99cb 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -9,7 +9,9 @@ class Spline(torch.nn.Module): Spline model class. """ - def __init__(self, order=4, knots=None, control_points=None, grid_extension=True) -> None: + def __init__( + self, order=4, knots=None, control_points=None, grid_extension=True + ): """ Initialization of the :class:`Spline` class. @@ -34,7 +36,7 @@ def __init__(self, order=4, knots=None, control_points=None, grid_extension=True self.order = order self.k = order - 1 self.grid_extension = grid_extension - + # Cache for performance optimization self._boundary_interval_idx = None @@ -71,7 +73,7 @@ def __init__(self, order=4, knots=None, control_points=None, grid_extension=True if self.knots.ndim > 2: raise ValueError("Knot vector must be one or two-dimensional.") - + # Precompute boundary interval index for performance self._compute_boundary_interval() @@ -83,22 +85,23 @@ def _compute_boundary_interval(self): if not isinstance(self.knots, torch.Tensor): self._boundary_interval_idx = None return - + # Find the rightmost interval with positive width knots = self.knots - + # Handle multi-dimensional knots if knots.ndim > 1: - # For multi-dimensional knots, we'll handle boundary detection in the basis function + # For multi-dimensional knots, we'll handle boundary detection in + # the basis function self._boundary_interval_idx = None return - + # For 1D knots, find the rightmost non-degenerate interval for i in range(len(knots) - 2, -1, -1): if knots[i] < knots[i + 1]: # Non-degenerate interval found self._boundary_interval_idx = i return - + self._boundary_interval_idx = len(knots) - 2 if len(knots) > 1 else 0 def basis(self, x, k, knots): @@ -126,21 +129,25 @@ def basis(self, x, k, knots): # Base case: k=0 basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) basis = basis.to(x.dtype) - if self._boundary_interval_idx is not None: i = self._boundary_interval_idx - tolerance = 1e-10 + tolerance = 1e-10 x_squeezed = x.squeeze(-1) knot_left = knots[..., i] knot_right = knots[..., i + 1] - + at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance - in_rightmost_interval = (x_squeezed >= knot_left) & at_right_boundary - + in_rightmost_interval = ( + x_squeezed >= knot_left + ) & at_right_boundary + if torch.any(in_rightmost_interval): - # For points at the boundary, ensure they're included in the rightmost interval - basis[..., i] = torch.logical_or(basis[..., i].bool(), in_rightmost_interval).to(basis.dtype) + # For points at the boundary, ensure they're included in the + # rightmost interval + basis[..., i] = torch.logical_or( + basis[..., i].bool(), in_rightmost_interval + ).to(basis.dtype) # Iterative step (Cox-de Boor recursion) for i in range(1, k + 1): @@ -215,7 +222,7 @@ def control_points(self, value): if not isinstance(value, torch.nn.Parameter): value = torch.nn.Parameter(value) - + if not isinstance(value, torch.Tensor): raise ValueError("Invalid value for control_points") self._control_points = value @@ -265,9 +272,9 @@ def knots(self, value): raise ValueError("Invalid value for knots") self._knots = value - + # Recompute boundary interval when knots change - if hasattr(self, '_boundary_interval_idx'): + if hasattr(self, "_boundary_interval_idx"): self._compute_boundary_interval() def forward(self, x): @@ -288,7 +295,9 @@ def forward(self, x): # KAN case where control points are (in_dim, out_dim, n_basis) if c.ndim == 3: - y_ij = torch.einsum("bil,iol->bio", B, c) # (batch, in_dim, out_dim) + y_ij = torch.einsum( + "bil,iol->bio", B, c + ) # (batch, in_dim, out_dim) # sum over input dimensions y = torch.sum(y_ij, dim=1) # (batch, out_dim) # Original test case @@ -299,4 +308,4 @@ def forward(self, x): else: y = torch.einsum("bi,ij->bj", B, c) - return y \ No newline at end of file + return y From 4b97efc767955ae544433aa2f3f01445be160063 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 4 Sep 2025 12:58:22 +0200 Subject: [PATCH 3/3] start fixing --- pina/model/spline.py | 85 ++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/pina/model/spline.py b/pina/model/spline.py index 3f7eb99cb..c6f3c5579 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -22,7 +22,7 @@ def __init__( ``None``. :raises ValueError: If the order is negative. :raises ValueError: If both knots and control points are ``None``. - :raises ValueError: If the knot tensor is not one-dimensional. + :raises ValueError: If the knot tensor is not one or two dimensional. """ super().__init__() @@ -82,27 +82,20 @@ def _compute_boundary_interval(self): Precompute the rightmost non-degenerate interval index for performance. This avoids the search loop in the basis function on every call. """ - if not isinstance(self.knots, torch.Tensor): - self._boundary_interval_idx = None - return - - # Find the rightmost interval with positive width - knots = self.knots - # Handle multi-dimensional knots - if knots.ndim > 1: + if self.knots.ndim > 1: # For multi-dimensional knots, we'll handle boundary detection in # the basis function self._boundary_interval_idx = None return # For 1D knots, find the rightmost non-degenerate interval - for i in range(len(knots) - 2, -1, -1): - if knots[i] < knots[i + 1]: # Non-degenerate interval found + for i in range(len(self.knots) - 2, -1, -1): + if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found self._boundary_interval_idx = i return - self._boundary_interval_idx = len(knots) - 2 if len(knots) > 1 else 0 + self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0 def basis(self, x, k, knots): """ @@ -194,6 +187,39 @@ def compute_control_points(self, x_eval, y_eval): self.control_points = torch.nn.Parameter(c) + def forward(self, x): + """ + Forward pass for the :class:`Spline` model. + + :param torch.Tensor x: The input tensor. + :return: The output tensor. + :rtype: torch.Tensor + """ + t = self.knots + k = self.k + c = self.control_points + + # Create the basis functions + # B will have shape (batch, in_dim, n_basis) + B = self.basis(x, k, t) + + # KAN case where control points are (in_dim, out_dim, n_basis) + if c.ndim == 3: + y_ij = torch.einsum( + "bil,iol->bio", B, c + ) # (batch, in_dim, out_dim) + # sum over input dimensions + y = torch.sum(y_ij, dim=1) # (batch, out_dim) + # Original test case + else: + B = B.squeeze(1) # (batch, n_basis) + if c.ndim == 1: + y = torch.einsum("bi,i->b", B, c) + else: + y = torch.einsum("bi,ij->bj", B, c) + + return y + @property def control_points(self): """ @@ -275,37 +301,4 @@ def knots(self, value): # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): - self._compute_boundary_interval() - - def forward(self, x): - """ - Forward pass for the :class:`Spline` model. - - :param torch.Tensor x: The input tensor. - :return: The output tensor. - :rtype: torch.Tensor - """ - t = self.knots - k = self.k - c = self.control_points - - # Create the basis functions - # B will have shape (batch, in_dim, n_basis) - B = self.basis(x, k, t) - - # KAN case where control points are (in_dim, out_dim, n_basis) - if c.ndim == 3: - y_ij = torch.einsum( - "bil,iol->bio", B, c - ) # (batch, in_dim, out_dim) - # sum over input dimensions - y = torch.sum(y_ij, dim=1) # (batch, out_dim) - # Original test case - else: - B = B.squeeze(1) # (batch, n_basis) - if c.ndim == 1: - y = torch.einsum("bi,i->b", B, c) - else: - y = torch.einsum("bi,ij->bj", B, c) - - return y + self._compute_boundary_interval() \ No newline at end of file