Skip to content

Commit cc228aa

Browse files
committed
fix: write tests to avoid regressions
Resolves: #267.
1 parent c3441e1 commit cc228aa

File tree

3 files changed

+123
-87
lines changed

3 files changed

+123
-87
lines changed

nitransforms/nonlinear.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def __init__(self, field=None, is_deltas=True, reference=None):
7575
field = _ensure_image(field)
7676
# Extract data if nibabel object otherwise assume numpy array
7777
_data = np.squeeze(
78-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field.copy()
78+
np.asanyarray(field.dataobj)
79+
if hasattr(field, "dataobj")
80+
else field.copy()
7981
)
8082

8183
try:
@@ -148,7 +150,7 @@ def map(self, x, inverse=False):
148150
... test_dir / "someones_displacement_field.nii.gz",
149151
... is_deltas=False,
150152
... )
151-
>>> xfm.map([-6.5, -36., -19.5]).tolist()
153+
>>> xfm.map([[-6.5, -36., -19.5]]).tolist()
152154
[[0.0, -0.47516798973083496, 0.0]]
153155
154156
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
@@ -165,8 +167,8 @@ def map(self, x, inverse=False):
165167
... test_dir / "someones_displacement_field.nii.gz",
166168
... is_deltas=True,
167169
... )
168-
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
169-
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
170+
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
171+
[[-6.5, -36.475..., -19.5], [-1.0, -42.038..., -11.25]]
170172
171173
>>> np.array_str(
172174
... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]),
@@ -183,19 +185,16 @@ def map(self, x, inverse=False):
183185
ijk = self.reference.index(np.array(x, dtype="float32"))
184186
indexes = np.round(ijk).astype("int")
185187
ongrid = np.where(np.linalg.norm(ijk - indexes, axis=1) < 1e-3)[0]
186-
mapped = np.empty_like(x, dtype="float32")
187-
188-
if ongrid.size:
189-
mapped[ongrid] = self._field[*indexes[ongrid].T, :]
190188

191-
if ongrid.size == x.shape[0]:
192-
return mapped
189+
if ongrid.size == np.shape(x)[0]:
190+
# return self._field[*indexes.T, :] # From Python 3.11
191+
return self._field[tuple(indexes.T) + (np.s_[:],)]
193192

