Skip to content

Commit 69b832b

Browse files
committed
enh: update tests accordingly
1 parent 44a1601 commit 69b832b

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

nitransforms/io/itk.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351+
field[..., (0, 1)] *= 1.0
352+
field = field.transpose(2, 1, 0, 3)
351353
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352354

353355
@classmethod
@@ -357,7 +359,9 @@ def to_image(cls, imgobj):
357359
hdr = imgobj.header.copy()
358360
hdr.set_intent("vector")
359361

360-
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
362+
field = imgobj.get_fdata()
363+
field = field.transpose(2, 1, 0, 3)[..., None, :]
364+
field[..., (0, 1)] *= 1.0
361365
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
362366

363367

nitransforms/tests/test_nonlinear.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""Tests of nonlinear transforms."""
44

55
import os
6+
from subprocess import check_call
7+
import shutil
8+
69
import pytest
710

811
import numpy as np
@@ -244,3 +247,128 @@ def manual_map(x):
244247
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
245248
expected = np.vstack([manual_map(p) for p in pts])
246249
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
250+
251+
252+
def test_densefield_map_against_ants(testdata_path, tmp_path):
253+
"""Map points with DenseFieldTransform and compare to ANTs."""
254+
warpfile = (
255+
testdata_path
256+
/ "regressions"
257+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
258+
)
259+
if not warpfile.exists():
260+
pytest.skip("Composite transform test data not available")
261+
262+
points = np.array(
263+
[
264+
[0.0, 0.0, 0.0],
265+
[1.0, 2.0, 3.0],
266+
[10.0, -10.0, 5.0],
267+
[-5.0, 7.0, -2.0],
268+
[-12.0, 12.0, 0.0],
269+
]
270+
)
271+
csvin = tmp_path / "points.csv"
272+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
273+
274+
csvout = tmp_path / "out.csv"
275+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
276+
exe = cmd.split()[0]
277+
if not shutil.which(exe):
278+
pytest.skip(f"Command {exe} not found on host")
279+
check_call(cmd, shell=True)
280+
281+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
282+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
283+
284+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
285+
mapped = xfm.map(points)
286+
287+
assert np.allclose(mapped, ants_pts, atol=1e-6)
288+
289+
290+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
291+
@pytest.mark.parametrize("gridpoints", [True, False])
292+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
293+
"""Create a constant displacement field and compare mappings."""
294+
295+
nii = get_testdata[image_orientation]
296+
297+
# Create a reference centered at the origin with various axis orders/flips
298+
shape = nii.shape
299+
ref_affine = nii.affine.copy()
300+
301+
field = np.hstack((
302+
np.zeros(np.prod(shape)),
303+
np.linspace(-80, 80, num=np.prod(shape)),
304+
np.linspace(-50, 50, num=np.prod(shape)),
305+
)).reshape(shape + (3, ))
306+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
307+
308+
warpfile = tmp_path / "itk_transform.nii.gz"
309+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
310+
311+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
312+
xfm = DenseFieldTransform(fieldnii)
313+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
314+
315+
assert xfm == itk_xfm
316+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
317+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
318+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
319+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
320+
321+
points = (
322+
xfm.reference.ndcoords.T if gridpoints
323+
else np.array(
324+
[
325+
[0.0, 0.0, 0.0],
326+
[1.0, 2.0, 3.0],
327+
[10.0, -10.0, 5.0],
328+
[-5.0, 7.0, -2.0],
329+
[12.0, 0.0, -11.0],
330+
]
331+
)
332+
)
333+
334+
mapped = xfm.map(points)
335+
nit_deltas = mapped - points
336+
337+
if gridpoints:
338+
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
339+
340+
csvin = tmp_path / "points.csv"
341+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
342+
343+
csvout = tmp_path / "out.csv"
344+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
345+
exe = cmd.split()[0]
346+
if not shutil.which(exe):
347+
pytest.skip(f"Command {exe} not found on host")
348+
check_call(cmd, shell=True)
349+
350+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
351+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
352+
353+
# if gridpoints:
354+
# ants_field = ants_pts.reshape(shape + (3, ))
355+
# diff = xfm._field[..., 0] - ants_field[..., 0]
356+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
357+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
358+
359+
# diff = xfm._field[..., 1] - ants_field[..., 1]
360+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
361+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
362+
363+
# diff = xfm._field[..., 2] - ants_field[..., 2]
364+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
365+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
366+
367+
ants_deltas = ants_pts - points
368+
np.testing.assert_array_equal(nit_deltas, ants_deltas)
369+
np.testing.assert_array_equal(mapped, ants_pts)
370+
371+
diff = mapped - ants_pts
372+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
373+
374+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

0 commit comments

Comments
 (0)