Skip to content

Commit 8cff86c

Browse files
committed
enh: enable B-Splines X5 i/o and DenseFields' is_deltas
1 parent aecf69f commit 8cff86c

File tree

3 files changed

+78
-42
lines changed

3 files changed

+78
-42
lines changed

nitransforms/io/x5.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ class X5Transform:
7373
For parametric models it is generally possible to obtain it analytically, so this dataset
7474
could not be as useful in that case.
7575
"""
76-
# additional_parameters: Optional[np.ndarray] = None
77-
# AdditionalParameters is empty in the draft spec - ignore for now.
78-
# Only documentation ATM is for SubType:
79-
# The SubType setting enables setting the additional parameters on a dataset called
80-
# "AdditionalParameters" that hangs directly from this transform node.
76+
additional_parameters: Optional[np.ndarray] = None
77+
"""
78+
An OPTIONAL field to store additional parameters, depending on the SubType of the
79+
transform.
80+
"""
8181
array_length: int = 1
8282
"""Undocumented field in the draft to enable a single transform group for 4D transforms."""
8383

@@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
130130
g.create_dataset("Inverse", data=node.inverse)
131131
if node.jacobian is not None:
132132
g.create_dataset("Jacobian", data=node.jacobian)
133-
# Disabled until we need SubType and AdditionalParameters
134-
# if node.additional_parameters is not None:
135-
# g.create_dataset(
136-
# "AdditionalParameters", data=node.additional_parameters
137-
# )
133+
if node.additional_parameters is not None:
134+
g.create_dataset(
135+
"AdditionalParameters", data=node.additional_parameters
136+
)
138137
return fname
139138

140139

@@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform:
174173
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
175174
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
176175
array_length=int(node.attrs.get("ArrayLength", 1)),
176+
additional_parameters=np.asarray(node["AdditionalParameters"])
177+
if "AdditionalParameters" in node
178+
else None,
177179
)
178180

179181
if "Domain" in node:

nitransforms/nonlinear.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from functools import partial
1313
from collections import namedtuple
1414
import numpy as np
15+
import nibabel as nb
1516

1617
from nitransforms import io
1718
from nitransforms.io.base import _ensure_image
19+
from nitransforms.io.x5 import from_filename as load_x5
1820
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
1921
from nitransforms.base import (
2022
TransformBase,
@@ -34,7 +36,7 @@
3436
class DenseFieldTransform(TransformBase):
3537
"""Represents dense field (voxel-wise) transforms."""
3638

37-
__slots__ = ("_field", "_deltas")
39+
__slots__ = ("_field", "_deltas", "_is_deltas")
3840

3941
def __init__(self, field=None, is_deltas=True, reference=None):
4042
"""
@@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6870

6971
super().__init__()
7072

71-
if field is not None:
72-
field = _ensure_image(field)
73-
self._field = np.squeeze(
74-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
75-
)
76-
else:
77-
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
78-
is_deltas = True
73+
self._is_deltas = is_deltas
7974

8075
try:
8176
self.reference = ImageGrid(reference if reference is not None else field)
@@ -86,24 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
8681
else "Reference is not a spatial image"
8782
)
8883

84+
fieldshape = (*self.reference.shape, self.reference.ndim)
85+
if field is not None:
86+
field = _ensure_image(field)
87+
self._field = np.squeeze(
88+
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
89+
)
90+
if fieldshape != self._field.shape:
91+
raise TransformError(
92+
f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) "
93+
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
94+
)
95+
else:
96+
self._field = np.zeros(fieldshape, dtype="float32")
97+
self._is_deltas = True
98+
8999
if self._field.shape[-1] != self.ndim:
90100
raise TransformError(
91101
"The number of components of the field (%d) does not match "
92102
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
93103
)
94104

95-
if is_deltas:
105+
if self.is_deltas:
96106
self._deltas = (
97107
self._field.copy()
98108
) # IMPORTANT: you don't want to update deltas
99109
# Convert from displacements (deltas) to deformations fields
100110
# (just add its origin to each delta vector)
101-
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
111+
self._field += self.reference.ndcoords.T.reshape(fieldshape)
102112

103113
def __repr__(self):
104114
"""Beautify the python representation."""
105115
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
106116