194-
new_map = np.vstack(
193+
mapped_coords = np.vstack(
195194
tuple(
196195
map_coordinates(
197196
self._field[..., i],
198-
ijk,
197+
ijk.T,
199198
order=3,
200199
mode="constant",
201200
cval=np.nan,
@@ -206,8 +205,8 @@ def map(self, x, inverse=False):
206205
).T
207206

208207
# Set NaN values back to the original coordinates value = no displacement
209-
new_map[np.isnan(new_map)] = np.array(x)[np.isnan(new_map)]
210-
return new_map
208+
mapped_coords[np.isnan(mapped_coords)] = np.array(x)[np.isnan(mapped_coords)]
209+
return mapped_coords
211210

212211
def __matmul__(self, b):
213212
"""

nitransforms/tests/test_nonlinear.py

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import numpy as np
99
import nibabel as nb
10-
from nitransforms.resampling import apply
1110
from nitransforms.base import TransformError, ImageGrid
1211
from nitransforms.io.base import TransformFileError
1312
from nitransforms.nonlinear import (
@@ -17,13 +16,16 @@
1716
from ..io.itk import ITKDisplacementsField
1817

1918

20-
SOME_TEST_POINTS = np.array([
21-
[0.0, 0.0, 0.0],
22-
[1.0, 2.0, 3.0],
23-
[10.0, -10.0, 5.0],
24-
[-5.0, 7.0, -2.0],
25-
[12.0, 0.0, -11.0],
26-
])
19+
SOME_TEST_POINTS = np.array(
20+
[
21+
[0.0, 0.0, 0.0],
22+
[1.0, 2.0, 3.0],
23+
[10.0, -10.0, 5.0],
24+
[-5.0, 7.0, -2.0],
25+
[12.0, 0.0, -11.0],
26+
]
27+
)
28+
2729

2830
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 3)])
2931
def test_itk_disp_load(size):
@@ -88,109 +90,114 @@ def test_bsplines_references(testdata_path):
8890
testdata_path / "someones_bspline_coefficients.nii.gz"
8991
).to_field()
9092

91-
with pytest.raises(TransformError):
92-
apply(
93-
BSplineFieldTransform(
94-
testdata_path / "someones_bspline_coefficients.nii.gz"
95-
),
96-
testdata_path / "someones_anatomy.nii.gz",
97-
)
98-
99-
apply(
100-
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
101-
testdata_path / "someones_anatomy.nii.gz",
93+
BSplineFieldTransform(
94+
testdata_path / "someones_bspline_coefficients.nii.gz",
10295
reference=testdata_path / "someones_anatomy.nii.gz",
10396
)
10497

10598

106-
@pytest.mark.xfail(
107-
reason="GH-267: disabled while debugging",
108-
strict=False,
109-
)
110-
def test_bspline(tmp_path, testdata_path):
99+
def test_map_bspline_vs_displacement(tmp_path, testdata_path):
111100
"""Cross-check B-Splines and deformation field."""
112101
os.chdir(str(tmp_path))
113102

114103
img_name = testdata_path / "someones_anatomy.nii.gz"
115104
disp_name = testdata_path / "someones_displacement_field.nii.gz"
116105
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
117106

118-
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
107+
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name).to_field()
119108
dispxfm = DenseFieldTransform(disp_name)
109+
# Interpolating field should be reasonably similar
110+
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)
120111

121-
out_disp = apply(dispxfm, img_name)
122-
out_bspl = apply(bsplxfm, img_name)
123-
124-
out_disp.to_filename("resampled_field.nii.gz")
125-
out_bspl.to_filename("resampled_bsplines.nii.gz")
126-
127-
assert (
128-
np.sqrt(
129-
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
130-
** 2
131-
).mean()
132-
< 0.2
133-
)
134112

135113
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
136114
@pytest.mark.parametrize("ongrid", [True, False])
137-
def test_densefield_map(tmp_path, get_testdata, image_orientation, ongrid):
115+
def test_densefield_map(get_testdata, image_orientation, ongrid):
138116
"""Create a constant displacement field and compare mappings."""
139117

140118
nii = get_testdata[image_orientation]
141119

120+
# Get sampling indices
121+
rng = np.random.default_rng()
122+
142123
# Create a reference centered at the origin with various axis orders/flips
143124
shape = nii.shape
144125
ref_affine = nii.affine.copy()
145126
reference = ImageGrid(nb.Nifti1Image(np.zeros(shape), ref_affine, None))
146-
indices = reference.ndindex
147-
148-
gridpoints = reference.ras(indices)
149-
points = gridpoints if ongrid else SOME_TEST_POINTS
150-
151-
coordinates = gridpoints.reshape(*shape, 3)
152-
deltas = np.stack((
153-
np.zeros(np.prod(shape), dtype="float32").reshape(shape),
154-
np.linspace(-80, 80, num=np.prod(shape), dtype="float32").reshape(shape),
155-
np.linspace(-50, 50, num=np.prod(shape), dtype="float32").reshape(shape),
156-
), axis=-1)
157-
158-
atol = 1e-4 if image_orientation == "oblique" else 1e-7
127+
grid_ijk = reference.ndindex
128+
grid_xyz = reference.ras(grid_ijk)
129+
130+
subsample = rng.choice(grid_ijk.shape[0], 5000)
131+
points_ijk = grid_ijk.copy() if ongrid else grid_ijk[subsample]
132+
coords_xyz = (
133+
grid_xyz
134+
if ongrid
135+
else reference.ras(points_ijk) + rng.normal(size=points_ijk.shape)
136+
)
159137

160-
# Build an identity transform (deltas)
161-
id_xfm_deltas = DenseFieldTransform(reference=reference)
162-
np.testing.assert_array_equal(coordinates, id_xfm_deltas._field)
163-
np.testing.assert_allclose(points, id_xfm_deltas.map(points), atol=atol)
138+
coords_map = grid_xyz.reshape(*shape, 3)
139+
deltas = np.stack(
140+
(
141+
np.zeros(np.prod(shape), dtype="float32").reshape(shape),
142+
np.linspace(-80, 80, num=np.prod(shape), dtype="float32").reshape(shape),
143+
np.linspace(-50, 50, num=np.prod(shape), dtype="float32").reshape(shape),
144+
),
145+
axis=-1,
146+
)
164147

165-
# Build an identity transform (deformation)
166-
id_xfm_field = DenseFieldTransform(coordinates, is_deltas=False, reference=reference)
167-
np.testing.assert_array_equal(coordinates, id_xfm_field._field)
168-
np.testing.assert_allclose(points, id_xfm_field.map(points), atol=atol)
148+
if ongrid:
149+
atol = 1e-3 if image_orientation == "oblique" or not ongrid else 1e-7
150+
# Build an identity transform (deltas)
151+
id_xfm_deltas = DenseFieldTransform(reference=reference)
152+
np.testing.assert_array_equal(coords_map, id_xfm_deltas._field)
153+
np.testing.assert_allclose(coords_xyz, id_xfm_deltas.map(coords_xyz))
154+
155+
# Build an identity transform (deformation)
156+
id_xfm_field = DenseFieldTransform(
157+
coords_map, is_deltas=False, reference=reference
158+
)
159+
np.testing.assert_array_equal(coords_map, id_xfm_field._field)
160+
np.testing.assert_allclose(coords_xyz, id_xfm_field.map(coords_xyz), atol=atol)
169161

170-
# Collapse to zero transform (deltas)
171-
zero_xfm_deltas = DenseFieldTransform(-coordinates, reference=reference)
172-
np.testing.assert_array_equal(np.zeros_like(zero_xfm_deltas._field), zero_xfm_deltas._field)
173-
np.testing.assert_allclose(np.zeros_like(points), zero_xfm_deltas.map(points), atol=atol)
162+
# Collapse to zero transform (deltas)
163+
zero_xfm_deltas = DenseFieldTransform(-coords_map, reference=reference)
164+
np.testing.assert_array_equal(
165+
np.zeros_like(zero_xfm_deltas._field), zero_xfm_deltas._field
166+
)
167+
np.testing.assert_allclose(
168+
np.zeros_like(coords_xyz), zero_xfm_deltas.map(coords_xyz), atol=atol
169+
)
174170

175-
# Collapse to zero transform (deformation)
176-
zero_xfm_field = DenseFieldTransform(np.zeros_like(deltas), is_deltas=False, reference=reference)
177-
np.testing.assert_array_equal(np.zeros_like(zero_xfm_field._field), zero_xfm_field._field)
178-
np.testing.assert_allclose(np.zeros_like(points), zero_xfm_field.map(points), atol=atol)
171+
# Collapse to zero transform (deformation)
172+
zero_xfm_field = DenseFieldTransform(
173+
np.zeros_like(deltas), is_deltas=False, reference=reference
174+
)
175+
np.testing.assert_array_equal(
176+
np.zeros_like(zero_xfm_field._field), zero_xfm_field._field
177+
)
178+
np.testing.assert_allclose(
179+
np.zeros_like(coords_xyz), zero_xfm_field.map(coords_xyz), atol=atol
180+
)
179181

180182
# Now let's apply a transform
181183
xfm = DenseFieldTransform(deltas, reference=reference)
182184
np.testing.assert_array_equal(deltas, xfm._deltas)
183-
np.testing.assert_array_equal(coordinates + deltas, xfm._field)
185+
np.testing.assert_array_equal(coords_map + deltas, xfm._field)
184186

185-
mapped = xfm.map(points)
186-
nit_deltas = mapped - points
187+
mapped = xfm.map(coords_xyz)
188+
nit_deltas = mapped - coords_xyz
187189

188190
if ongrid:
189191
mapped_image = mapped.reshape(*shape, 3)
190-
np.testing.assert_allclose(deltas + coordinates, mapped_image)
192+
np.testing.assert_allclose(deltas + coords_map, mapped_image)
191193
np.testing.assert_allclose(deltas, nit_deltas.reshape(*shape, 3), atol=1e-4)
192194
np.testing.assert_allclose(xfm._field, mapped_image)
193-
195+
else:
196+
ongrid_xyz = xfm.map(grid_xyz[subsample])
197+
assert (
198+
(np.linalg.norm(ongrid_xyz - mapped, axis=1) > 2).sum()
199+
/ ongrid_xyz.shape[0]
200+
) < 0.5
194201

195202

196203
@pytest.mark.parametrize("is_deltas", [True, False])

nitransforms/tests/test_resampling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,33 @@ def test_apply_4d(serialize_4d):
388388
data = np.asanyarray(moved.dataobj)
389389
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
390390
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]
391+
392+
393+
@pytest.mark.xfail(
394+
reason="GH-267: disabled while debugging",
395+
strict=False,
396+
)
397+
def test_apply_bspline(tmp_path, testdata_path):
398+
"""Cross-check B-Splines and deformation field."""
399+
os.chdir(str(tmp_path))
400+
401+
img_name = testdata_path / "someones_anatomy.nii.gz"
402+
disp_name = testdata_path / "someones_displacement_field.nii.gz"
403+
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
404+
405+
bsplxfm = nitnl.BSplineFieldTransform(bs_name, reference=img_name)
406+
dispxfm = nitnl.DenseFieldTransform(disp_name)
407+
408+
out_disp = apply(dispxfm, img_name)
409+
out_bspl = apply(bsplxfm, img_name)
410+
411+
out_disp.to_filename("resampled_field.nii.gz")
412+
out_bspl.to_filename("resampled_bsplines.nii.gz")
413+
414+
assert (
415+
np.sqrt(
416+
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
417+
** 2
418+
).mean()
419+
< 0.2
420+
)

0 commit comments

Comments
 (0)