Skip to content

Commit c033fce

Browse files
committed
fix: make nitransforms/tests/test_nonlinear.py::test_densefield_map_vs_ants pass
1 parent 6a36577 commit c033fce

File tree

2 files changed

+165
-140
lines changed

2 files changed

+165
-140
lines changed

nitransforms/io/itk.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,11 @@ def from_image(cls, imgobj):
348348
hdr.set_intent("vector")
349349

350350
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351-
field[..., (0, 1)] *= 1.0
352-
field = field.transpose(2, 1, 0, 3)
353-
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
351+
affine = imgobj.affine
352+
midindex = (np.array(field.shape[:3]) - 1) * 0.5
353+
offset = (LPS @ affine - affine) @ (*midindex, 1)
354+
affine[:3, 3] += offset[:3]
355+
return imgobj.__class__(np.flip(field, axis=(0, 1)), imgobj.affine, hdr)
354356

355357
@classmethod
356358
def to_image(cls, imgobj):

nitransforms/tests/test_nonlinear.py

Lines changed: 160 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,11 @@
1616
DenseFieldTransform,
1717
)
1818
from nitransforms.tests.utils import get_points
19+
from nitransforms.io.itk import ITKDisplacementsField
1920

2021
rng = np.random.default_rng()
2122

2223

23-
SOME_TEST_POINTS = np.array(
24-
[
25-
[0.0, 0.0, 0.0],
26-
[1.0, 2.0, 3.0],
27-
[10.0, -10.0, 5.0],
28-
[-5.0, 7.0, -2.0],
29-
[12.0, 0.0, -11.0],
30-
]
31-
)
32-
33-
3424
def test_displacements_init():
3525
identity1 = DenseFieldTransform(
3626
np.zeros((10, 10, 10, 3)),
@@ -80,7 +70,6 @@ def test_bsplines_references(testdata_path):
8070
reference=testdata_path / "someones_anatomy.nii.gz",
8171
)
8272

83-
8473
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
8574
@pytest.mark.parametrize("ongrid", [True, False])
8675
def test_densefield_map(get_testdata, image_orientation, ongrid):
@@ -173,6 +162,165 @@ def test_densefield_map_vs_bspline(tmp_path, testdata_path):
173162
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)
174163

175164