117+
@property
118+
def is_deltas(self):
119+
"""Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+
return self._is_deltas
121+
107122
@property
108123
def ndim(self):
109124
"""Get the dimensions of the transform."""
@@ -232,7 +247,7 @@ def __eq__(self, other):
232247
True
233248
234249
"""
235-
_eq = np.array_equal(self._field, other._field)
250+
_eq = np.allclose(self._field, other._field)
236251
if _eq and self._reference != other._reference:
237252
warnings.warn("Fields are equal, but references do not match.")
238253
return _eq
@@ -255,9 +270,9 @@ def to_x5(self, metadata=None):
255270
return io.x5.X5Transform(
256271
type="nonlinear",
257272
subtype="densefield",
258-
representation="displacements",
273+
representation="displacements" if self.is_deltas else "deformations",
259274
metadata=metadata,
260-
transform=self._deltas,
275+
transform=self._deltas if self.is_deltas else self._field,
261276
dimension_kinds=kinds,
262277
domain=domain,
263278
)
@@ -275,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"):
275290
raise NotImplementedError(f"Unsupported format <{fmt}>")
276291

277292
if fmt == "X5":
278-
from .io.x5 import from_filename as load_x5
279-
280293
x5_xfm = load_x5(filename)[0]
281294
Domain = namedtuple("Domain", "affine shape")
282295
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
283-
return cls(x5_xfm.transform, is_deltas=True, reference=reference)
296+
field = nb.Nifti1Image(x5_xfm.transform, reference.affine)
297+
return cls(
298+
field,
299+
is_deltas=x5_xfm.representation == "displacements",
300+
reference=reference,
301+
)
284302

285303
return cls(_factory[fmt.lower()].from_filename(filename))
286304

@@ -317,6 +335,24 @@ def ndim(self):
317335
"""Get the dimensions of the transform."""
318336
return self._coeffs.ndim - 1
319337

338+
@classmethod
339+
def from_filename(cls, filename, fmt="X5"):
340+
_factory = {
341+
"X5": None,
342+
}
343+
fmt = fmt.upper()
344+
if fmt not in {k.upper() for k in _factory}:
345+
raise NotImplementedError(f"Unsupported format <{fmt}>")
346+
347+
x5_xfm = load_x5(filename)[0]
348+
Domain = namedtuple("Domain", "affine shape")
349+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
350+
351+
coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
352+
return cls(coefficients, reference=reference)
353+
354+
# return cls(_factory[fmt.lower()].from_filename(filename))
355+
320356
def to_field(self, reference=None, dtype="float32"):
321357
"""Generate a displacements deformation field from this B-Spline field."""
322358
_ref = (
@@ -351,21 +387,17 @@ def to_x5(self, metadata=None):
351387
coordinates="cartesian",
352388
)
353389

354-
meta = metadata | {
355-
"KnotsAffine": self._knots.affine.tolist(),
356-
"KnotsShape": self._knots.shape,
357-
}
358-
359390
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
360391

361392
return io.x5.X5Transform(
362393
type="nonlinear",
363394
subtype="bspline",
364395
representation="coefficients",
365-
metadata=meta,
396+
metadata=metadata,
366397
transform=self._coeffs,
367398
dimension_kinds=kinds,
368399
domain=domain,
400+
additional_parameters=self._knots.affine,
369401
)
370402

371403
def map(self, x, inverse=False):

nitransforms/tests/test_nonlinear.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,29 +122,29 @@ def test_bspline(tmp_path, testdata_path):
122122
)
123123

124124

125-
def test_densefield_x5_roundtrip(tmp_path):
125+
@pytest.mark.parametrize("is_deltas", [True, False])
126+
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
126127
"""Ensure dense field transforms roundtrip via X5."""
127128
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
128129
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
129130

130-
xfm = DenseFieldTransform(disp, reference=ref)
131+
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
131132

132133
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
133134
assert node.type == "nonlinear"
134135
assert node.subtype == "densefield"
135-
assert node.representation == "displacements"
136+
assert node.representation == "displacements" if is_deltas else "deformations"
136137
assert node.domain.size == ref.shape
137138
assert node.metadata["GeneratedBy"] == "pytest"
138139

139140
fname = tmp_path / "test.x5"
140141
io.x5.to_filename(fname, [node])
141142

142143
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
143-
diff = xfm2._deltas - xfm._deltas
144-
coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape)
145-
assert np.allclose(diff, coords)
144+
146145
assert xfm2.reference.shape == ref.shape
147146
assert np.allclose(xfm2.reference.affine, ref.affine)
147+
assert xfm == xfm2
148148

149149

150150
def test_bspline_to_x5(tmp_path):
@@ -161,6 +161,8 @@ def test_bspline_to_x5(tmp_path):
161161

162162
fname = tmp_path / "bspline.x5"
163163
io.x5.to_filename(fname, [node])
164-
node2 = io.x5.from_filename(fname)[0]
165-
assert np.allclose(node2.transform, node.transform)
166-
assert node2.metadata["tool"] == "pytest"
164+
165+
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
166+
assert np.allclose(xfm._coeffs, xfm2._coeffs)
167+
assert xfm2.reference.shape == ref.shape
168+
assert np.allclose(xfm2.reference.affine, ref.affine)

0 commit comments

Comments
 (0)