Skip to content
Merged
49 changes: 44 additions & 5 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
GridSamplePadMode,
InterpolateMode,
NumpyPadMode,
SpaceKeys,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
Expand All @@ -75,6 +76,7 @@
issequenceiterable,
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg_default
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
Expand Down Expand Up @@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

@deprecated_arg_default(
name="labels",
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
new_default=None,
msg_suffix=(
"Default value changed to None meaning that the transform now uses the 'space' of a "
"meta-tensor, if applicable, to determine appropriate axis labels."
),
)
def __init__(
self,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
lazy: bool = False,
) -> None:
"""
Expand All @@ -573,7 +584,14 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
If ``None``, an appropriate value is chosen depending on the
value of the ``"space"`` metadata item of a metatensor: if
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
input is not a meta-tensor or has no ``"space"`` item, the
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
``None``, the provided value is always used and the ``"space"``
metadata item (if any) of the input is ignored.
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False

Expand Down Expand Up @@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
affine_: np.ndarray
affine_np: np.ndarray
labels = self.labels
if isinstance(data_array, MetaTensor):
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
affine_ = to_affine_nd(sr, affine_np)

# Set up "labels" such that LPS tensors are handled correctly by default
if (
self.labels is None
and "space" in data_array.meta
and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS
):
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

else:
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
# default to identity
Expand All @@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
"please make sure the input is in the channel-first format."
)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
if len(dst) < sr:
raise ValueError(
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
Expand All @@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
labels = self.labels

# Set up "labels" such that LPS tensors are handled correctly by default
if (
isinstance(data, MetaTensor)
and self.labels is None
and "space" in data.meta
and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
):
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data)
Expand Down
21 changes: 19 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
ensure_tuple_rep,
fall_back_tuple,
)
from monai.utils.deprecate_utils import deprecated_arg_default
from monai.utils.enums import TraceKeys
from monai.utils.module import optional_import

Expand Down Expand Up @@ -545,12 +546,21 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform):

backend = Orientation.backend

@deprecated_arg_default(
name="labels",
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
new_default=None,
msg_suffix=(
"Default value changed to None meaning that the transform now uses the 'space' of a "
"meta-tensor, if applicable, to determine appropriate axis labels."
),
)
def __init__(
self,
keys: KeysCollection,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
allow_missing_keys: bool = False,
lazy: bool = False,
) -> None:
Expand All @@ -564,7 +574,14 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
If ``None``, an appropriate value is chosen depending on the
value of the ``"space"`` metadata item of a metatensor: if
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
input is not a meta-tensor or has no ``"space"`` item, the
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
``None``, the provided value is always used and the ``"space"``
metadata item (if any) of the input is ignored.
allow_missing_keys: don't raise exception if key is missing.
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False
Expand Down
96 changes: 87 additions & 9 deletions tests/transforms/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from typing import cast

import nibabel as nib
import numpy as np
Expand All @@ -21,6 +22,7 @@
from monai.data.meta_obj import set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.transforms import Orientation, create_rotate, create_translate
from monai.utils import SpaceKeys
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.test_utils import TEST_DEVICES, assert_allclose

Expand All @@ -33,6 +35,18 @@
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"LPS",
True,
*device,
]
)
Expand All @@ -43,6 +57,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"ALS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PRS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"PRS",
True,
*device,
]
)
Expand All @@ -53,6 +79,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"LPS",
True,
*device,
]
)
Expand All @@ -63,6 +101,18 @@
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"AL",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PR"},
torch.arange(6).reshape((2, 1, 3)),
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"PR",
True,
*device,
]
)
Expand All @@ -73,6 +123,18 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "R"},
torch.arange(6).reshape((2, 3)),
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"R",
True,
*device,
]
)
Expand All @@ -83,6 +145,7 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
Expand All @@ -93,6 +156,7 @@
torch.as_tensor(np.diag([-1, 1])),
torch.arange(6).reshape((2, 3)),
"L",
False,
*device,
]
)
Expand All @@ -107,6 +171,7 @@
),
torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),
"LPS",
False,
*device,
]
)
Expand All @@ -121,6 +186,7 @@
),
torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),
"RAS",
False,
*device,
]
)
Expand All @@ -131,6 +197,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[3, 0], [4, 1], [5, 2]]]),
"RA",
False,
*device,
]
)
Expand All @@ -141,6 +208,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[2, 5], [1, 4], [0, 3]]]),
"LP",
False,
*device,
]
)
Expand All @@ -151,6 +219,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"LPID",
False,
*device,
]
)
Expand All @@ -161,6 +230,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"RASD",
False,
*device,
]
)
Expand All @@ -175,6 +245,11 @@
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
]

TESTS_INVERSE = []
for device in TEST_DEVICES:
TESTS_INVERSE.append([True, *device])
TESTS_INVERSE.append([False, *device])


class TestOrientationCase(unittest.TestCase):
@parameterized.expand(TESTS)
Expand All @@ -185,17 +260,20 @@ def test_ornt_meta(
affine: torch.Tensor,
expected_data: torch.Tensor,
expected_code: str,
lps_convention: bool,
device,
):
img = MetaTensor(img, affine=affine).to(device)
meta = {"space": SpaceKeys.LPS} if lps_convention else None
img = MetaTensor(img, affine=affine, meta=meta).to(device)
ornt = Orientation(**init_param)
call_param = {"data_array": img}
res = ornt(**call_param) # type: ignore[arg-type]
if img.ndim in (3, 4):
test_resampler_lazy(ornt, res, init_param, call_param)

assert_allclose(res, expected_data.to(device))
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore
labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore
self.assertEqual("".join(new_code), expected_code)

@parameterized.expand(TESTS_TORCH)
Expand Down Expand Up @@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
with self.assertRaises(ValueError):
Orientation(**init_param)(img)

@parameterized.expand(TEST_DEVICES)
def test_inverse(self, device):
@parameterized.expand(TESTS_INVERSE)
def test_inverse(self, lps_convention: bool, device):
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
affine = torch.tensor(
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
)
meta = {"fname": "somewhere"}
meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}
img = MetaTensor(img_t, affine=affine, meta=meta)
tr = Orientation("LPS")
# check that image and affine have changed
img = tr(img)
img = cast(MetaTensor, tr(img))
self.assertNotEqual(img.shape, img_t.shape)
self.assertGreater((affine - img.affine).max(), 0.5)
self.assertGreater(float((affine - img.affine).max()), 0.5)
# check that with inverse, image affine are back to how they were
img = tr.inverse(img)
img = cast(MetaTensor, tr.inverse(img))
self.assertEqual(img.shape, img_t.shape)
self.assertLess((affine - img.affine).max(), 1e-2)
self.assertLess(float((affine - img.affine).max()), 1e-2)


if __name__ == "__main__":
Expand Down
Loading
Loading