Skip to content

Commit e4e2b0a

Browse files
oestebanfeilong
andcommitted
fix: proper ordering of displacements fields when reading and writing
Transposes data as suggested by @feilong when reading and writing displacements fields. Tests have been added to ensure round-trip consistency. Related-to: #171. Co-Authored-By: Feilong Ma <[email protected]>
1 parent 5647caf commit e4e2b0a

File tree

5 files changed

+165
-77
lines changed

5 files changed

+165
-77
lines changed

env.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ dependencies:
2424
- nitime=0.10
2525
- scikit-image=0.22
2626
- scikit-learn=1.4
27+
# SimpleITK, so build doesn't complain about building scikit from sources
28+
- simpleitk=2.4
2729
# Utilities
2830
- graphviz=9.0
2931
- pandoc=3.1

nitransforms/io/itk.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Read/write ITK transforms."""
2+
23
import warnings
34
import numpy as np
45
from scipy.io import loadmat as _read_mat, savemat as _save_mat
@@ -138,8 +139,7 @@ def from_matlab_dict(cls, mdict, index=0):
138139
sa = tf.structarr
139140

140141
affine = mdict.get(
141-
"AffineTransform_double_3_3",
142-
mdict.get("AffineTransform_float_3_3")
142+
"AffineTransform_double_3_3", mdict.get("AffineTransform_float_3_3")
143143
)
144144

145145
if affine is None:
@@ -196,7 +196,7 @@ def from_string(cls, string):
196196
lines = lines[1:] # Drop banner with version
197197

198198
parameters = np.eye(4, dtype="f4")
199-
sa["index"] = int(lines[0][lines[0].index("T"):].split()[1])
199+
sa["index"] = int(lines[0][lines[0].index("T") :].split()[1])
200200
sa["offset"] = np.genfromtxt(
201201
[lines[3].split(":")[-1].encode()], dtype=cls.dtype["offset"]
202202
)
@@ -337,7 +337,7 @@ def from_image(cls, imgobj):
337337
hdr = imgobj.header.copy()
338338
shape = hdr.get_data_shape()
339339

340-
if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3):
340+
if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3):
341341
raise TransformFileError(
342342
'Displacements field "%s" does not come from ITK.'
343343
% imgobj.file_map["image"].filename
@@ -347,10 +347,10 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj))
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
351351
field[..., (0, 1)] *= -1.0
352352

353-
return imgobj.__class__(field, imgobj.affine, hdr)
353+
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
354354

355355
@classmethod
356356
def to_image(cls, imgobj):
@@ -359,10 +359,9 @@ def to_image(cls, imgobj):
359359
hdr = imgobj.header.copy()
360360
hdr.set_intent("vector")
361361

362-
warp_data = imgobj.get_fdata().reshape(imgobj.shape[:3] + (1, imgobj.shape[-1]))
363-
warp_data[..., (0, 1)] *= -1
364-
365-
return imgobj.__class__(warp_data, imgobj.affine, hdr)
362+
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
363+
field[..., (0, 1)] *= -1.0
364+
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
366365

367366

368367
class ITKCompositeH5:
@@ -410,21 +409,16 @@ def from_h5obj(cls, fileobj, check=True, only_linear=False):
410409
directions = np.reshape(_fixed[9:], (3, 3))
411410
affine = from_matvec(directions * zooms, offset)
412411
# ITK uses Fortran ordering, like NIfTI, but with the vector dimension first
413-
field = np.moveaxis(
414-
np.reshape(
415-
xfm[f"{typo_fallback}Parameters"], (3, *shape.astype(int)), order='F'
416-
),
417-
0,
418-
-1,
419-
)
420-
field[..., (0, 1)] *= -1.0
412+
# In practice, this seems to work (see issue #171)
413+
field = np.reshape(
414+
xfm[f"{typo_fallback}Parameters"], (*shape.astype(int), 3)
415+
).transpose(2, 1, 0, 3)
416+
421417
hdr = Nifti1Header()
422418
hdr.set_intent("vector")
423419
hdr.set_data_dtype("float")
424420

425-
xfm_list.append(
426-
Nifti1Image(field.astype("float"), LPS @ affine, hdr)
427-
)
421+
xfm_list.append(Nifti1Image(field.astype("float"), affine, hdr))
428422
continue
429423

430424
raise TransformIOError(

0 commit comments

Comments
 (0)