Skip to content

Commit cb726c4

Browse files
committed
fix: write tests to avoid regressions
Resolves: #267.
1 parent 69b832b commit cb726c4

File tree

2 files changed

+55
-21
lines changed

2 files changed

+55
-21
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import nibabel as nb
13-
from nitransforms.resampling import apply
1413
from nitransforms.base import TransformError
1514
from nitransforms.nonlinear import (
1615
BSplineFieldTransform,
@@ -21,6 +20,17 @@
2120
rng = np.random.default_rng()
2221

2322

23+
SOME_TEST_POINTS = np.array(
24+
[
25+
[0.0, 0.0, 0.0],
26+
[1.0, 2.0, 3.0],
27+
[10.0, -10.0, 5.0],
28+
[-5.0, 7.0, -2.0],
29+
[12.0, 0.0, -11.0],
30+
]
31+
)
32+
33+
2434
def test_displacements_init():
2535
identity1 = DenseFieldTransform(
2636
np.zeros((10, 10, 10, 3)),
@@ -65,17 +75,8 @@ def test_bsplines_references(testdata_path):
6575
testdata_path / "someones_bspline_coefficients.nii.gz"
6676
).to_field()
6777

68-
with pytest.raises(TransformError):
69-
apply(
70-
BSplineFieldTransform(
71-
testdata_path / "someones_bspline_coefficients.nii.gz"
72-
),
73-
testdata_path / "someones_anatomy.nii.gz",
74-
)
75-
76-
apply(
77-
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
78-
testdata_path / "someones_anatomy.nii.gz",
78+
BSplineFieldTransform(
79+
testdata_path / "someones_bspline_coefficients.nii.gz",
7980
reference=testdata_path / "someones_anatomy.nii.gz",
8081
)
8182

nitransforms/tests/test_resampling.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ def test_apply_linear_transform(
150150

151151

152152
@pytest.mark.xfail(
153-
reason="Disable while #266 is developed.",
153+
reason="GH-267: disabled while debugging",
154154
strict=False,
155155
)
156156
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
157157
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
158158
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
159-
def test_displacements_field1(
159+
def test_apply_displacements_field1(
160160
tmp_path,
161161
get_testdata,
162162
get_testmask,
@@ -190,16 +190,17 @@ def test_displacements_field1(
190190
else:
191191
field.to_filename(xfm_fname)
192192

193-
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
193+
# xfm = nitnl.load(xfm_fname, fmt=sw_tool)
194+
xfm = nitnl.DenseFieldTransform(fieldmap, reference=nii)
194195

196+
ants_output = tmp_path / "ants_brainmask.nii.gz"
195197
# Then apply the transform and cross-check with software
196198
cmd = APPLY_NONLINEAR_CMD[sw_tool](
197199
transform=os.path.abspath(xfm_fname),
198200
reference=tmp_path / "mask.nii.gz",
199201
moving=tmp_path / "mask.nii.gz",
200-
output=tmp_path / "resampled_brainmask.nii.gz",
201-
extra="",
202-
# extra="--output-data-type uchar" if sw_tool == "itk" else "",
202+
output=ants_output,
203+
extra="--output-data-type uchar" if sw_tool == "itk" else "",
203204
)
204205

205206
# skip test if command is not available on host
@@ -209,11 +210,13 @@ def test_displacements_field1(
209210

210211
# resample mask
211212
exit_code = check_call([cmd], shell=True)
212-
sw_moved_mask = nb.load("resampled_brainmask.nii.gz")
213+
assert exit_code == 0
214+
sw_moved_mask = nb.load(ants_output)
213215
nt_moved_mask = apply(xfm, msk, order=0)
214216

215-
# Calculate xor between both:
216-
sw_mask = np.asanyarray(sw_moved_mask.dataobj, dtype=bool)
217+
nt_moved_mask.to_filename(tmp_path / "nit_brainmask.nii.gz")
218+
219+
assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR
217220
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)
218221
percent_diff = (sw_mask != brainmask)[5:-5, 5:-5, 5:-5].sum() / brainmask.size
219222

@@ -403,3 +406,33 @@ def test_apply_4d(serialize_4d):
403406
data = np.asanyarray(moved.dataobj)
404407
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
405408
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]
409+
410+
411+
@pytest.mark.xfail(
412+
reason="GH-267: disabled while debugging",
413+
strict=False,
414+
)
415+
def test_apply_bspline(tmp_path, testdata_path):
416+
"""Cross-check B-Splines and deformation field."""
417+
os.chdir(str(tmp_path))
418+
419+
img_name = testdata_path / "someones_anatomy.nii.gz"
420+
disp_name = testdata_path / "someones_displacement_field.nii.gz"
421+
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
422+
423+
bsplxfm = nitnl.BSplineFieldTransform(bs_name, reference=img_name)
424+
dispxfm = nitnl.DenseFieldTransform(disp_name)
425+
426+
out_disp = apply(dispxfm, img_name)
427+
out_bspl = apply(bsplxfm, img_name)
428+
429+
out_disp.to_filename("resampled_field.nii.gz")
430+
out_bspl.to_filename("resampled_bsplines.nii.gz")
431+
432+
assert (
433+
np.sqrt(
434+
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
435+
** 2
436+
).mean()
437+
< 0.2
438+
)

0 commit comments

Comments
 (0)