Skip to content

Commit aacf125

Browse files
committed
Expand test points for constant field
1 parent 8257f7b commit aacf125

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,15 @@ def test_densefield_map_against_ants(testdata_path, tmp_path):
258258
if not warpfile.exists():
259259
pytest.skip("Composite transform test data not available")
260260

261-
points = np.array([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
261+
points = np.array(
262+
[
263+
[0.0, 0.0, 0.0],
264+
[1.0, 2.0, 3.0],
265+
[10.0, -10.0, 5.0],
266+
[-5.0, 7.0, -2.0],
267+
[-12.0, 12.0, 0.0],
268+
]
269+
)
262270
csvin = tmp_path / "points.csv"
263271
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
264272

@@ -272,7 +280,7 @@ def test_densefield_map_against_ants(testdata_path, tmp_path):
272280
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
273281
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
274282

275-
xfm = DenseFieldTransform(warpfile)
283+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
276284
mapped = xfm.map(points)
277285

278286
assert np.allclose(mapped, ants_pts, atol=1e-6)
@@ -282,15 +290,13 @@ def test_constant_field_vs_ants(tmp_path):
282290
"""Create a constant displacement field and compare mappings."""
283291

284292
# Create a reference centered at the origin
285-
shape = (5, 5, 5)
293+
shape = (25, 25, 25)
286294
ref_affine = from_matvec(np.eye(3), -(np.array(shape) - 1) / 2)
287295

288296
field = np.zeros(shape + (3,), dtype="float32")
289297
field[..., 0] = -5
290-
field[..., 1] = 0
291-
field[..., 2] = 5
292-
293-
field_img = nb.Nifti1Image(field, ref_affine)
298+
field[..., 1] = 5
299+
field[..., 2] = 0 # No flip in the third axis
294300

295301
warpfile = tmp_path / "const_disp.nii.gz"
296302
itk_img = sitk.GetImageFromArray(field, isVector=True)
@@ -301,7 +307,15 @@ def test_constant_field_vs_ants(tmp_path):
301307
itk_img.SetDirection(tuple(direction))
302308
sitk.WriteImage(itk_img, str(warpfile))
303309

304-
points = np.array([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
310+
points = np.array(
311+
[
312+
[0.0, 0.0, 0.0],
313+
[1.0, 2.0, 3.0],
314+
[10.0, -10.0, 5.0],
315+
[-5.0, 7.0, -2.0],
316+
[12.0, 0.0, -11.0],
317+
]
318+
)
305319
csvin = tmp_path / "points.csv"
306320
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
307321

@@ -315,7 +329,7 @@ def test_constant_field_vs_ants(tmp_path):
315329
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
316330
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
317331

318-
xfm = DenseFieldTransform(field_img)
332+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
319333
mapped = xfm.map(points)
320334

321335
assert np.allclose(mapped, ants_pts, atol=1e-6)

0 commit comments

Comments
 (0)