Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ module = [
"joblib",
"h5py",
"ConfigSpace",
"scipy.*",
"sklearn.*",
"skimage.*",
"pandas",
"attrs",
]
ignore_missing_imports = true

Expand Down
81 changes: 74 additions & 7 deletions src/nifreeze/model/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
from scipy.optimize import Bounds
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
RBF,
Hyperparameter,
Kernel,
KernelOperator,
)
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils._param_validation import Interval, StrOptions
Expand Down Expand Up @@ -94,8 +96,9 @@ class DiffusionGPR(GaussianProcessRegressor):
selection through gradient-descent with analytical gradient calculations
would not work (the derivative of the kernel w.r.t. ``alpha`` is zero).

This might have been overlooked in :footcite:p:`andersson_non-parametric_2015` or else they actually did
not use analytical gradient-descent:
This might have been overlooked in
:footcite:p:`andersson_non-parametric_2015` or else they actually did not
use analytical gradient-descent:

*A note on optimisation*

Expand Down Expand Up @@ -246,7 +249,7 @@ def _constrained_optimization(
options=options,
args=(self.eval_gradient,),
tol=self.tol,
)
) # type: ignore[call-overload]
return opt_res.x, opt_res.fun

if callable(self.optimizer):
Expand Down Expand Up @@ -470,11 +473,73 @@ def __repr__(self) -> str:
return f"SphericalKriging (a={self.beta_a}, λ={self.beta_l})"


class MultiShellKernel(KernelOperator):
"""Composite kernel for multi-shell diffusion data."""

k1: Kernel
k2: Kernel

def __init__(
self,
orientation_kernel: Kernel | None = None,
radial_kernel: Kernel | None = None,
orientation_dims: Sequence[int] = (0, 1, 2),
bval_index: int = 3,
) -> None:
super().__init__(
orientation_kernel if orientation_kernel is not None else SphericalKriging(),
radial_kernel if radial_kernel is not None else RBF(length_scale=1.0),
)
self.orientation_dims = tuple(orientation_dims)
self.bval_index = bval_index

def get_params(self, deep: bool = True): # noqa: D401
params = super().get_params(deep=deep)
params.update({"orientation_dims": self.orientation_dims, "bval_index": self.bval_index})
return params

def _split(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
X = np.asarray(X)
orient = X[:, self.orientation_dims]
bvals = np.log(X[:, self.bval_index]).reshape(-1, 1)
return orient, bvals

def __call__(
self,
X: np.ndarray,
Y: np.ndarray | None = None,
eval_gradient: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
X_o, X_b = self._split(X)
Y_o: np.ndarray | None
Y_b: np.ndarray | None
if Y is not None:
Y_o, Y_b = self._split(Y)
else:
Y_o = None
Y_b = None

if eval_gradient:
K1, g1 = self.k1(X_o, Y_o, eval_gradient=True)
K2, g2 = self.k2(X_b, Y_b, eval_gradient=True)
return K1 * K2, np.dstack((g1 * K2[:, :, np.newaxis], g2 * K1[:, :, np.newaxis]))

return self.k1(X_o, Y_o) * self.k2(X_b, Y_b)

def diag(self, X: npt.ArrayLike) -> np.ndarray:
X_o, X_b = self._split(np.asarray(X))
return self.k1.diag(X_o) * self.k2.diag(X_b)

def __repr__(self) -> str: # pragma: no cover - simple representation
return f"MultiShellKernel({self.k1} * {self.k2})"


def exponential_covariance(theta: np.ndarray, a: float) -> np.ndarray:
r"""
Compute the exponential covariance for given distances and scale parameter.

Implements :math:`C_{\theta}`, following Eq. (9) in :footcite:p:`andersson_non-parametric_2015`:
Implements :math:`C_{\theta}`, following Eq. (9) in
:footcite:p:`andersson_non-parametric_2015`:

.. math::
\begin{equation}
Expand Down Expand Up @@ -510,7 +575,8 @@ def spherical_covariance(theta: np.ndarray, a: float) -> np.ndarray:
r"""
Compute the spherical covariance for given distances and scale parameter.

Implements :math:`C_{\theta}`, following Eq. (10) in :footcite:p:`andersson_non-parametric_2015`:
Implements :math:`C_{\theta}`, following Eq. (10) in
:footcite:p:`andersson_non-parametric_2015`:

.. math::
\begin{equation}
Expand Down Expand Up @@ -554,8 +620,9 @@ def compute_pairwise_angles(
) -> np.ndarray:
r"""Compute pairwise angles across diffusion gradient encoding directions.

Following :footcite:p:`andersson_non-parametric_2015`:, it computes the smallest of the angles between
each pair if ``closest_polarity`` is ``True``, i.e.,
Following :footcite:p:`andersson_non-parametric_2015`, it computes the
smallest of the angles between each pair if ``closest_polarity`` is ``True``,
i.e.,

.. math::

Expand Down