165+
@pytest.mark.parametrize("ongrid", [True, False])
166+
def test_densefield_map_vs_ants(testdata_path, tmp_path, ongrid):
167+
"""Map points with DenseFieldTransform and compare to ANTs."""
168+
warpfile = (
169+
testdata_path
170+
/ "regressions"
171+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
172+
)
173+
if not warpfile.exists():
174+
pytest.skip("Composite transform test data not available")
175+
176+
nii = ITKDisplacementsField.from_filename(warpfile)
177+
178+
# Get sampling indices
179+
coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample = (
180+
get_points(nii, ongrid, npoints=5, rng=rng)
181+
)
182+
coords_map = grid_xyz.reshape(*shape, 3)
183+
184+
csvin = tmp_path / "fixed_coords.csv"
185+
csvout = tmp_path / "moving_coords.csv"
186+
np.savetxt(csvin, coords_xyz, delimiter=",", header="x,y,z", comments="")
187+
188+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
189+
exe = cmd.split()[0]
190+
if not shutil.which(exe):
191+
pytest.skip(f"Command {exe} not found on host")
192+
check_call(cmd, shell=True)
193+
194+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
195+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
196+
197+
xfm = DenseFieldTransform(nii, reference=reference)
198+
mapped = xfm.map(coords_xyz)
199+
200+
if ongrid:
201+
ants_mapped_xyz = ants_pts.reshape(*shape, 3)
202+
nit_mapped_xyz = mapped.reshape(*shape, 3)
203+
204+
nb.Nifti1Image(coords_map, ref_affine, None).to_filename(
205+
tmp_path / "baseline_field.nii.gz"
206+
)
207+
208+
nb.Nifti1Image(ants_mapped_xyz, ref_affine, None).to_filename(
209+
tmp_path / "ants_deformation_xyz.nii.gz"
210+
)
211+
nb.Nifti1Image(nit_mapped_xyz, ref_affine, None).to_filename(
212+
tmp_path / "nit_deformation_xyz.nii.gz"
213+
)
214+
nb.Nifti1Image(ants_mapped_xyz - coords_map, ref_affine, None).to_filename(
215+
tmp_path / "ants_deltas_xyz.nii.gz"
216+
)
217+
nb.Nifti1Image(nit_mapped_xyz - coords_map, ref_affine, None).to_filename(
218+
tmp_path / "nit_deltas_xyz.nii.gz"
219+
)
220+
221+
atol = 0 if ongrid else 1e-2
222+
rtol = 1e-4 if ongrid else 1e-6
223+
assert np.allclose(mapped, ants_pts, atol=atol, rtol=rtol)
224+
225+
226+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
227+
@pytest.mark.parametrize("ongrid", [True, False])
228+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, ongrid):
229+
"""Create a constant displacement field and compare mappings."""
230+
231+
nii = get_testdata[image_orientation]
232+
233+
# Get sampling indices
234+
coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample = (
235+
get_points(nii, ongrid, npoints=5, rng=rng)
236+
)
237+
238+
coords_map = grid_xyz.reshape(*shape, 3)
239+
gold_mapped_xyz = coords_map + deltas
240+
241+
deltas = np.hstack(
242+
(
243+
np.zeros(np.prod(shape)),
244+
np.linspace(-80, 80, num=np.prod(shape)),
245+
np.linspace(-50, 50, num=np.prod(shape)),
246+
)
247+
).reshape(shape + (3,))
248+
249+
fieldnii = nb.Nifti1Image(deltas, ref_affine, None)
250+
warpfile = tmp_path / "itk_transform.nii.gz"
251+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
252+
253+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
254+
xfm = DenseFieldTransform(fieldnii)
255+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
256+
257+
assert xfm == itk_xfm
258+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
259+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
260+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
261+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
262+
263+
# Ensure transform (xfm_orig) and ITK roundtrip (itk_xfm) are equivalent
264+
xfm_orig = DenseFieldTransform(deltas, reference=reference)
265+
np.testing.assert_allclose(xfm_orig.reference.shape, itk_xfm.reference.shape)
266+
np.testing.assert_allclose(ref_affine, xfm_orig.reference.affine)
267+
np.testing.assert_allclose(xfm_orig.reference.affine, itk_xfm.reference.affine)
268+
np.testing.assert_allclose(xfm_orig._field, itk_xfm._field)
269+
270+
# Ensure deltas and mapped grid are equivalent
271+
grid_mapped_xyz = itk_xfm.map(grid_xyz).reshape(*shape, -1)
272+
orig_grid_mapped_xyz = xfm_orig.map(grid_xyz).reshape(*shape, -1)
273+
274+
# Check apparent healthiness of mapping
275+
np.testing.assert_array_equal(orig_grid_mapped_xyz, grid_mapped_xyz)
276+
np.testing.assert_array_equal(gold_mapped_xyz, orig_grid_mapped_xyz)
277+
np.testing.assert_array_equal(gold_mapped_xyz, grid_mapped_xyz)
278+
279+
csvout = tmp_path / "mapped_xyz.csv"
280+
csvin = tmp_path / "coords_xyz.csv"
281+
np.savetxt(csvin, coords_xyz, delimiter=",", header="x,y,z", comments="")
282+
283+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
284+
exe = cmd.split()[0]
285+
if not shutil.which(exe):
286+
pytest.skip(f"Command {exe} not found on host")
287+
check_call(cmd, shell=True)
288+
289+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
290+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
291+
292+
nb.Nifti1Image(grid_mapped_xyz, ref_affine, None).to_filename(
293+
tmp_path / "grid_mapped.nii.gz"
294+
)
295+
nb.Nifti1Image(coords_map, ref_affine, None).to_filename(
296+
tmp_path / "baseline_field.nii.gz"
297+
)
298+
nb.Nifti1Image(gold_mapped_xyz, ref_affine, None).to_filename(
299+
tmp_path / "gold_mapped_xyz.nii.gz"
300+
)
301+
302+
if ongrid:
303+
ants_pts = ants_pts.reshape(*shape, 3)
304+
305+
nb.Nifti1Image(ants_pts, ref_affine, None).to_filename(
306+
tmp_path / "ants_mapped_xyz.nii.gz"
307+
)
308+
np.testing.assert_array_equal(gold_mapped_xyz, ants_pts)
309+
np.testing.assert_array_equal(deltas, ants_pts - coords_map)
310+
else:
311+
ants_deltas = ants_pts - coords_xyz
312+
deltas_xyz = deltas.reshape(-1, 3)[subsample]
313+
gold_xyz = coords_xyz + deltas_xyz
314+
np.testing.assert_array_equal(gold_xyz, ants_pts)
315+
np.testing.assert_array_equal(deltas_xyz, ants_deltas)
316+
317+
# np.testing.assert_array_equal(mapped, ants_pts)
318+
# diff = mapped - ants_pts
319+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
320+
321+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
322+
323+
176324
@pytest.mark.parametrize("is_deltas", [True, False])
177325
def test_densefield_oob_resampling(is_deltas):
178326
"""Ensure mapping outside the field returns input coordinates."""
@@ -248,128 +396,3 @@ def manual_map(x):
248396
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
249397
expected = np.vstack([manual_map(p) for p in pts])
250398
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
251-
252-
253-
def test_densefield_map_against_ants(testdata_path, tmp_path):
254-
"""Map points with DenseFieldTransform and compare to ANTs."""
255-
warpfile = (
256-
testdata_path
257-
/ "regressions"
258-
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
259-
)
260-
if not warpfile.exists():
261-
pytest.skip("Composite transform test data not available")
262-
263-
points = np.array(
264-
[
265-
[0.0, 0.0, 0.0],
266-
[1.0, 2.0, 3.0],
267-
[10.0, -10.0, 5.0],
268-
[-5.0, 7.0, -2.0],
269-
[-12.0, 12.0, 0.0],
270-
]
271-
)
272-
csvin = tmp_path / "points.csv"
273-
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
274-
275-
csvout = tmp_path / "out.csv"
276-
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
277-
exe = cmd.split()[0]
278-
if not shutil.which(exe):
279-
pytest.skip(f"Command {exe} not found on host")
280-
check_call(cmd, shell=True)
281-
282-
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
283-
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
284-
285-
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
286-
mapped = xfm.map(points)
287-
288-
assert np.allclose(mapped, ants_pts, atol=1e-6)
289-
290-
291-
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
292-
@pytest.mark.parametrize("gridpoints", [True, False])
293-
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
294-
"""Create a constant displacement field and compare mappings."""
295-
296-
nii = get_testdata[image_orientation]
297-
298-
# Create a reference centered at the origin with various axis orders/flips
299-
shape = nii.shape
300-
ref_affine = nii.affine.copy()
301-
302-
field = np.hstack((
303-
np.zeros(np.prod(shape)),
304-
np.linspace(-80, 80, num=np.prod(shape)),
305-
np.linspace(-50, 50, num=np.prod(shape)),
306-
)).reshape(shape + (3, ))
307-
fieldnii = nb.Nifti1Image(field, ref_affine, None)
308-
309-
warpfile = tmp_path / "itk_transform.nii.gz"
310-
ITKDisplacementsField.to_filename(fieldnii, warpfile)
311-
312-
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
313-
xfm = DenseFieldTransform(fieldnii)
314-
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
315-
316-
assert xfm == itk_xfm
317-
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
318-
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
319-
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
320-
np.testing.assert_allclose(xfm._field, itk_xfm._field)
321-
322-
points = (
323-
xfm.reference.ndcoords.T if gridpoints
324-
else np.array(
325-
[
326-
[0.0, 0.0, 0.0],
327-
[1.0, 2.0, 3.0],
328-
[10.0, -10.0, 5.0],
329-
[-5.0, 7.0, -2.0],
330-
[12.0, 0.0, -11.0],
331-
]
332-
)
333-
)
334-
335-
mapped = xfm.map(points)
336-
nit_deltas = mapped - points
337-
338-
if gridpoints:
339-
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
340-
341-
csvin = tmp_path / "points.csv"
342-
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
343-
344-
csvout = tmp_path / "out.csv"
345-
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
346-
exe = cmd.split()[0]
347-
if not shutil.which(exe):
348-
pytest.skip(f"Command {exe} not found on host")
349-
check_call(cmd, shell=True)
350-
351-
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
352-
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
353-
354-
# if gridpoints:
355-
# ants_field = ants_pts.reshape(shape + (3, ))
356-
# diff = xfm._field[..., 0] - ants_field[..., 0]
357-
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
358-
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
359-
360-
# diff = xfm._field[..., 1] - ants_field[..., 1]
361-
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
362-
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
363-
364-
# diff = xfm._field[..., 2] - ants_field[..., 2]
365-
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
366-
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
367-
368-
ants_deltas = ants_pts - points
369-
np.testing.assert_array_equal(nit_deltas, ants_deltas)
370-
np.testing.assert_array_equal(mapped, ants_pts)
371-
372-
diff = mapped - ants_pts
373-
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
374-
375-
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

0 commit comments

Comments
 (0)