Skip to content

Commit 60168b2

Browse files
authored
Merge pull request #274 from nipy/enh/extract-test-from-266
ENH: Add unit test on dense fields (extracted from #266)
2 parents 6c64288 + 8bb8024 commit 60168b2

File tree

2 files changed

+101
-16
lines changed

2 files changed

+101
-16
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
BSplineFieldTransform,
1414
DenseFieldTransform,
1515
)
16+
from nitransforms.tests.utils import get_points
17+
18+
rng = np.random.default_rng()
1619

1720

1821
def test_displacements_init():
@@ -74,24 +77,81 @@ def test_bsplines_references(testdata_path):
7477
)
7578

7679

77-
@pytest.mark.xfail(
78-
reason="Disable while #266 is developed.",
79-
strict=False,
80-
)
81-
def test_bspline(tmp_path, testdata_path):
82-
"""
83-
Cross-check B-Splines and deformation field.
80+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
81+
@pytest.mark.parametrize("ongrid", [True, False])
82+
def test_densefield_map(get_testdata, image_orientation, ongrid):
83+
"""Create a constant displacement field and compare mappings."""
84+
85+
nii = get_testdata[image_orientation]
86+
87+
# Get sampling indices
88+
coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample = (
89+
get_points(nii, ongrid, rng=rng)
90+
)
91+
92+
coords_map = grid_xyz.reshape(*shape, 3)
93+
deltas = np.stack(
94+
(
95+
np.zeros(np.prod(shape), dtype="float32").reshape(shape),
96+
np.linspace(-80, 80, num=np.prod(shape), dtype="float32").reshape(shape),
97+
np.linspace(-50, 50, num=np.prod(shape), dtype="float32").reshape(shape),
98+
),
99+
axis=-1,
100+
)
101+
102+
if ongrid:
103+
atol = 1e-3 if image_orientation == "oblique" or not ongrid else 1e-7
104+
# Build an identity transform (deltas)
105+
id_xfm_deltas = DenseFieldTransform(reference=reference)
106+
np.testing.assert_array_equal(coords_map, id_xfm_deltas._field)
107+
np.testing.assert_allclose(coords_xyz, id_xfm_deltas.map(coords_xyz))
108+
109+
# Build an identity transform (deformation)
110+
id_xfm_field = DenseFieldTransform(
111+
coords_map, is_deltas=False, reference=reference
112+
)
113+
np.testing.assert_array_equal(coords_map, id_xfm_field._field)
114+
np.testing.assert_allclose(coords_xyz, id_xfm_field.map(coords_xyz), atol=atol)
84115

85-
This test is disabled and will be split into two separate tests.
86-
The current implementation will be moved into test_resampling.py,
87-
since that's what it actually tests.
116+
# Collapse to zero transform (deltas)
117+
zero_xfm_deltas = DenseFieldTransform(-coords_map, reference=reference)
118+
np.testing.assert_array_equal(
119+
np.zeros_like(zero_xfm_deltas._field), zero_xfm_deltas._field
120+
)
121+
np.testing.assert_allclose(
122+
np.zeros_like(coords_xyz), zero_xfm_deltas.map(coords_xyz), atol=atol
123+
)
88124

89-
In GH-266, this test will be re-implemented by testing the equivalence
90-
of the B-Spline and deformation field transforms by calling the
91-
transform's `map()` method on points.
125+
# Collapse to zero transform (deformation)
126+
zero_xfm_field = DenseFieldTransform(
127+
np.zeros_like(deltas), is_deltas=False, reference=reference
128+
)
129+
np.testing.assert_array_equal(
130+
np.zeros_like(zero_xfm_field._field), zero_xfm_field._field
131+
)
132+
np.testing.assert_allclose(
133+
np.zeros_like(coords_xyz), zero_xfm_field.map(coords_xyz), atol=atol
134+
)
135+
136+
# Now let's apply a transform
137+
xfm = DenseFieldTransform(deltas, reference=reference)
138+
np.testing.assert_array_equal(deltas, xfm._deltas)
139+
np.testing.assert_array_equal(coords_map + deltas, xfm._field)
92140

93-
"""
94-
assert True
141+
mapped = xfm.map(coords_xyz)
142+
nit_deltas = mapped - coords_xyz
143+
144+
if ongrid:
145+
mapped_image = mapped.reshape(*shape, 3)
146+
np.testing.assert_allclose(deltas + coords_map, mapped_image)
147+
np.testing.assert_allclose(deltas, nit_deltas.reshape(*shape, 3), atol=1e-4)
148+
np.testing.assert_allclose(xfm._field, mapped_image)
149+
else:
150+
ongrid_xyz = xfm.map(grid_xyz[subsample])
151+
assert (
152+
(np.linalg.norm(ongrid_xyz - mapped, axis=1) > 2).sum()
153+
/ ongrid_xyz.shape[0]
154+
) < 0.5
95155

96156

97157
def test_map_bspline_vs_displacement(tmp_path, testdata_path):

nitransforms/tests/utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from pathlib import Path
44
import numpy as np
5+
import nibabel as nb
56

6-
from .. import linear as nbl
7+
from nitransforms import linear as nbl
8+
from nitransforms.base import ImageGrid
79

810

911
def assert_affines_by_filename(affine1, affine2):
@@ -26,3 +28,26 @@ def assert_affines_by_filename(affine1, affine2):
2628
xfm1 = np.loadtxt(str(affine1))
2729
xfm2 = np.loadtxt(str(affine2))
2830
assert np.allclose(xfm1, xfm2, atol=1e-04)
31+
32+
33+
def get_points(reference_nii, ongrid, npoints=5000, rng=None):
34+
"""Get points in RAS space."""
35+
if rng is None:
36+
rng = np.random.default_rng()
37+
38+
# Get sampling indices
39+
shape = reference_nii.shape[:3]
40+
ref_affine = reference_nii.affine.copy()
41+
reference = ImageGrid(nb.Nifti1Image(np.zeros(shape), ref_affine, None))
42+
grid_ijk = reference.ndindex
43+
grid_xyz = reference.ras(grid_ijk)
44+
45+
subsample = rng.choice(grid_ijk.shape[0], npoints)
46+
points_ijk = grid_ijk.copy() if ongrid else grid_ijk[subsample]
47+
coords_xyz = (
48+
grid_xyz
49+
if ongrid
50+
else reference.ras(points_ijk) + rng.normal(size=points_ijk.shape)
51+
)
52+
53+
return coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample

0 commit comments

Comments
 (0)