From 4dc00e8927916cc28a0dfeef14a92cdebad5362c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 23 Jul 2025 16:20:41 +0200 Subject: [PATCH 1/2] Add MultiShellKernel for multi-shell GP --- src/nifreeze/model/gpr.py | 73 +++++++++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 6 deletions(-) diff --git a/src/nifreeze/model/gpr.py b/src/nifreeze/model/gpr.py index 786545a1..397e28cc 100644 --- a/src/nifreeze/model/gpr.py +++ b/src/nifreeze/model/gpr.py @@ -33,8 +33,10 @@ from scipy.optimize import Bounds from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( + RBF, Hyperparameter, Kernel, + KernelOperator, ) from sklearn.metrics.pairwise import cosine_similarity from sklearn.utils._param_validation import Interval, StrOptions @@ -94,8 +96,9 @@ class DiffusionGPR(GaussianProcessRegressor): selection through gradient-descent with analytical gradient calculations would not work (the derivative of the kernel w.r.t. ``alpha`` is zero). - This might have been overlooked in :footcite:p:`andersson_non-parametric_2015` or else they actually did - not use analytical gradient-descent: + This might have been overlooked in + :footcite:p:`andersson_non-parametric_2015` or else they actually did not + use analytical gradient-descent: *A note on optimisation* @@ -470,11 +473,67 @@ def __repr__(self) -> str: return f"SphericalKriging (a={self.beta_a}, λ={self.beta_l})" +class MultiShellKernel(KernelOperator): + """Composite kernel for multi-shell diffusion data.""" + + def __init__( + self, + orientation_kernel: Kernel | None = None, + radial_kernel: Kernel | None = None, + orientation_dims: Sequence[int] = (0, 1, 2), + bval_index: int = 3, + ) -> None: + super().__init__( + orientation_kernel if orientation_kernel is not None else SphericalKriging(), + radial_kernel if radial_kernel is not None else RBF(length_scale=1.0), + ) + self.orientation_dims = tuple(orientation_dims) + self.bval_index = bval_index + + def get_params(self, deep: bool = True): # noqa: D401 + params = super().get_params(deep=deep) + params.update({"orientation_dims": self.orientation_dims, "bval_index": self.bval_index}) + return params + + def _split(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + X = np.asarray(X) + orient = X[:, self.orientation_dims] + bvals = np.log(X[:, self.bval_index]).reshape(-1, 1) + return orient, bvals + + def __call__( + self, + X: np.ndarray, + Y: np.ndarray | None = None, + eval_gradient: bool = False, + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + X_o, X_b = self._split(X) + if Y is not None: + Y_o, Y_b = self._split(Y) + else: + Y_o = Y_b = None + + if eval_gradient: + K1, g1 = self.k1(X_o, Y_o, eval_gradient=True) + K2, g2 = self.k2(X_b, Y_b, eval_gradient=True) + return K1 * K2, np.dstack((g1 * K2[:, :, np.newaxis], g2 * K1[:, :, np.newaxis])) + + return self.k1(X_o, Y_o) * self.k2(X_b, Y_b) + + def diag(self, X: npt.ArrayLike) -> np.ndarray: + X_o, X_b = self._split(np.asarray(X)) + return self.k1.diag(X_o) * self.k2.diag(X_b) + + def __repr__(self) -> str: # pragma: no cover - simple representation + return f"MultiShellKernel({self.k1} * {self.k2})" + + def exponential_covariance(theta: np.ndarray, a: float) -> np.ndarray: r""" Compute the exponential covariance for given distances and scale parameter. - Implements :math:`C_{\theta}`, following Eq. (9) in :footcite:p:`andersson_non-parametric_2015`: + Implements :math:`C_{\theta}`, following Eq. (9) in + :footcite:p:`andersson_non-parametric_2015`: .. math:: \begin{equation} @@ -510,7 +569,8 @@ def spherical_covariance(theta: np.ndarray, a: float) -> np.ndarray: r""" Compute the spherical covariance for given distances and scale parameter. - Implements :math:`C_{\theta}`, following Eq. (10) in :footcite:p:`andersson_non-parametric_2015`: + Implements :math:`C_{\theta}`, following Eq. (10) in + :footcite:p:`andersson_non-parametric_2015`: .. math:: \begin{equation} @@ -554,8 +614,9 @@ def compute_pairwise_angles( ) -> np.ndarray: r"""Compute pairwise angles across diffusion gradient encoding directions. - Following :footcite:p:`andersson_non-parametric_2015`:, it computes the smallest of the angles between - each pair if ``closest_polarity`` is ``True``, i.e., + Following :footcite:p:`andersson_non-parametric_2015`, it computes the + smallest of the angles between each pair if ``closest_polarity`` is ``True``, + i.e., .. math:: From 94a6838333c82508584f3b8ed23485298f22cc93 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 23 Jul 2025 16:43:44 +0200 Subject: [PATCH 2/2] Fix MultiShellKernel type issues --- pyproject.toml | 5 +++++ src/nifreeze/model/gpr.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f970f47..5fbde4c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,11 @@ module = [ "joblib", "h5py", "ConfigSpace", + "scipy.*", + "sklearn.*", + "skimage.*", + "pandas", + "attrs", ] ignore_missing_imports = true diff --git a/src/nifreeze/model/gpr.py b/src/nifreeze/model/gpr.py index 397e28cc..ea64d035 100644 --- a/src/nifreeze/model/gpr.py +++ b/src/nifreeze/model/gpr.py @@ -249,7 +249,7 @@ def _constrained_optimization( options=options, args=(self.eval_gradient,), tol=self.tol, - ) + ) # type: ignore[call-overload] return opt_res.x, opt_res.fun if callable(self.optimizer): @@ -476,6 +476,9 @@ def __repr__(self) -> str: class MultiShellKernel(KernelOperator): """Composite kernel for multi-shell diffusion data.""" + k1: Kernel + k2: Kernel + def __init__( self, orientation_kernel: Kernel | None = None, @@ -508,10 +511,13 @@ def __call__( eval_gradient: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: X_o, X_b = self._split(X) + Y_o: np.ndarray | None + Y_b: np.ndarray | None if Y is not None: Y_o, Y_b = self._split(Y) else: - Y_o = Y_b = None + Y_o = None + Y_b = None if eval_gradient: K1, g1 = self.k1(X_o, Y_o, eval_gradient=True)