diff --git a/docs/source/lazy_resampling.rst b/docs/source/lazy_resampling.rst new file mode 100644 index 0000000000..f2e9911e3f --- /dev/null +++ b/docs/source/lazy_resampling.rst @@ -0,0 +1,104 @@ +:github_url: https://github.com/Project-MONAI/MONAI + +Lazy Resampling +=============== + +.. toctree:: + : maxdepth: 2 + + mb_specification + config_syntax.md + +Introduction +^^^^^^^^^^^^ + +Lazy Resampling is a new feature for MONAI 1.2. This feature is still experimental at this time and it is possible that +behaviour and APIs will change in upcoming releases. + +Lazy resampling is a feature that can be used to improve preprocessing pipelines in the following ways: + * it can improve pipeline execution time + * it can improve pipeline memory usage + * it can improve image and segmentation quality by reducing incidental noise caused by resampling + +How Lazy Resampling changes preprocessing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to understand how lazy resampling changes preprocessing, we'll first discuss standard processing pipeline +behaviour, and then compare it with the way lazy resampling works. + +Traditional resampling pipelines +++++++++++++++++++++++++++++++++ + +With traditional resampling, found both in MONAI and many other preprocessing libraries, you typically define a sequence +of transforms and pass them to a ``Compose`` object, such as `monai.transforms.compose.Compose`_. + +Example:: + + transforms = [ + LoadImaged(keys=["img", "seg"], ...), + EnsureChannelFirstd(keys=["img", "seg"], ...), + Spacingd(keys=["img", "seg"], ...), + Orientationd(keys=["img", "seg"], ...), + RandSpatialCropd(keys=["img", "seg"], ...), + RandRotate90d(keys=["img", "seg"], ...), + RandRotated(keys=["img", "seg"], ...), + RandZoomd(keys=["img", "seg"], ...), + RandGaussianNoised(keys="img", ...), + ] + compose = Compose(transforms) + + # elsewhere this will be called many times (such as in a Dataset instance) + outputs = compose(inputs) +:: + +The following will then happen when we call ``compose(inputs)``: + +1. ``LoadImaged`` is called with its inputs (a dictionary of strings containing file locations). This loads and + returns a dictionary of the corresponding data samples +2. ``EnsureChannelFirstd`` is called with the dictionary of data samples and adds a channel so that they have the + appropriate shape for the rest of the pipeline +3. ``Spacingd`` is called and reinterpolates the data samples +4. ``Orientationd`` permutes the data samples so that their spatial dimensions are reorganised +5. ``RandSpatialCropd`` crops a random patch of the data samples, throwing away the rest of the data in the process +6. ``RandRotate90d`` has a chance of performing a tensor-based rotation of the data samples +7. ``RandRotated`` has a chance of performing a full resample of the data samples +8. ``RandZoomd`` has a chance of performing a reinterpolation of the data samples +9. ``RandGaussianNoised`` has a chance of adding noise to ``img`` + +Overall, there are up to three occasions where the data is either interpolated or resampled through spatial transforms. +Furthermore, the crop that occurs means that the output data samples might contain pixels for which there is data but +that show padding values, because the data was thrown away by ``RandSpatialCrop``. + +Each of these operations takes time and memory, but, as we can see in the example above, also creates resampling +artifacts and can even destroy data in the resulting data samples (see `lazy resampling best practices`_ for examples). + +Lazy resampling pipelines +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Lazy resampling works very differently. When you execute the same pipeline with `lazy=True`, the following happens: + +1. ``LoadImaged`` behaves identically +2. ``EnsureChannelFirstd`` behaves identically +3. ``Spacingd`` is executing lazily. It puts a description of the operation that it wants to perform onto a list of + pending operations +4. ``Orientationd`` is executing lazily. It adds a description of its own operation to the pending operation list so + now there are 2 pending operations +5. ``RandSpatialCropd`` is executing lazily. It adds a description of its own operation to the pending operation list + so now there are 3 pending operations +6. ``RandRotate90d`` is executing lazily. It adds a description of its own operation to the pending operation list + so now there are 4 pending operations +7. ``RandRotated`` is executing lazily. It adds a description of its own operation to the pending operation list + so now there are 5 pending operations +8. ``RandZoomd`` is executing lazily. It adds a description of its own operation to the pending operation list + so now there are 6 pending operations + 1. ``[Spacingd, Orientationd, RandSpatialCropd, RandRotate90d, RandRotated, RandZoomd]`` are all on the pending + operations list but have yet to be carried out on the data +9. ``RandGaussianNoised`` is not a lazy transform. It is now time for the pending operations to be evaluated. Their + descriptions are mathematically composited together, to determine the operation that results from all of them + being carried out. This is then applied in a single resample operation. Once that is done, ``RandGaussianNoised`` + operates on the resulting data + +The single resampling operation has less noise induced by resampling, as it only occurs once in this pipeline rather +than three times in the traditional pipeline. More importantly, although the crop describes an operation to keep only a +subset of the data sample, the crop is not performed until after the spatial transforms are completed, which means that +all of the data sample that is within bounds is preserved and is part of the resulting output. diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index e045a7e741..fe17fa4efe 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -958,6 +958,17 @@ MRI Transforms :special-members: __call__ +Lazy +^^^^ + +`ApplyPending` +"""""""""""""" + +.. autoclass:: ApplyPending + :members: + :special-members: __call__ + + Utility ^^^^^^^ @@ -1912,6 +1923,17 @@ Smooth Field (Dict) :special-members: __call__ +Lazy (Dict) +^^^^^^^^^^^ + +`ApplyPendingd` +""""""""""""""" + +.. autoclass:: ApplyPendingd + :members: + :special-members: __call__ + + Utility (Dict) ^^^^^^^^^^^^^^ @@ -2211,9 +2233,3 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: - -Lazy ----- -.. automodule:: monai.transforms.lazy - :members: - :imported-members: diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 2df043e506..9e8b081104 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -212,7 +212,7 @@ def get_all_case_stats(self, key="training", transform_list=None): manager_list = manager.list() processes = [] for rank in range(nprocs): - p = tmp_ctx.Process( + p = tmp_ctx.Process( # type: ignore[attr-defined] target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list) ) processes.append(p) diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index a692a42369..853971adc9 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -17,6 +17,7 @@ from __future__ import annotations +import warnings from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy from typing import Any @@ -1326,7 +1327,11 @@ def __init__( super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys) self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> Mapping[Hashable, torch.Tensor]: + if lazy is True: + warnings.warn("RandRotateBox90d cannot currently execute lazily; ignoring lazy=True") self.randomize() d = dict(data) diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index 11454b0b6b..6b96f10cf9 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from collections.abc import Hashable, Mapping, Sequence import numpy as np @@ -213,10 +214,10 @@ class ReferenceBasedSpatialCropd(Cropd): """ def __init__(self, keys: KeysCollection, ref_key: str, allow_missing_keys: bool = False) -> None: - super().__init__(keys, cropper=None, allow_missing_keys=allow_missing_keys) # type: ignore + super().__init__(keys, cropper=None, allow_missing_keys=allow_missing_keys, lazy=False) # type: ignore self.ref_key = ref_key - def __call__(self, data: Mapping[Hashable, Tensor]) -> dict[Hashable, Tensor]: + def __call__(self, data: Mapping[Hashable, Tensor], lazy: bool | None = None) -> dict[Hashable, Tensor]: """ This transform can support to crop ND spatial (channel-first) data. It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D @@ -229,6 +230,9 @@ def __call__(self, data: Mapping[Hashable, Tensor]) -> dict[Hashable, Tensor]: Returns: the new data dictionary """ + if lazy is True: + warnings.warn("ReferenceBasedSpatialCropd cannot currently execute lazily; ignoring lazy=True") + d = dict(data) # compute roi_size according to self.ref_key diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 75cbec5607..84817d17b0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -230,7 +230,9 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict -from .lazy.functional import apply_transforms +from .lazy.array import ApplyPending +from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict +from .lazy.functional import apply_pending from .lazy.utils import combine_transforms, resample from .meta_utility.dictionary import ( FromMetaTensord, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ec6ec6a0fe..a610692952 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -22,13 +22,15 @@ import numpy as np import monai -import monai.transforms as mt from monai.apps.utils import get_logger from monai.config import NdarrayOrTensor from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import ThreadUnsafe +from monai.transforms.lazy.array import ApplyPending +from monai.transforms.lazy.dictionary import ApplyPendingd # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) +from monai.transforms.lazy.executors import apply_pending_transforms +from monai.transforms.traits import LazyTrait, ThreadUnsafe from monai.transforms.transform import ( # noqa: F401 LazyTransform, MapTransform, @@ -37,85 +39,12 @@ Transform, apply_transform, ) -from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed, to_tuple_of_dictionaries +from monai.transforms.utils import is_tensor_invertible +from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed logger = get_logger(__name__) -__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides", "SomeOf"] - - -def evaluate_with_overrides( - data, - upcoming, - lazy_evaluation: bool | None = False, - overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, -): - """ - The previously applied transform may have been lazily applied to MetaTensor `data` and - made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``, - this function determines whether `data.pending_operations` should be evaluated. If so, it will - evaluate the lazily applied transforms. - - Currently, the conditions for evaluation are: - - - ``lazy_evaluation`` is ``True``, AND - - the data is a ``MetaTensor`` and has pending operations, AND - - the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``. - - The returned `data` will then be ready for the ``upcoming`` transform. - - Args: - data: data to be evaluated. - upcoming: the upcoming transform. - lazy_evaluation: whether to evaluate the pending operations. - override: keyword arguments to apply transforms. - override_keys: to which the override arguments are used when apply transforms. - verbose: whether to print debugging info when evaluate MetaTensor with pending operations. - - """ - if not lazy_evaluation: - return data # eager evaluation - overrides = (overrides or {}).copy() - if isinstance(data, monai.data.MetaTensor): - if data.has_pending_operations and ( - (upcoming is None) - or (isinstance(upcoming, mt.Identity)) - or (isinstance(upcoming, mt.Identityd) and override_keys in upcoming.keys) - ): - data, _ = mt.apply_transforms(data, None, overrides=overrides) - if verbose: - next_name = "final output" if upcoming is None else f"'{upcoming.__class__.__name__}'" - logger.info(f"Evaluated - '{override_keys}' - up-to-date for - {next_name}") - elif verbose: - logger.info( - f"Lazy - '{override_keys}' - upcoming: '{upcoming.__class__.__name__}'" - f"- pending {len(data.pending_operations)}" - ) - return data - override_keys = ensure_tuple(override_keys) - if isinstance(data, dict): - if isinstance(upcoming, MapTransform): - applied_keys = {k for k in data if k in upcoming.keys} - if not applied_keys: - return data - else: - applied_keys = set(data.keys()) - - keys_to_override = {k for k in applied_keys if k in override_keys} - # generate a list of dictionaries with the appropriate override value per key - dict_overrides = to_tuple_of_dictionaries(overrides, override_keys) - for k in data: - if k in keys_to_override: - dict_for_key = dict_overrides[override_keys.index(k)] - data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, dict_for_key, k, verbose) - else: - data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, None, k, verbose) - - if isinstance(data, (list, tuple)): - return [evaluate_with_overrides(v, upcoming, lazy_evaluation, overrides, override_keys, verbose) for v in data] - return data +__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"] def execute_compose( @@ -125,12 +54,11 @@ def execute_compose( unpack_items: bool = False, start: int = 0, end: int | None = None, - lazy_evaluation: bool = False, + lazy: bool | None = False, + lazy_strategy: str = "in_order", overrides: dict | None = None, - override_keys: Sequence[str] | None = None, threading: bool = False, - log_stats: bool = False, - verbose: bool = False, + logger_name: str | None = None, ) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]: """ ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence @@ -148,26 +76,23 @@ def execute_compose( start: the index of the first transform to be executed. If not set, this defaults to 0 end: the index after the last transform to be exectued. If set, the transform at index-1 is the last transform that is executed. If this is not set, it defaults to len(transforms) - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + lazy: whether to enable lazy evaluation for lazy transforms. If False, transforms will be carried out on a transform by transform basis. If True, all lazy transforms will be executed by accumulating changes and resampling as few times as possible. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of - the pending operations and make the primary data up-to-date. + lazy_strategy: this field controls how execution occurs when processing data lazily. Permitted + options are "in_order", "out_of_order". Please see `Compose`_ for more details of what these + options mean. In general, you should not need to change this from its default. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. + to that transform before it is executed. Note that overrides are currently only applied when lazy + is True. If lazy is False they are ignored. currently supported args are: {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, - please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. + please see also :py:func:`monai.transforms.lazy.apply_pending` and ``Compose`` for more details. threading: whether executing is happening in a threaded environment. If set, copies are made of transforms that have the ``RandomizedTrait`` interface. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - verbose: whether to print debugging info when lazy_evaluation=True. + logger_name: The name of the logger that should be used during transform execution. If None, logging is + suppressed. Returns: A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running @@ -176,6 +101,8 @@ def execute_compose( end_ = len(transforms) if end is None else end if start is None: raise ValueError(f"'start' ({start}) cannot be None") + if start < 0: + raise ValueError(f"'start' ({start}) cannot be less than 0") if start > end_: raise ValueError(f"'start' ({start}) must be less than 'end' ({end_})") if end_ > len(transforms): @@ -188,21 +115,162 @@ def execute_compose( for _transform in transforms[start:end]: if threading: _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - data = evaluate_with_overrides( - data, + data = apply_transform( _transform, - lazy_evaluation=lazy_evaluation, + data, + map_items, + unpack_items, + lazy=lazy, + lazy_strategy=lazy_strategy, overrides=overrides, - override_keys=override_keys, - verbose=verbose, + logger_name=logger_name, ) - data = apply_transform(_transform, data, map_items, unpack_items, log_stats) - data = evaluate_with_overrides( - data, None, lazy_evaluation=lazy_evaluation, overrides=overrides, override_keys=override_keys, verbose=verbose - ) + data = apply_pending_transforms(data, None, overrides, logger_name=logger_name) return data +class ExecutionOptions: + """ + ExecutionOptions is an implementation class that is required to parse options for Compose. It should currently + be considered an implementation detail that should not be interacted with directly by users of MONAI, although that + may change in subsequent releases. Its job is to parse options provided to `Compose.__call__`_ to set execution + modes for executing the pipeline. + + See `Compose`_ for a detailed explanation of lazy resampling. + """ + + def __init__(self): + # construct the list of options + options = {"reorder": {"lazy_last": self.reorder_lazy_last, "lazy_last_nosync": self.reorder_lazy_last_nosync}} + self.options = options + + def __call__(self, transforms, lazy: bool | None, options: dict | None = None): + """ + Get a policy object that controls the way `Compose`_ executes a list of transforms. At present, the user can + only specify a single flag, but this design will be extended in future releases to allow multiple flags to + control different aspects of transform execution. + + Args: + transforms: a list of transforms to be executed + lazy: the current lazy mode (False, None, or True) + options: the options that determine the execution policy + + Returns: + a dictionary specifying the execution policy + + """ + if lazy is False or options is None: + return ExecutionOptions.generate_policy() + + if len(options.keys()) > 1: + raise ValueError("Only one option can currently be set") + + for k, v in options.items(): + if k not in self.options.keys(): + raise KeyError( + f"'{k}' is not a valid option key. Valid options are " f"{tuple(k for k in self.options.keys())}" + ) + + option = self.options[k] + if v not in option.keys(): + raise KeyError(f"'{v}' is not a valid option value. Value options for " f"'{k}' are {option.keys()}") + + action = option[v] + + return action(transforms=transforms, lazy=lazy) # type: ignore[operator] + + @classmethod + def reorder_lazy_last(cls, *, transforms: list, lazy: bool | None, **kwargs): + """ + 'reorder_lazy_last` effectively reorders a sequence of transforms so that lazy transforms are grouped together + after non-lazy ones. This operation can significantly change the behaviour of your pipeline and so should only + be used once you are clear about its behaviour. + + Example:: + + transforms = [LoadImage, Flip, GaussianNoise, Rotate90, ApplyPending, Zoom, Rotate] + + # ApplyPending effectively splits the pipeline up into two subranges. No transform can move after or before + # an ApplyPending instance, so we end up with transforms before and after ApplyPending + + sub_ranges = [[LoadImage, Flip, GaussianNoise, Rotate90], [ApplyPending], [Zoom, Rotate]] + + # Each subrange is then sorted so that non-lazy transform stay in their relative order but go before the + # lazy transforms (which also stay in their relative order) + + sub_ranges = [[LoadImage, GaussianNoise, Flip, Rotate90], [Apply + + :: + """ + subsections = list() + subsection_starts = list() + # pass 1: split the transform list into subsections + i_s = 0 + for i_t in range(len(transforms)): + if isinstance(transforms[i_t], (ApplyPending, ApplyPendingd)): + # this subsection ends and is added to the subsection list + if i_s < i_t: + subsections.append(transforms[i_s:i_t]) + subsection_starts.append(i_s) + # add apply pending in its own list + subsections.append([transforms[i_t]]) + subsection_starts.append(i_t) + i_s = i_t + 1 + + if i_s != len(transforms): + subsections.append(transforms[i_s:]) + subsection_starts.append(i_s) + + # pass 2: calculate the permuted indices + permuted_indices = list() + for sub_start, subsection in zip(subsection_starts, subsections): + for i_s, s in enumerate(subsection): + if not cls._executing_lazily(s, lazy): + permuted_indices.append(i_s + sub_start) + for i_s, s in enumerate(subsection): + if cls._executing_lazily(s, lazy): + permuted_indices.append(i_s + sub_start) + + # pass 2: sort the subsections + reordered = list() + for subsection in subsections: + # non-lazy, lazy + subsection = [t for t in subsection if not cls._executing_lazily(t, lazy)] + [ + t for t in subsection if cls._executing_lazily(t, lazy) + ] + reordered.extend(subsection) + + return ExecutionOptions.generate_policy({"indices": permuted_indices}) + + @classmethod + def reorder_lazy_last_nosync(cls, *, transforms: list, **_): + """ + 'reorder: lazy_last_nosync' is implemented through use of the 'out_of_order' execution + policy. See 'Compose'_ for details of this policy. + Args: + transforms: Not used by this method + + Returns: + + """ + return cls.generate_policy({"lazy_policy": "out_of_order"}) + + @staticmethod + def generate_policy(overrides: dict | None = None): + default_policy = {"indices": None, "transforms": None, "lazy_policy": "in_order"} + if overrides is not None: + for k, v in overrides.items(): + default_policy[k] = v + return default_policy + + @staticmethod + def _executing_lazily(t, lazy_policy): + if isinstance(t, LazyTrait): + lazy_ = t.lazy if lazy_policy is None else lazy_policy + return lazy_ + return False + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in @@ -291,6 +359,23 @@ class Compose(Randomizable, InvertibleTransform): should ensure that you fully execute the part of the pipeline that generates the data to be cached before caching it. This is quite simply done however, as shown by the following example. + Lazy resampling can be enabled or disabled through the ``lazy`` parameter. This is specified as an + optional boolean parameter. + + * False (default): Don't perform any lazy resampling + * None: Perform lazy resampling based on the 'lazy' properties of the transform instances. + * True: Always perform lazy resampling if possible. This will ignore the ``lazy`` properties + of the transform instances + + If you only want some of the pipeline to be executed lazily, there are two ways to achieve this. + + The first way is to set lazy=True on your Compose instance and specify for each transform whether you + want it to be lazily executed or not. + + The second way is to set lazy=True on your Compose instance and add ``ApplyPending`` or `ApplyPendingd` + transforms after the final transform in a sequence that you want to execute lazily. This can be done at multiple + points in the pipeline. + Example: # run the part of the pipeline that needs to be cached data = self.transform(data, end=self.post_cache_index) @@ -308,24 +393,22 @@ class Compose(Randomizable, InvertibleTransform): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be - carried out on a transform by transform basis. If True, all lazy transforms will - be executed by accumulating changes and resampling as few times as possible. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + lazy: whether to enable lazy evaluation for lazy transforms. This is an optional bool that can take + the following values. If lazy=False, lazy execution is disabled and transforms will be + carried out on a transform by transform basis. If lazy=True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. If lazy is None, + Compose will perform lazy execution on lazy transforms that have their lazy flag set to True. + A `monai.transforms.ApplyPending[d]` transform in the pipeline will trigger the evaluation of the pending operations and make the primary data up-to-date. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. + to that transform before it is executed. Note that overrides are currently only applied when lazy + is True. If lazy is False they are ignored. currently supported args are: {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + logger_name: this optional parameter allows you to specify a logger by name. If this is not set + it defaults to 'Compose'. You can also suppress logging by setting this to None. """ def __init__( @@ -333,29 +416,22 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + lazy: bool | None = False, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, + options: dict | None = None, + logger_name: str | None = None, ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) self.map_items = map_items self.unpack_items = unpack_items - self.log_stats = log_stats self.set_random_state(seed=get_seed()) - - self.lazy_evaluation = lazy_evaluation + self.lazy = lazy self.overrides = overrides - self.override_keys = override_keys - self.verbose = verbose - - if self.lazy_evaluation is not None: - for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf - if isinstance(t, LazyTransform): - t.lazy_evaluation = self.lazy_evaluation + self.options = options + self.logger_name = logger_name + self.execution_options = ExecutionOptions() def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose: super().set_random_state(seed=seed, state=state) @@ -432,52 +508,99 @@ def __len__(self): """Return number of transformations.""" return len(self.flatten().transforms) - def evaluate_with_overrides(self, input_, upcoming_xform): - """ - Args: - input_: input data to be transformed. - upcoming_xform: a transform used to determine whether to evaluate with override - """ - return evaluate_with_overrides( - input_, - upcoming_xform, - lazy_evaluation=self.lazy_evaluation, - overrides=self.overrides, - override_keys=self.override_keys, - verbose=self.verbose, - ) + def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + lazy_ = self.lazy if lazy is None else lazy + policy = self.execution_options(self.transforms, lazy_, self.options) + lazy_strategy = policy["lazy_policy"] + indices = policy["indices"] - def __call__(self, input_, start=0, end=None, threading=False): - return execute_compose( + # permute the transforms if required + transforms = self.transforms if indices is None else [self.transforms[i] for i in indices] + + result = execute_compose( input_, - self.transforms, + transforms, start=start, end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy_evaluation=self.lazy_evaluation, # type: ignore + lazy=self.lazy, # type: ignore + lazy_strategy=lazy_strategy, overrides=self.overrides, - override_keys=self.override_keys, threading=threading, - log_stats=self.log_stats, - verbose=self.verbose, + logger_name=self.logger_name, ) + # if the transforms were permuted, record it in the metadata for inversion + if indices is not None: + if isinstance(result, monai.data.MetaTensor): + self.push_transform(result, extra_info={"applied_order": indices}) + elif isinstance(result, Mapping): + for key in result: # dictionary not change size during iteration + if isinstance(result[key], monai.data.MetaTensor) or self.trace_key(key) in result: + self.push_transform(result, key, extra_info={"applied_order": indices}) + + return result + def inverse(self, data): - invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] - if not invertible_transforms: - warnings.warn("inverse has been called but no invertible transforms have been supplied") + policy = self.execution_options(self.transforms, self.lazy, self.options) + indices = policy["indices"] + + self._raise_if_tensor_is_not_invertible(data) + + if indices is not None: + applied_order = None + if isinstance(data, monai.data.MetaTensor): + applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"] + elif isinstance(data, Mapping): + for key in data: + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"] + else: + raise RuntimeError( + f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." + ) + if applied_order is None: + # no invertible transforms have been applied + return data - # loop backwards over transforms - for t in reversed(invertible_transforms): - if isinstance(t, LazyTransform) and t.lazy_evaluation: + # loop backwards over transforms + for o in reversed(applied_order): + if isinstance(self.transforms[o], InvertibleTransform): + data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items) + else: + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if not invertible_transforms: + warnings.warn("inverse has been called but no invertible transforms have been supplied") + + if self.lazy is not False: warnings.warn( - f"inversing {t.__class__.__name__} lazily may not implemented" - "please set `lazy_evaluation=False` before calling inverse." + f"'lazy' is set to {self.lazy} but lazy execution is not supported when inverting. " + f"'lazy' has been overridden to False for the call to inverse" + ) + # loop backwards over transforms + for t in reversed(invertible_transforms): + # if isinstance(t, LazyTrait) and t.lazy: + # warnings.warn( + # f"inversing {t.__class__.__name__} lazily may not implemented" + # "please set `lazy=False` before calling inverse." + # ) + data = apply_transform( + t.inverse, data, self.map_items, self.unpack_items, lazy=False, logger_name=self.logger_name ) - data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) return data + @staticmethod + def _raise_if_tensor_is_not_invertible(data: Any): + invertible, reasons = is_tensor_invertible(data) + + if invertible is False: + if reasons is not None: + reason_text = "\n".join(reasons) + raise RuntimeError(f"Unable to run inverse on 'data' for the following reasons:\n{reason_text}") + else: + raise RuntimeError("Unable to run inverse on 'data'; no reason logged in trace data") + class OneOf(Compose): """ @@ -492,24 +615,20 @@ class OneOf(Compose): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will + lazy: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will be executed by accumulating changes and resampling as few times as possible. If False, transforms will be carried out on a transform by transform basis. A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of the pending operations and make the primary data up-to-date. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. + to that transform before it is executed. Note that overrides are currently only applied when lazy + is True. If lazy is False they are ignored. currently supported args are: {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + logger_name: this optional parameter allows you to specify a logger by name. If this is not set + it defaults to 'OneOf'. You can also suppress logging by setting this to None. """ def __init__( @@ -518,15 +637,11 @@ def __init__( weights: Sequence[float] | float | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + lazy: bool | None = None, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, + logger_name: str | None = None, ) -> None: - super().__init__( - transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose - ) + super().__init__(transforms, map_items, unpack_items, lazy, overrides) if len(self.transforms) == 0: weights = [] elif weights is None or isinstance(weights, float): @@ -537,6 +652,7 @@ def __init__( f"got {len(weights)} and {len(self.transforms)}." ) self.weights = ensure_tuple(self._normalize_probabilities(weights)) + self.logger_name = logger_name def _normalize_probabilities(self, weights): if len(weights) == 0: @@ -565,7 +681,12 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data, start=0, end=None, threading=False): + def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | None = None): + if start != 0: + raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"OneOf requires 'end' parameter to be None (end set to {end}") + if len(self.transforms) == 0: return data @@ -575,11 +696,14 @@ def __call__(self, data, start=0, end=None, threading=False): data = execute_compose( data, [_transform], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, # type: ignore + overrides=self.overrides, threading=threading, + logger_name=self.logger_name, ) # if the data is a mapping (dictionary), append the OneOf transform to the end @@ -625,24 +749,20 @@ class RandomOrder(Compose): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will + lazy: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will be executed by accumulating changes and resampling as few times as possible. If False, transforms will be carried out on a transform by transform basis. A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of the pending operations and make the primary data up-to-date. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. + to that transform before it is executed. Note that overrides are currently only applied when lazy + is True. If lazy is False they are ignored. currently supported args are: {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + logger_name: this optional parameter allows you to specify a logger by name. If this is not set + it defaults to 'Compose'. You can also suppress logging by setting this to None. """ def __init__( @@ -650,30 +770,35 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + lazy: bool | None = None, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, + logger_name: str | None = None, ) -> None: - super().__init__( - transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose - ) + super().__init__(transforms, map_items, unpack_items, lazy, overrides) + self.logger_name = logger_name + + def __call__(self, input_, start=0, end=None, threading=False, lazy: str | bool | None = None): + if start != 0: + raise ValueError(f"RandomOrder requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"RandomOrder requires 'end' parameter to be None (end set to {end}") - def __call__(self, input_, start=0, end=None, threading=False): if len(self.transforms) == 0: return input_ + num = len(self.transforms) applied_order = self.R.permutation(range(num)) input_ = execute_compose( input_, [self.transforms[ind] for ind in applied_order], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, threading=threading, + logger_name=self.logger_name, ) # if the data is a mapping (dictionary), append the RandomOrder transform to the end @@ -707,9 +832,7 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): if isinstance(self.transforms[o], InvertibleTransform): - data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats - ) + data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items) return data @@ -727,14 +850,12 @@ class SomeOf(Compose): Defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. Defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. Default to `False`. num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. Default to `None`. replace: whether to sample with replacement. Defaults to `False`. weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). + logger_name: the name of the logger to use when logging output. """ def __init__( @@ -742,16 +863,17 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, *, num_transforms: int | tuple[int, int] | None = None, replace: bool = False, weights: list[int] | None = None, + logger_name: str | None = None, ) -> None: - super().__init__(transforms, map_items, unpack_items, log_stats) + super().__init__(transforms, map_items, unpack_items, logger_name=logger_name) self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms) self.replace = replace self.weights = self._normalize_probabilities(weights) + self.logger_name = logger_name def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple: if ( @@ -805,7 +927,12 @@ def _normalize_probabilities(self, weights): return ensure_tuple(list(weights)) - def __call__(self, data, start=0, end=None, threading=False): + def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | None = None): + if start != 0: + raise ValueError(f"SomeOf requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"SomeOf requires 'end' parameter to be None (end set to {end}") + if len(self.transforms) == 0: return data @@ -815,11 +942,14 @@ def __call__(self, data, start=0, end=None, threading=False): data = execute_compose( data, [self.transforms[a] for a in applied_order], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, + overrides=self.overrides, threading=threading, + logger_name=self.logger_name, ) if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"applied_order": applied_order}) @@ -852,10 +982,7 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): - transform = self.transforms[o] - if isinstance(transform, InvertibleTransform): - data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats - ) + if isinstance(self.transforms[o], InvertibleTransform): + data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items) return data diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 8cfd2c70ef..a0243cd237 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -106,8 +106,13 @@ class Pad(InvertibleTransform, LazyTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs + self, + to_pad: tuple[tuple[int, int]] | None = None, + mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, + **kwargs, ) -> None: + LazyTransform.__init__(self, lazy) self.to_pad = to_pad self.mode = mode self.kwargs = kwargs @@ -124,7 +129,12 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") def __call__( # type: ignore[override] - self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs + self, + img: torch.Tensor, + to_pad: tuple[tuple[int, int]] | None = None, + mode: str | None = None, + lazy: bool | None = None, + **kwargs, ) -> torch.Tensor: """ Args: @@ -150,7 +160,8 @@ def __call__( # type: ignore[override] kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_) + lazy_ = self.lazy if lazy is None else lazy + return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, lazy_, **kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -194,11 +205,12 @@ def __init__( spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...], method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ @@ -245,9 +257,11 @@ class BorderPad(Pad): """ - def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, **kwargs) -> None: + def __init__( + self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, lazy: bool = False, **kwargs + ) -> None: self.spatial_border = spatial_border - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) @@ -279,7 +293,12 @@ class DivisiblePad(Pad): backend = SpatialPad.backend def __init__( - self, k: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, **kwargs + self, + k: Sequence[int] | int, + mode: str = PytorchPadMode.CONSTANT, + method: str = Method.SYMMETRIC, + lazy: bool = False, + **kwargs, ) -> None: """ Args: @@ -301,7 +320,7 @@ def __init__( """ self.k = k self.method: Method = Method(method) - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) @@ -313,10 +332,16 @@ class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. + Args: + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] + def __init__(self, lazy: bool = False): + LazyTransform.__init__(self, lazy) + @staticmethod def compute_slices( roi_center: Sequence[int] | NdarrayOrTensor | None = None, @@ -370,7 +395,9 @@ def compute_slices( [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] ) - def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override] + def __call__( # type: ignore[override] + self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None + ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -384,7 +411,8 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor slices_ = list([slice(None)] + slices_[:sd]) img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - return crop_func(img_t, tuple(slices_), self.get_transform_info()) + lazy_ = self.lazy if lazy is None else lazy + return crop_func(img_t, tuple(slices_), lazy_, self.get_transform_info()) def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) @@ -417,6 +445,7 @@ def __init__( roi_start: Sequence[int] | NdarrayOrTensor | None = None, roi_end: Sequence[int] | NdarrayOrTensor | None = None, roi_slices: Sequence[slice] | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -428,17 +457,19 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ + super().__init__(lazy) self.slices = self.compute_slices( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=ensure_tuple(self.slices)) + lazy_ = self.lazy if lazy is None else lazy + return super().__call__(img=img, slices=ensure_tuple(self.slices), lazy=lazy_) class CenterSpatialCrop(Crop): @@ -456,7 +487,8 @@ class CenterSpatialCrop(Crop): the spatial size of output data will be [32, 40, 40]. """ - def __init__(self, roi_size: Sequence[int] | int) -> None: + def __init__(self, roi_size: Sequence[int] | int, lazy: bool = False) -> None: + super().__init__(lazy=lazy) self.roi_size = roi_size def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override] @@ -464,15 +496,17 @@ def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + lazy_ = self.lazy if lazy is None else lazy return super().__call__( img=img, slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), + lazy=lazy_, ) @@ -486,15 +520,17 @@ class CenterScaleCrop(Crop): """ - def __init__(self, roi_scale: Sequence[float] | float): + def __init__(self, roi_scale: Sequence[float] | float, lazy: bool = False): + super().__init__(lazy=lazy) self.roi_scale = roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - cropper = CenterSpatialCrop(roi_size=roi_size) - return super().__call__(img=img, slices=cropper.compute_slices(img_size)) + lazy_ = self.lazy if lazy is None else lazy + cropper = CenterSpatialCrop(roi_size=roi_size, lazy=lazy_) + return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_) class RandSpatialCrop(Randomizable, Crop): @@ -528,7 +564,9 @@ def __init__( max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: + super().__init__(lazy) self.roi_size = roi_size self.max_roi_size = max_roi_size self.random_center = random_center @@ -547,7 +585,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -558,10 +596,11 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(img_size) if self._size is None: raise RuntimeError("self._size not specified.") + lazy_ = self.lazy if lazy is None else lazy if self.random_center: - return super().__call__(img=img, slices=self._slices) - cropper = CenterSpatialCrop(self._size) - return super().__call__(img=img, slices=cropper.compute_slices(img_size)) + return super().__call__(img=img, slices=self._slices, lazy=lazy_) + cropper = CenterSpatialCrop(self._size, lazy=lazy_) + return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_) class RandScaleCrop(RandSpatialCrop): @@ -592,8 +631,11 @@ def __init__( max_roi_scale: Sequence[float] | float | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: - super().__init__(roi_size=-1, max_roi_size=None, random_center=random_center, random_size=random_size) + super().__init__( + roi_size=-1, max_roi_size=None, random_center=random_center, random_size=random_size, lazy=lazy + ) self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale @@ -609,14 +651,15 @@ def randomize(self, img_size: Sequence[int]) -> None: self.get_max_roi_size(img_size) super().randomize(img_size) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ self.get_max_roi_size(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) - return super().__call__(img=img, randomize=randomize) + lazy_ = self.lazy if lazy is None else lazy + return super().__call__(img=img, randomize=randomize, lazy=lazy_) class RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): @@ -660,11 +703,13 @@ def __init__( max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) + self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -673,25 +718,26 @@ def set_random_state( self.cropper.set_random_state(seed, state) return self - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value def randomize(self, data: Any | None = None) -> None: pass - def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> list[torch.Tensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ ret = [] + lazy_ = self.lazy if lazy is None else lazy for i in range(self.num_samples): - cropped = self.cropper(img) + cropped = self.cropper(img, lazy=lazy_) if get_track_meta(): cropped.meta[Key.PATCH_INDEX] = i # type: ignore - self.push_transform(cropped, replace=True) # track as this class instead of RandSpatialCrop + self.push_transform(cropped, replace=True, lazy=lazy_) # track as this class instead of RandSpatialCrop ret.append(cropped) return ret @@ -737,6 +783,7 @@ def __init__( return_coords: bool = False, k_divisible: Sequence[int] | int = 1, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **pad_kwargs, ) -> None: """ @@ -761,18 +808,23 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ + LazyTransform.__init__(self, lazy) self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None self.margin = margin self.allow_smaller = allow_smaller self.return_coords = return_coords self.k_divisible = k_divisible - self.padder = Pad(mode=mode, **pad_kwargs) + self.padder = Pad(mode=mode, lazy=lazy, **pad_kwargs) + + @Crop.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + self.padder.lazy = _val - @Crop.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val - self.padder.lazy_evaluation = _val + @property + def checks_data(self): + return False def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]: """ @@ -794,14 +846,20 @@ def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarra return box_start_, box_end_ def crop_pad( - self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs + self, + img: torch.Tensor, + box_start: np.ndarray, + box_end: np.ndarray, + mode: str | None = None, + lazy: bool = False, + **pad_kwargs, ) -> torch.Tensor: """ Crop and pad based on the bounding box. """ slices = self.compute_slices(roi_start=box_start, roi_end=box_end) - cropped = super().__call__(img=img, slices=slices) + cropped = super().__call__(img=img, slices=slices, lazy=lazy) pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum( box_end - np.asarray(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), 0 @@ -810,11 +868,11 @@ def crop_pad( pad_width = BorderPad(spatial_border=pad).compute_pad_width( cropped.peek_pending_shape() if isinstance(cropped, MetaTensor) else cropped.shape[1:] ) - ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, **pad_kwargs) + ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, lazy=lazy, **pad_kwargs) # combine the traced cropping and padding into one transformation # by taking the padded info and placing it in a key inside the crop info. if get_track_meta() and isinstance(ret, MetaTensor): - if not self.lazy_evaluation: + if not lazy: ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop() else: pad_info = ret.pending_operations.pop() @@ -826,19 +884,21 @@ def crop_pad( orig_size=crop_info.get(TraceKeys.ORIG_SIZE), sp_size=pad_info[LazyAttr.SHAPE], affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE], + lazy=lazy, extra_info=extra, ) return ret def __call__( # type: ignore[override] - self, img: torch.Tensor, mode: str | None = None, **pad_kwargs + self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ box_start, box_end = self.compute_bounding_box(img) - cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) + lazy_ = self.lazy if lazy is None else lazy + cropped = self.crop_pad(img, box_start, box_end, mode, lazy=lazy_, **pad_kwargs) if self.return_coords: return cropped, box_start, box_end # type: ignore[return-value] @@ -871,8 +931,13 @@ class RandWeightedCrop(Randomizable, TraceableTransform, LazyTransform, MultiSam backend = SpatialCrop.backend def __init__( - self, spatial_size: Sequence[int] | int, num_samples: int = 1, weight_map: NdarrayOrTensor | None = None + self, + spatial_size: Sequence[int] | int, + num_samples: int = 1, + weight_map: NdarrayOrTensor | None = None, + lazy: bool = False, ): + LazyTransform.__init__(self, lazy) self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map @@ -883,12 +948,16 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val def __call__( - self, img: torch.Tensor, weight_map: NdarrayOrTensor | None = None, randomize: bool = True + self, + img: torch.Tensor, + weight_map: NdarrayOrTensor | None = None, + randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -915,15 +984,15 @@ def __call__( _spatial_size = fall_back_tuple(self.spatial_size, img_shape) results: list[torch.Tensor] = [] + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -997,7 +1066,9 @@ def __init__( fg_indices: NdarrayOrTensor | None = None, bg_indices: NdarrayOrTensor | None = None, allow_smaller: bool = False, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) self.spatial_size = spatial_size self.label = label if pos < 0 or neg < 0: @@ -1044,9 +1115,13 @@ def randomize( self.allow_smaller, ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + + @property + def checks_data(self): + return False def __call__( self, @@ -1056,6 +1131,7 @@ def __call__( fg_indices: NdarrayOrTensor | None = None, bg_indices: NdarrayOrTensor | None = None, randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -1082,15 +1158,15 @@ def __call__( if self.centers is not None: img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] roi_size = fall_back_tuple(self.spatial_size, default=img_shape) + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=center, roi_size=roi_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -1177,7 +1253,9 @@ def __init__( allow_smaller: bool = False, warn: bool = True, max_samples_per_class: int | None = None, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) self.spatial_size = spatial_size self.ratios = ratios self.label = label @@ -1215,9 +1293,13 @@ def randomize( self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + + @property + def checks_data(self): + return False def __call__( self, @@ -1226,6 +1308,7 @@ def __call__( image: torch.Tensor | None = None, indices: list[NdarrayOrTensor] | None = None, randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -1248,15 +1331,15 @@ def __call__( if self.centers is not None: img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] roi_size = fall_back_tuple(self.spatial_size, default=img_shape) + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -1292,18 +1375,22 @@ def __init__( spatial_size: Sequence[int] | int, method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **pad_kwargs, ): - self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) - self.cropper = CenterSpatialCrop(roi_size=spatial_size) + LazyTransform.__init__(self, lazy) + self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, lazy=lazy, **pad_kwargs) + self.cropper = CenterSpatialCrop(roi_size=spatial_size, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.padder.lazy_evaluation = val - self.cropper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.padder.lazy = val + self.cropper.lazy = val + self._lazy = val - def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> torch.Tensor: # type: ignore + def __call__( # type: ignore[override] + self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs + ) -> torch.Tensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1318,16 +1405,17 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> note that `np.pad` treats channel dimension as the first dimension. """ - ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) + lazy_ = self.lazy if lazy is None else lazy + ret = self.padder(self.cropper(img, lazy_), mode=mode, lazy=lazy_, **pad_kwargs) # remove the individual info and combine if get_track_meta(): ret_: MetaTensor = ret # type: ignore - if not self.lazy_evaluation: + if not lazy_: pad_info = ret_.applied_operations.pop() crop_info = ret_.applied_operations.pop() orig_size = crop_info.get(TraceKeys.ORIG_SIZE) self.push_transform( - ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info} + ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info}, lazy=lazy_ ) else: pad_info = ret_.pending_operations.pop() @@ -1339,6 +1427,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> sp_size=pad_info[LazyAttr.SHAPE], affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE], extra_info={"pad_info": pad_info, "crop_info": crop_info}, + lazy=lazy_, ) return ret diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 8e9b6b2f1e..0ea6dc3442 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -47,7 +47,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import MultiSampleTrait +from monai.transforms.traits import LazyTrait, MultiSampleTrait from monai.transforms.transform import LazyTransform, MapTransform, Randomizable from monai.transforms.utils import is_positive from monai.utils import MAX_SEED, Method, PytorchPadMode, deprecated_arg_default, ensure_tuple_rep @@ -124,6 +124,7 @@ def __init__( padder: Pad, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -140,20 +141,32 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) + if lazy is True and not isinstance(padder, LazyTrait): + raise ValueError("'padder' must inherit LazyTrait if lazy is True " f"'padder' is of type({type(padder)})") self.padder = padder self.mode = ensure_tuple_rep(mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value if isinstance(self.padder, LazyTransform): - self.padder.lazy_evaluation = value + self.padder.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) + lazy_ = self.lazy if lazy is None else lazy + if lazy_ is True and not isinstance(self.padder, LazyTrait): + raise ValueError( + "'self.padder' must inherit LazyTrait if lazy is True " f"'self.padder' is of type({type(self.padder)}" + ) for key, m in self.key_iterator(d, self.mode): - d[key] = self.padder(d[key], mode=m) + if isinstance(self.padder, LazyTrait): + d[key] = self.padder(d[key], mode=m, lazy=lazy_) + else: + d[key] = self.padder(d[key], mode=m) + return d def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: @@ -177,6 +190,7 @@ def __init__( method: str = Method.SYMMETRIC, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -202,8 +216,9 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - padder = SpatialPad(spatial_size, method, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class BorderPadd(Padd): @@ -220,6 +235,7 @@ def __init__( spatial_border: Sequence[int] | int, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -249,8 +265,9 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - padder = BorderPad(spatial_border=spatial_border, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class DivisiblePadd(Padd): @@ -268,6 +285,7 @@ def __init__( mode: SequenceStr = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -293,8 +311,9 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - padder = DivisiblePad(k=k, method=method, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class Cropd(MapTransform, InvertibleTransform, LazyTransform): @@ -311,20 +330,22 @@ class Cropd(MapTransform, InvertibleTransform, LazyTransform): backend = Crop.backend - def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): - super().__init__(keys, allow_missing_keys) + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False): + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.cropper = cropper - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value if isinstance(self.cropper, LazyTransform): - self.cropper.lazy_evaluation = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.cropper(d[key]) # type: ignore + d[key] = self.cropper(d[key], lazy=lazy_) # type: ignore return d def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: @@ -348,8 +369,8 @@ class RandCropd(Cropd, Randomizable): backend = Crop.backend - def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False): + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandCropd: super().set_random_state(seed, state) @@ -361,13 +382,21 @@ def randomize(self, img_size: Sequence[int]) -> None: if isinstance(self.cropper, Randomizable): self.cropper.randomize(img_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) # the first key must exist to execute random operations first_item = d[self.first_key(d)] self.randomize(first_item.peek_pending_shape() if isinstance(first_item, MetaTensor) else first_item.shape[1:]) + lazy_ = self.lazy if lazy is None else lazy + if lazy_ is True and not isinstance(self.cropper, LazyTrait): + raise ValueError( + "'self.cropper' must inherit LazyTrait if lazy is True " + f"'self.cropper' is of type({type(self.cropper)}" + ) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} + if isinstance(self.cropper, LazyTrait): + kwargs["lazy"] = lazy_ d[key] = self.cropper(d[key], **kwargs) # type: ignore return d @@ -396,6 +425,7 @@ def __init__( roi_end: Sequence[int] | None = None, roi_slices: Sequence[slice] | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -411,8 +441,8 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ - cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class CenterSpatialCropd(Cropd): @@ -433,9 +463,11 @@ class CenterSpatialCropd(Cropd): allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False) -> None: - cropper = CenterSpatialCrop(roi_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + def __init__( + self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False, lazy: bool = False + ) -> None: + cropper = CenterSpatialCrop(roi_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class CenterScaleCropd(Cropd): @@ -453,10 +485,14 @@ class CenterScaleCropd(Cropd): """ def __init__( - self, keys: KeysCollection, roi_scale: Sequence[float] | float, allow_missing_keys: bool = False + self, + keys: KeysCollection, + roi_scale: Sequence[float] | float, + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = CenterScaleCrop(roi_scale) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = CenterScaleCrop(roi_scale, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandSpatialCropd(RandCropd): @@ -498,9 +534,10 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandScaleCropd(RandCropd): @@ -537,9 +574,10 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): @@ -590,19 +628,25 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) + LazyTransform.__init__(self, lazy) + self.cropper = RandSpatialCropSamples( + roi_size, num_samples, max_roi_size, random_center, random_size, lazy=lazy + ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value def randomize(self, data: Any | None = None) -> None: self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: ret: list[dict[Hashable, torch.Tensor]] = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): @@ -611,9 +655,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, # for each key we reset the random state to ensure crops are the same self.randomize() + + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(dict(data)): self.cropper.set_random_state(seed=self.sub_seed) - for i, im in enumerate(self.cropper(data[key])): + for i, im in enumerate(self.cropper(data[key], lazy=lazy_)): ret[i][key] = im return ret @@ -644,6 +690,7 @@ def __init__( start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, + lazy: bool = False, **pad_kwargs, ) -> None: """ @@ -683,17 +730,22 @@ def __init__( margin=margin, allow_smaller=allow_smaller, k_divisible=k_divisible, + lazy=lazy, **pad_kwargs, ) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value + + @property + def checks_data(self): + return True - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) @@ -701,8 +753,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc d[self.start_coord_key] = box_start # type: ignore if self.end_coord_key is not None: d[self.end_coord_key] = box_end # type: ignore + + lazy_ = self.lazy if lazy is None else lazy for key, m in self.key_iterator(d, self.mode): - d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) + d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m, lazy=lazy_) return d @@ -733,10 +787,12 @@ def __init__( spatial_size: Sequence[int] | int, num_samples: int = 1, allow_missing_keys: bool = False, + lazy: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.w_key = w_key - self.cropper = RandWeightedCrop(spatial_size, num_samples) + self.cropper = RandWeightedCrop(spatial_size, num_samples, lazy=lazy) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -748,12 +804,14 @@ def set_random_state( def randomize(self, weight_map: NdarrayOrTensor) -> None: self.cropper.randomize(weight_map) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: # output starts as empty list of dictionaries ret: list = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data @@ -762,8 +820,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, ret[i][key] = deepcopy(data[key]) self.randomize(weight_map=data[self.w_key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(data): - for i, im in enumerate(self.cropper(data[key], randomize=False)): + for i, im in enumerate(self.cropper(data[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -836,8 +895,10 @@ def __init__( bg_indices_key: str | None = None, allow_smaller: bool = False, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.label_key = label_key self.image_key = image_key self.fg_indices_key = fg_indices_key @@ -849,6 +910,7 @@ def __init__( num_samples=num_samples, image_threshold=image_threshold, allow_smaller=allow_smaller, + lazy=lazy, ) def set_random_state( @@ -867,12 +929,18 @@ def randomize( ) -> None: self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value + + @property + def checks_data(self): + return True - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) fg_indices = d.pop(self.fg_indices_key, None) bg_indices = d.pop(self.bg_indices_key, None) @@ -886,8 +954,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -984,8 +1053,10 @@ def __init__( allow_missing_keys: bool = False, warn: bool = True, max_samples_per_class: int | None = None, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.label_key = label_key self.image_key = image_key self.indices_key = indices_key @@ -998,6 +1069,7 @@ def __init__( allow_smaller=allow_smaller, warn=warn, max_samples_per_class=max_samples_per_class, + lazy=lazy, ) def set_random_state( @@ -1012,12 +1084,16 @@ def randomize( ) -> None: self.cropper.randomize(label=label, indices=indices, image=image) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, Any]) -> list[dict[Hashable, torch.Tensor]]: + @property + def checks_data(self): + return True + + def __call__(self, data: Mapping[Hashable, Any], lazy: bool | None = None) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) self.randomize(d.get(self.label_key), d.pop(self.indices_key, None), d.get(self.image_key)) # type: ignore @@ -1028,8 +1104,9 @@ def __call__(self, data: Mapping[Hashable, Any]) -> list[dict[Hashable, torch.Te for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -1065,10 +1142,13 @@ def __init__( mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, method: str = Method.SYMMETRIC, + lazy: bool = False, **pad_kwargs, ) -> None: - padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) # type: ignore + padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs, lazy=lazy) + super().__init__( + keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy # type: ignore + ) class BoundingRectd(MapTransform): diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index e694edb737..783635e467 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -27,14 +27,7 @@ from monai.data.utils import to_affine_nd from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import convert_pad_mode, create_translate -from monai.utils import ( - PytorchPadMode, - TraceKeys, - convert_to_dst_type, - convert_to_numpy, - convert_to_tensor, - ensure_tuple, -) +from monai.utils import PytorchPadMode, convert_to_dst_type, convert_to_numpy, convert_to_tensor, ensure_tuple __all__ = ["pad_nd", "pad_func", "crop_func", "crop_or_pad_nd"] @@ -161,11 +154,12 @@ def pad_func( to_pad: tuple[tuple[int, int]], transform_info: dict, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **kwargs, ) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according - to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + to ``lazy`` (default ``False``). `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, in which case `np.pad` will be used. @@ -181,6 +175,8 @@ def pad_func( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag indicating whether the operation should be performed in a lazy fashion or not. + transform_info: a dictionary with the relevant information pertaining to an applied transform. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ @@ -205,24 +201,25 @@ def pad_func( extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out out = convert_to_tensor(out, track_meta=get_track_meta()) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor: +def crop_func(img: torch.Tensor, slices: tuple[slice, ...], lazy: bool, transform_info: dict) -> torch.Tensor: """ Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according - to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + to ``lazy`` (default ``False``). Args: img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim. slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`. + lazy: a flag indicating whether the operation should be performed in a lazy fashion or not. transform_info: a dictionary with the relevant information pertaining to an applied transform. """ img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -243,10 +240,10 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore out = out[slices] return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 96b8e6b782..f010aa9de9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -23,8 +23,17 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd -from monai.transforms.transform import LazyTransform, Transform -from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor +from monai.transforms.traits import InvertibleTrait +from monai.transforms.transform import Transform +from monai.utils import ( + LazyAttr, + MetaKeys, + TraceKeys, + TraceStatusKeys, + convert_to_dst_type, + convert_to_numpy, + convert_to_tensor, +) __all__ = ["TraceableTransform", "InvertibleTransform"] @@ -77,13 +86,7 @@ def trace_key(key: Hashable = None): @staticmethod def transform_info_keys(): """The keys to store necessary info of an applied transform.""" - return ( - TraceKeys.CLASS_NAME, - TraceKeys.ID, - TraceKeys.TRACING, - TraceKeys.LAZY_EVALUATION, - TraceKeys.DO_TRANSFORM, - ) + return (TraceKeys.CLASS_NAME, TraceKeys.ID, TraceKeys.TRACING, TraceKeys.DO_TRANSFORM) def get_transform_info(self) -> dict: """ @@ -93,7 +96,6 @@ def get_transform_info(self) -> dict: self.__class__.__name__, id(self), self.tracing, - self.lazy_evaluation if isinstance(self, LazyTransform) else False, self._do_transform if hasattr(self, "_do_transform") else True, ) return dict(zip(self.transform_info_keys(), vals)) @@ -109,8 +111,9 @@ def push_transform(self, data, *args, **kwargs): set ``replace=True`` (default False) to rewrite the last transform infor in applied_operation/pending_operation based on ``self.get_transform_info()``. """ + lazy_eval = kwargs.get("lazy", False) transform_info = self.get_transform_info() - lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + # lazy_eval = transform_info.get(TraceKeys.lazy, False) do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) kwargs = kwargs or {} replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info @@ -125,9 +128,9 @@ def push_transform(self, data, *args, **kwargs): xform.update(transform_info) else: # lazy, replace=True, do_transform=False xform, extra = transform_info, {} - meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=True, extra_info=extra) + meta_obj = self.push_transform(data, transform_info=xform, lazy=True, extra_info=extra) return data.copy_meta_from(meta_obj) - kwargs["lazy_evaluation"] = lazy_eval + kwargs["lazy"] = lazy_eval if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): kwargs["transform_info"].update(transform_info) else: @@ -145,7 +148,7 @@ def track_transform_meta( extra_info: dict | None = None, orig_size: tuple | None = None, transform_info=None, - lazy_evaluation=False, + lazy=False, ): """ Update a stack of applied/pending transforms metadata of ``data``. @@ -163,7 +166,7 @@ def track_transform_meta( orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. transform_info: info from self.get_transform_info(). - lazy_evaluation: whether to push the transform to pending_operations or applied_operations. + lazy: whether to push the transform to pending_operations or applied_operations. Returns: @@ -176,10 +179,10 @@ def track_transform_meta( if isinstance(data_t, MetaTensor): out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) - if lazy_evaluation and (not get_track_meta()): + if lazy and (not get_track_meta()): warnings.warn("metadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.") - if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + if not lazy and affine is not None and isinstance(data_t, MetaTensor): # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] @@ -202,6 +205,10 @@ def track_transform_meta( info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() elif hasattr(data_t, "shape"): info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] + + # add lazy status to the transform info + info[TraceKeys.LAZY] = lazy + # include extra_info if extra_info is not None: extra_info.pop(LazyAttr.SHAPE, None) @@ -209,7 +216,7 @@ def track_transform_meta( info[TraceKeys.EXTRA_INFO] = extra_info # push the transform info to the applied_operation or pending_operation stack - if lazy_evaluation: + if lazy: if sp_size is None: if LazyAttr.SHAPE not in info: info[LazyAttr.SHAPE] = info.get(TraceKeys.ORIG_SIZE, []) @@ -227,17 +234,18 @@ def track_transform_meta( if out_obj.pending_operations: transform_name = info.get(TraceKeys.CLASS_NAME, "") if isinstance(info, dict) else "" msg = ( - f"Applying transform {transform_name} to a MetaTensor with pending operations " - "is not supported (as this eventually changes the ordering of applied_operations when the pending " - f"operations are executed). Please clear the pending operations before transform {transform_name}." - f"\nPending operations: {[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}." + f"Transform {transform_name} has been applied to a MetaTensor with pending operations: " + f"{[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}" ) + if key is not None: + msg += f" for key {key}" + pend = out_obj.pending_operations[-1] - if not isinstance(pend.get(TraceKeys.EXTRA_INFO), dict): - pend[TraceKeys.EXTRA_INFO] = dict(pend.get(TraceKeys.EXTRA_INFO, {})) - if not isinstance(info.get(TraceKeys.EXTRA_INFO), dict): - info[TraceKeys.EXTRA_INFO] = dict(info.get(TraceKeys.EXTRA_INFO, {})) - info[TraceKeys.EXTRA_INFO]["warn"] = pend[TraceKeys.EXTRA_INFO]["warn"] = msg + statuses = pend.get(TraceKeys.STATUSES, dict()) + messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, list()) + messages.append(msg) + statuses[TraceStatusKeys.PENDING_DURING_APPLY] = messages + info[TraceKeys.STATUSES] = statuses out_obj.push_applied_operation(info) if isinstance(data, Mapping): if not isinstance(data, dict): @@ -329,7 +337,7 @@ def trace_transform(self, to_trace: bool): self.tracing = prev -class InvertibleTransform(TraceableTransform): +class InvertibleTransform(TraceableTransform, InvertibleTrait): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py index 02349dd0f2..1e97f89407 100644 --- a/monai/transforms/lazy/__init__.py +++ b/monai/transforms/lazy/__init__.py @@ -8,8 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -from .functional import apply_transforms -from .utils import combine_transforms, resample diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..aa635eeda9 --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.transforms.traits import InvertibleTrait + +__all__ = ["ApplyPending"] + + +class ApplyPending(InvertibleTrait): + """ + ApplyPending can be inserted into a pipeline that is being executed lazily in order to ensure + resampling happens before the next transform. It doesn't do anything itself, but its presence + causes the pipeline to be executed as ApplyPending doesn't implement ```LazyTrait``. + + See ``Compose`` for a detailed explanation of the lazy resampling feature. + """ + + def __call__(self, data): + return data + + def inverse(self, data): + return data diff --git a/monai/transforms/lazy/dictionary.py b/monai/transforms/lazy/dictionary.py new file mode 100644 index 0000000000..7abb0ea026 --- /dev/null +++ b/monai/transforms/lazy/dictionary.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.transforms.traits import InvertibleTrait +from monai.transforms.transform import MapTransform + +__all__ = ["ApplyPendingd", "ApplyPendingD", "ApplyPendingDict"] + + +class ApplyPendingd(InvertibleTrait, MapTransform): + """ + ApplyPendingd can be inserted into a pipeline that is being executed lazily in order + to ensure resampling happens before the next transform. It doesn't do anything itself, + but its presence causes the pipeline to be executed as it doesn't implement ``LazyTrait`` + + See ``Compose`` for a detailed explanation of the lazy resampling feature. + + Args: + keys: the keys for tensors that should have their pending transforms executed + """ + + def __init__(self, keys: KeysCollection): + super().__init__(keys) + + def __call__(self, data): + if not isinstance(data, dict): + raise ValueError(f"'data' must be of type dict but is '{type(data)}'") + + return data + + def inverse(self, data): + if not isinstance(data, dict): + raise ValueError(f"'data' must be of type dict but is '{type(data)}'") + + return data + + +ApplyPendingD = ApplyPendingDict = ApplyPendingd diff --git a/monai/transforms/lazy/executors.py b/monai/transforms/lazy/executors.py new file mode 100644 index 0000000000..1adec4c6a9 --- /dev/null +++ b/monai/transforms/lazy/executors.py @@ -0,0 +1,229 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Mapping, Sequence + +from monai.apps.utils import get_logger +from monai.config import NdarrayOrTensor +from monai.data.meta_tensor import MetaTensor +from monai.transforms.lazy.array import ApplyPending +from monai.transforms.lazy.dictionary import ApplyPendingd +from monai.transforms.lazy.functional import apply_pending +from monai.transforms.traits import LazyTrait +from monai.transforms.transform import MapTransform + +__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending_transforms_out_of_order"] + + +def _log_pending_info( + transform: Any, + data: Any, + activity: str, + *, + lazy: bool | None = None, + key: str | None = None, + logger_name: str | None = None, +): + if logger_name is None: + return + logger = get_logger(logger_name) + + if isinstance(transform, LazyTrait): + if lazy is not None and lazy != transform.lazy: + tlazy = f", transform.lazy: {transform.lazy} (overridden)" + else: + tlazy = f", transform.lazy: {transform.lazy}" + else: + tlazy = ", transform is not lazy" + + if isinstance(transform, MapTransform): + transform_keys = transform.keys if key is None else (key,) + for k in transform_keys: + if k in data: + pcount = len(data[k].pending_operations) if isinstance(data[k], MetaTensor) else 0 + logger.info( + f"{activity} - lazy mode: {lazy}, key: '{k}', " + f"pending: {pcount}, upcoming '{transform.__class__.__name__}'{tlazy}" + ) + else: + pcount = len(data.pending_operations) if isinstance(data, MetaTensor) else 0 + if key is None: + logger.info( + f"{activity} - lazy: {lazy}, " f"pending: {pcount}, upcoming '{transform.__class__.__name__}'{tlazy}" + ) + else: + logger.info( + f"{activity} - lazy mode: {lazy}, key: '{key}', " + f"pending: {pcount}, upcoming '{transform.__class__.__name__}'{tlazy}" + ) + + +def _log_applied_info(data: Any, key=None, logger_name: str | None = None): + if logger_name is None: + return + logger = get_logger(logger_name) + + key_str = "" if key is None else f"key: '{key}', " + logger.info(f"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}") + + +def apply_pending_transforms( + data: NdarrayOrTensor | Sequence[Any, NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], + keys: tuple | None, + overrides: dict | None = None, + logger_name: str | None = None, +): + """ + apply_pending_transforms is called with either a tensor or a dictionary, some entries of which contain + tensors. + + When operating on a dictionary of tensors, the 'keys' parameter determines what tensors should be checked. + If 'keys' is not set, all keys of 'data' are considered. + + This method optionally takes a set of overrides that can be used to change specific parameters on the + transform pipeline. See ``Compose`` for more details. This method takes a logger_name that can be used + to override the default logger, to provide telemetry during the execution of pending transforms. + + This method is intended primarily for use by ``execute_compose`` and other methods that handle the + underlying execution of transform pipelines. You should not need to use it in the general case, unless + you are developing functionality to perform such operations. + + Args: + data: a ``torch.Tensor`` or ``MetaTensor``, or dictionary of tensors. + keys: an optional tuple of keys that filters the keys on 'data' if it is a dict + overrides: An optional dictionary that specifies parameters that can be used to override transform + arguments when they are called. When 'data' is a dict, this dictionary should contain a dictionary + of overrides for each key that needs them + logger_name: An optional name for a logger to be used when applying pending transforms. If None, + logging is suppressed. + Returns: + an object of the same type as data if pending transforms were applied, or 'data' if they were not + """ + if isinstance(data, list): + return [apply_pending_transforms(d, keys, overrides, logger_name) for d in data] + if isinstance(data, tuple): + return tuple(apply_pending_transforms(d, keys, overrides, logger_name) for d in data) + + if isinstance(data, dict): + # get the keys from 'data' for metatensors with pending operations. If 'keys' is set, select + # only data keys that are in 'keys' + active_keys = [k for k in data.keys() if keys is None or k in keys] + keys_to_update = [k for k in active_keys if isinstance(data[k], MetaTensor) and data[k].has_pending_operations] + + if len(keys_to_update) > 0: + rdata = dict(data) + + for k in keys_to_update: + overrides_ = None if overrides is None else overrides.get(k, None) + rdata[k], _ = apply_pending(data[k], overrides=overrides_) + _log_applied_info(rdata[k], key=k, logger_name=logger_name) + + return rdata + else: + if isinstance(data, MetaTensor) and data.has_pending_operations: + rdata, _ = apply_pending(data, overrides=overrides) + _log_applied_info(rdata, logger_name=logger_name) + return rdata + + return data + + +def apply_pending_transforms_in_order( + transform, data, lazy: bool | None = None, overrides: dict | None = None, logger_name: str | None = None +): + """ + This method causes "out of order" processing of pending transforms to occur. + + Out of order processing for lazy resampling only causes pending transforms to be processed when + an `ApplyPending`_ or `ApplyPendingd`_ transform is encountered in the pipeline. + + This method is designed to be used only in the context of implementing lazy resampling functionality. In general + you should not need to interact with or use this method directly. + Args: + transform: a transform that should be evaluated to determine whether pending transforms should be applied + data: a tensor / MetaTensor, or dictionary containing tensors / MetaTensors whose pending transforms may + need to be applied + lazy: The lazy mode that is being applied (this can be False, True or None) + overrides: An optional dictionary containing overrides to be applied to the pending transforms when they + are lazily executed. If data is a dict, it should contain a dictionary of overrides for each key that + needs them + logger_name: An optional name for a logger to be used when applying pending transforms. If None, + logging is suppressed. + Returns: + an object of the same type as data if pending transforms were applied, or 'data' if they were not + + """ + apply_pending = False + keys = None + if isinstance(transform, LazyTrait): + if transform.checks_data: + apply_pending = True + else: + apply_pending = not (transform.lazy if lazy is None else lazy) + elif isinstance(transform, ApplyPending): + apply_pending = True + elif isinstance(transform, ApplyPendingd): + apply_pending = True + keys = transform.keys + else: + apply_pending = True + + if apply_pending is True: + _log_pending_info(transform, data, "Apply pending transforms", lazy=lazy, logger_name=logger_name) + return apply_pending_transforms(data, keys, overrides, logger_name) + + _log_pending_info(transform, data, "Accumulate pending transforms", lazy=lazy, logger_name=logger_name) + return data + + +def apply_pending_transforms_out_of_order( + transform, data, lazy: bool | None = None, overrides: dict | None = None, logger_name: str | None = None +): + """ + This method causes "out of order" processing of pending transforms to occur. + + Out of order processing for lazy resampling only causes pending transforms to be processed when + an `ApplyPending`_ or `ApplyPendingd`_ transform is encountered in the pipeline. + + This method is designed to be used only in the context of implementing lazy resampling functionality. In general + you should not need to interact with or use this method directly. + Args: + transform: a transform that should be evaluated to determine whether pending transforms should be applied + data: a tensor / MetaTensor, or dictionary containing tensors / MetaTensors whose pending transforms may + need to be applied + lazy: The lazy mode that is being applied (this can be False, True or None) + overrides: An optional dictionary containing overrides to be applied to the pending transforms when they + are lazily executed. If data is a dict, it should contain a dictionary of overrides for each key that + needs them + logger_name: An optional name for a logger to be used when applying pending transforms. If None, + logging is suppressed. + Returns: + an object of the same type as data if pending transforms were applied, or 'data' if they were not + + """ + apply_pending = False + keys = None + if lazy is False: + apply_pending = True + elif isinstance(transform, ApplyPending): + apply_pending = True + elif isinstance(transform, ApplyPendingd): + apply_pending = True + keys = transform.keys + + if apply_pending is True: + _log_pending_info(transform, data, "Apply pending transforms", lazy=lazy, logger_name=logger_name) + return apply_pending_transforms(data, keys, overrides, logger_name) + + _log_pending_info(transform, data, "Accumulate pending transforms", lazy=lazy, logger_name=logger_name) + return data diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 22c74cef8a..5324fa7058 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -15,8 +15,13 @@ import torch +# from monai.apps.utils import get_logger from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd + +# from monai.transforms.lazy.array import ApplyPending +# from monai.transforms.lazy.dictionary import ApplyPendingd +# from monai.transforms.traits import LazyTrait from monai.transforms.lazy.utils import ( affine_from_pending, combine_transforms, @@ -26,17 +31,35 @@ ) from monai.utils import LazyAttr, look_up_option -__all__ = ["apply_transforms"] +__all__ = ["apply_pending"] __override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"} -def apply_transforms( - data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None, **kwargs: Any -): +def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None): """ This method applies pending transforms to `data` tensors. - Currently, only 2d and 3d input are supported. + Currently, only 2d and 3d inputs are supported. + + This method is designed to be called by ``apply_pending_transforms`` and other methods / classes + that are part of the implementation of lazy resampling. In general, you should not need to tall + this method unless you are directly developing custom lazy execution strategies. + + It works by calculating the overall effect of the accumulated pending transforms. When it runs + out of pending transforms or when it finds incompatibilities between the accumulated pending + transform and the next pending transform, it then applies the accumulated transform in a call to + '`resample``. + + Pending transforms are incompatible with each other if one or more of the arguments in the pending + transforms differ. These are parameters such as 'mode', 'padding_mode', 'dtype' and so forth. If + a pending transform doesn't have a given parameter, it is considered compatible with the + accumulated transform. If a subsequent transform has a parameter that is incompatible with + the accumulated transform (e.g. 'mode' of 'bilinear' vs. 'mode' of 'nearest'), an intermediate + resample will be performed and the accumulated transform reset to its starting state. + + After resampling, the pending transforms are pushed to the ``applied_transforms`` field of the + resulting MetaTensor. Note, if a torch.tensor is passed to this method along with a list of + pending transforms, the resampled tensor will be wrapped in a MetaTensor before being returned. Args: data: A torch Tensor or a monai MetaTensor. @@ -63,10 +86,11 @@ def apply_transforms( - device: device for resampling computation. Defaults to ``None``. - resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the :py:class:`monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad). + logger_name: A logger name that is used to log output generated while applying pending transforms. You can + suppress logging by setting this to None (default). """ overrides = (overrides or {}).copy() - overrides.update((kwargs or {}).copy()) for k in overrides: look_up_option(k, __override_keywords) # check existence of the key @@ -103,9 +127,11 @@ def apply_transforms( _cur_kwargs = cur_kwargs.copy() _cur_kwargs.update(override_kwargs) data = resample(data.to(device), cumulative_xform, _cur_kwargs) + next_matrix = affine_from_pending(p) if next_matrix.shape[0] == 3: next_matrix = to_affine_nd(3, next_matrix) + cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) cur_kwargs.update(override_kwargs) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index fa1bb6d48e..d9c8404cdb 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -223,7 +223,9 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = img.affine = call_kwargs["dst_affine"] return img + # TODO: lazy evaluation - no need to separately set lazy to False + # resampler = monai.transforms.SpatialResample(lazy=False, **init_kwargs) resampler = monai.transforms.SpatialResample(**init_kwargs) - resampler.lazy_evaluation = False # resampler is a lazytransform + resampler.lazy = False # resampler is a lazytransform with resampler.trace_transform(False): # don't track this transform in `img` return resampler(img=img, **call_kwargs) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6fe433a0bc..eb7f273e4c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -134,6 +134,7 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, + lazy: bool = False, ): """ Args: @@ -152,7 +153,10 @@ def __init__( dtype: data type for resampling computation. Defaults to ``float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners @@ -167,6 +171,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -198,7 +203,9 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``. When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, @@ -210,8 +217,17 @@ def __call__( align_corners = align_corners if align_corners is not None else self.align_corners mode = mode if mode is not None else self.mode padding_mode = padding_mode if padding_mode is not None else self.padding_mode + lazy_ = self.lazy if lazy is None else lazy return spatial_resample( - img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, self.get_transform_info() + img, + dst_affine, + spatial_size, + mode, + padding_mode, + align_corners, + dtype_pt, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -244,6 +260,7 @@ def __call__( # type: ignore padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -267,6 +284,10 @@ def __call__( # type: ignore dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + Raises: ValueError: When the affine matrix of the source image is not invertible. Returns: @@ -275,6 +296,7 @@ def __call__( # type: ignore if img_dst is None: raise RuntimeError("`img_dst` is missing.") dst_affine = img_dst.peek_pending_affine() if isinstance(img_dst, MetaTensor) else torch.eye(4) + lazy_ = self.lazy if lazy is None else lazy img = super().__call__( img=img, dst_affine=dst_affine, @@ -283,8 +305,9 @@ def __call__( # type: ignore padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) - if not self.lazy_evaluation: + if not lazy_: if isinstance(img, MetaTensor): img.affine = dst_affine if isinstance(img_dst, MetaTensor): @@ -321,6 +344,7 @@ def __init__( recompute_affine: bool = False, min_pixdim: Sequence[float] | float | np.ndarray | None = None, max_pixdim: Sequence[float] | float | np.ndarray | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -373,8 +397,10 @@ def __init__( max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the value of `pixdim`. Default to `None`. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64) self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64) @@ -387,13 +413,13 @@ def __init__( raise ValueError(f"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.") self.sp_resample = SpatialResample( - mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype + mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.sp_resample.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.sp_resample.lazy = val @deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.") def __call__( @@ -406,6 +432,7 @@ def __call__( dtype: DtypeLike = None, scale_extent: bool | None = None, output_spatial_shape: Sequence[int] | np.ndarray | int | None = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -435,6 +462,9 @@ def __call__( output_spatial_shape: specify the shape of the output data_array. This is typically useful for the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization error with the affine. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``data_array`` has no spatial dimensions. @@ -485,6 +515,7 @@ def __call__( new_affine[:sr, -1] = offset[:sr] actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape + lazy_ = self.lazy if lazy is None else lazy data_array = self.sp_resample( data_array, dst_affine=torch.as_tensor(new_affine), @@ -493,9 +524,10 @@ def __call__( padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) if self.recompute_affine and isinstance(data_array, MetaTensor): - if self.lazy_evaluation: + if lazy_: raise NotImplementedError("recompute_affine is not supported with lazy evaluation.") a = scale_affine(original_spatial_shape, actual_shape) data_array.affine = convert_to_dst_type(a, affine_)[0] # type: ignore @@ -517,6 +549,7 @@ def __init__( axcodes: str | None = None, as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), + lazy: bool = False, ) -> None: """ Args: @@ -529,6 +562,8 @@ def __init__( labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False Raises: ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values. @@ -536,6 +571,7 @@ def __init__( See Also: `nibabel.orientations.ornt2axcodes`. """ + LazyTransform.__init__(self, lazy=lazy) if axcodes is None and not as_closest_canonical: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") if axcodes is not None and as_closest_canonical: @@ -544,13 +580,16 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels - def __call__(self, data_array: torch.Tensor) -> torch.Tensor: + def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. If input type is `torch.Tensor`, original affine is assumed to be identity. Args: data_array: in shape (num_channels, H[, W, ...]). + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``data_array`` has no spatial dimensions. @@ -595,7 +634,10 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) - return orientation(data_array, affine_np, spatial_ornt, self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return orientation( + data_array, affine_np, spatial_ornt, lazy=lazy_, transform_info=self.get_transform_info() + ) # type: ignore[no-any-return] def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -622,21 +664,28 @@ class Flip(InvertibleTransform, LazyTransform): If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] - def __init__(self, spatial_axis: Sequence[int] | int | None = None) -> None: + def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: + LazyTransform.__init__(self, lazy=lazy) self.spatial_axis = spatial_axis - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - return flip(img, self.spatial_axis, transform_info=self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) @@ -677,6 +726,8 @@ class Resize(InvertibleTransform, LazyTransform): anti-aliasing is performed prior to rescaling. dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -690,7 +741,9 @@ def __init__( anti_aliasing: bool = False, anti_aliasing_sigma: Sequence[float] | float | None = None, dtype: DtypeLike | torch.dtype = torch.float32, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size self.mode = mode @@ -707,6 +760,7 @@ def __call__( anti_aliasing: bool | None = None, anti_aliasing_sigma: Sequence[float] | float | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -729,7 +783,9 @@ def __call__( anti-aliasing is performed prior to rescaling. dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. @@ -760,6 +816,7 @@ def __call__( _mode = self.mode if mode is None else mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + lazy_ = self.lazy if lazy is None else lazy return resize( # type: ignore img, sp_size, @@ -769,6 +826,7 @@ def __call__( input_ndim, anti_aliasing, anti_aliasing_sigma, + lazy_, self.get_transform_info(), ) @@ -814,6 +872,8 @@ class Rotate(InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -826,7 +886,9 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike | torch.dtype = torch.float32, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.angle = angle self.keep_size = keep_size self.mode: str = mode @@ -841,6 +903,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -858,6 +921,9 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -870,8 +936,17 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] output_shape = im_shape if self.keep_size else None + lazy_ = self.lazy if lazy is None else lazy return rotate( # type: ignore - img, self.angle, output_shape, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() + img, + self.angle, + output_shape, + _mode, + _padding_mode, + _align_corners, + _dtype, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -934,9 +1009,10 @@ class Zoom(InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. keep_size: Should keep original size (padding/slicing if needed), default is True. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. - """ backend = [TransformBackends.TORCH] @@ -949,8 +1025,10 @@ def __init__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, + lazy: bool = False, **kwargs, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.zoom = zoom self.mode = mode self.padding_mode = padding_mode @@ -966,6 +1044,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -986,7 +1065,9 @@ def __call__( See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim @@ -994,8 +1075,17 @@ def __call__( _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + lazy_ = self.lazy if lazy is None else lazy return zoom( # type: ignore - img, _zoom, self.keep_size, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() + img, + _zoom, + self.keep_size, + _mode, + _padding_mode, + _align_corners, + _dtype, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1034,28 +1124,35 @@ class Rotate90(InvertibleTransform, LazyTransform): backend = [TransformBackends.TORCH] - def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1)) -> None: + def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False) -> None: """ Args: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.") self.spatial_axes = spatial_axes_ - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axes) - return rotate90(img, axes, self.k, self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return rotate90(img, axes, self.k, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -1078,7 +1175,9 @@ class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): backend = Rotate90.backend - def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1)) -> None: + def __init__( + self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False + ) -> None: """ Args: prob: probability of rotating. @@ -1086,8 +1185,11 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, i max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`, (Default 3). spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.max_k = max_k self.spatial_axes = spatial_axes @@ -1099,23 +1201,27 @@ def randomize(self, data: Any | None = None) -> None: return None self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ + if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: - xform = Rotate90(self._rand_k, self.spatial_axes) - xform.lazy_evaluation = self.lazy_evaluation + xform = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) out = xform(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1152,6 +1258,8 @@ class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Rotate.backend @@ -1167,8 +1275,10 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike | torch.dtype = np.float32, + lazy: bool = False, ) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -1205,6 +1315,7 @@ def __call__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, randomize: bool = True, + lazy: bool | None = None, ): """ Args: @@ -1221,10 +1332,14 @@ def __call__( If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: ndim = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) rotator = Rotate( @@ -1234,12 +1349,12 @@ def __call__( padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, + lazy=lazy_, ) - rotator.lazy_evaluation = self.lazy_evaluation out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1258,30 +1373,37 @@ class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): Args: prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None) -> None: + def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: RandomizableTransform.__init__(self, prob) - self.flipper = Flip(spatial_axis=spatial_axis) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize(None) - out = self.flipper(img) if self._do_transform else img + lazy_ = self.lazy if lazy is None else lazy + out = self.flipper(img, lazy=lazy_) if self._do_transform else img out = convert_to_tensor(out, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1300,20 +1422,22 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): Args: prob: Probability of flipping. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1) -> None: + def __init__(self, prob: float = 0.1, lazy: bool = False) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self._axis: int | None = None self.flipper = Flip(spatial_axis=self._axis) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) @@ -1321,21 +1445,25 @@ def randomize(self, data: NdarrayOrTensor) -> None: return None self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize(data=img) + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: self.flipper.spatial_axis = self._axis - out = self.flipper(img) + out = self.flipper(img, lazy=lazy_) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1379,6 +1507,8 @@ class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1396,9 +1526,11 @@ def __init__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, + lazy: bool = False, **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): @@ -1434,6 +1566,7 @@ def __call__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, randomize: bool = True, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -1454,12 +1587,15 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. randomize: whether to execute `randomize()` function first, default to True. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ # match the spatial image dim if randomize: self.randomize(img=img) + lazy_ = self.lazy if lazy is None else lazy if not self._do_transform: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) else: @@ -1470,11 +1606,11 @@ def __call__( padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype, + lazy=lazy_, **self.kwargs, ) - xform.lazy_evaluation = self.lazy_evaluation out = xform(img) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1513,7 +1649,8 @@ class AffineGrid(LazyTransform): affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -1528,7 +1665,9 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, affine: NdarrayOrTensor | None = None, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params @@ -1540,7 +1679,7 @@ def __init__( self.affine = affine def __call__( - self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None + self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None, lazy: bool | None = None ) -> tuple[torch.Tensor | None, torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. @@ -1550,12 +1689,15 @@ def __call__( Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if not self.lazy_evaluation: + lazy_ = self.lazy if lazy is None else lazy + if not lazy_: if grid is None: # create grid from spatial_size if spatial_size is None: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1584,7 +1726,7 @@ def __call__( else: affine = self.affine # type: ignore affine = to_affine_nd(spatial_dims, affine) - if self.lazy_evaluation: + if lazy_: return None, affine affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore @@ -1615,6 +1757,7 @@ def __init__( scale_range: RandRange = None, device: torch.device | None = None, dtype: DtypeLike = np.float32, + lazy: bool = False, ) -> None: """ Args: @@ -1643,6 +1786,8 @@ def __init__( device: device to store the output grid data. dtype: data type for the grid computation. Defaults to ``np.float32``. If ``None``, use the data type of input data (if `grid` is provided). + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1651,6 +1796,7 @@ def __init__( - :py:meth:`monai.transforms.utils.create_scale` """ + LazyTransform.__init__(self, lazy=lazy) self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) self.translate_range = ensure_tuple(translate_range) @@ -1683,19 +1829,27 @@ def randomize(self, data: Any | None = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Sequence[int] | None = None, grid: NdarrayOrTensor | None = None, randomize: bool = True + self, + spatial_size: Sequence[int] | None = None, + grid: NdarrayOrTensor | None = None, + randomize: bool = True, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. randomize: boolean as to whether the grid parameters governing the grid should be randomized. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. """ if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy affine_grid = AffineGrid( rotate_params=self.rotate_params, shear_params=self.shear_params, @@ -1703,9 +1857,9 @@ def __call__( scale_params=self.scale_params, device=self.device, dtype=self.dtype, + lazy=lazy_, ) - affine_grid.lazy_evaluation = self.lazy_evaluation - if self.lazy_evaluation: # return the affine only, don't construct the grid + if lazy_: # return the affine only, don't construct the grid self.affine = affine_grid(spatial_size, grid)[1] # type: ignore return None # type: ignore _grid: torch.Tensor @@ -1943,6 +2097,7 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, image_only: bool = False, + lazy: bool = False, ) -> None: """ The affine transformations are applied in rotate, shear, translate, scale order. @@ -1997,8 +2152,10 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html image_only: if True return only the image volume, otherwise return (image, affine). - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.affine_grid = AffineGrid( rotate_params=rotate_params, shear_params=shear_params, @@ -2008,6 +2165,7 @@ def __init__( dtype=dtype, align_corners=align_corners, device=device, + lazy=lazy, ) self.image_only = image_only self.norm_coord = not normalized @@ -2016,10 +2174,10 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self.affine_grid.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self.affine_grid.lazy = val + self._lazy = val def __call__( self, @@ -2027,6 +2185,7 @@ def __call__( spatial_size: Sequence[int] | int | None = None, mode: str | int | None = None, padding_mode: str | None = None, + lazy: bool | None = None, ) -> torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor]: """ Args: @@ -2048,13 +2207,17 @@ def __call__( When `mode` is an integer, using numpy/cupy backends, this argument accepts {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) + lazy_ = self.lazy if lazy is None else lazy _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode - grid, affine = self.affine_grid(spatial_size=sp_size) + grid, affine = self.affine_grid(spatial_size=sp_size, lazy=lazy_) return affine_func( # type: ignore img, @@ -2066,7 +2229,8 @@ def __call__( _padding_mode, True, self.image_only, - self.get_transform_info(), + lazy=lazy_, + transform_info=self.get_transform_info(), ) @classmethod @@ -2125,6 +2289,7 @@ def __init__( padding_mode: str = GridSamplePadMode.REFLECTION, cache_grid: bool = False, device: torch.device | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -2174,6 +2339,8 @@ def __init__( If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. device: device on which the tensor will be allocated. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. @@ -2181,32 +2348,33 @@ def __init__( """ RandomizableTransform.__init__(self, prob) - + LazyTransform.__init__(self, lazy=lazy) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, device=device, + lazy=lazy, ) self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid - self._cached_grid = self._init_identity_cache() + self._cached_grid = self._init_identity_cache(lazy) self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rand_affine_grid.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rand_affine_grid.lazy = val - def _init_identity_cache(self): + def _init_identity_cache(self, lazy: bool): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ - if self.lazy_evaluation: + if lazy: return None if self.spatial_size is None: if self.cache_grid: @@ -2226,14 +2394,14 @@ def _init_identity_cache(self): return None return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") - def get_identity_grid(self, spatial_size: Sequence[int]): + def get_identity_grid(self, spatial_size: Sequence[int], lazy: bool): """ Return a cached or new identity grid depends on the availability. Args: spatial_size: non-dynamic spatial size """ - if self.lazy_evaluation: + if lazy: return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( @@ -2265,6 +2433,7 @@ def __call__( padding_mode: str | None = None, randomize: bool = True, grid=None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -2288,7 +2457,9 @@ def __call__( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html randomize: whether to execute `randomize()` function first, default to True. grid: precomputed grid to be used (mainly to accelerate `RandAffined`). - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize() @@ -2299,17 +2470,18 @@ def __call__( do_resampling = self._do_transform or (sp_size != ensure_tuple(ori_size)) _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode + lazy_ = self.lazy if lazy is None else lazy img = convert_to_tensor(img, track_meta=get_track_meta()) - if self.lazy_evaluation: + if lazy_: if self._do_transform: affine = self.rand_affine_grid.get_transformation_matrix() else: affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] else: if grid is None: - grid = self.get_identity_grid(sp_size) + grid = self.get_identity_grid(sp_size, lazy_) if self._do_transform: - grid = self.rand_affine_grid(grid=grid, randomize=randomize) + grid = self.rand_affine_grid(grid=grid, randomize=randomize, lazy=lazy_) affine = self.rand_affine_grid.get_transformation_matrix() return affine_func( # type: ignore img, @@ -2321,7 +2493,8 @@ def __call__( _padding_mode, do_resampling, True, - self.get_transform_info(), + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -2436,6 +2609,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, + lazy=False, ) self.resampler = Resample(device=device) @@ -2603,6 +2777,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, + lazy=False, ) self.resampler = Resample(device=device) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2f34f57ca2..b0fd53d3cd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -169,6 +169,7 @@ def __init__( dtype: Sequence[DtypeLike] | DtypeLike = np.float64, dst_keys: KeysCollection | None = "dst_affine", allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -196,21 +197,37 @@ def __init__( It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) - self.sp_transform = SpatialResample() + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.sp_transform = SpatialResample(lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.sp_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.sp_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d: dict = dict(data) for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.dst_keys @@ -223,6 +240,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) return d @@ -247,6 +265,7 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike] | DtypeLike = np.float64, allow_missing_keys: bool = False, + lazy: bool = False, ): """ Args: @@ -274,21 +293,37 @@ def __init__( the output data type is always ``float32``. It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.key_dst = key_dst self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.resampler = ResampleToMatch() + self.resampler = ResampleToMatch(lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.resampler.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.resampler.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -300,6 +335,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) return d @@ -341,6 +377,7 @@ def __init__( max_pixdim: Sequence[float] | float | None = None, ensure_same_shape: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -400,11 +437,18 @@ def __init__( ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim, whether to ensure exactly the same output spatial shape. Default to True. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.spacing_transform = Spacing( - pixdim, diagonal=diagonal, recompute_affine=recompute_affine, min_pixdim=min_pixdim, max_pixdim=max_pixdim + pixdim, + diagonal=diagonal, + recompute_affine=recompute_affine, + min_pixdim=min_pixdim, + max_pixdim=max_pixdim, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -413,16 +457,29 @@ def __init__( self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) self.ensure_same_shape = ensure_same_shape - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.spacing_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.spacing_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d: dict = dict(data) _init_shape, _pixdim, should_match = None, None, False output_shape_k = None # tracking output shape + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent @@ -442,6 +499,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc dtype=dtype, scale_extent=scale_extent, output_spatial_shape=output_shape_k if should_match else None, + lazy=lazy_, ) output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d @@ -471,6 +529,7 @@ def __init__( as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -484,23 +543,41 @@ def __init__( (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See Also: `nibabel.orientations.ornt2axcodes`. """ - super().__init__(keys, allow_missing_keys) - self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.ornt_transform = Orientation( + axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels, lazy=lazy + ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.ornt_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.ornt_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d: dict = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.ornt_transform(d[key]) + d[key] = self.ornt_transform(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -518,7 +595,12 @@ class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): backend = Rotate90.backend def __init__( - self, keys: KeysCollection, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False + self, + keys: KeysCollection, + k: int = 1, + spatial_axes: tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -526,19 +608,35 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) - self.rotator = Rotate90(k, spatial_axes) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.rotator = Rotate90(k, spatial_axes, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rotator.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rotator.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.rotator(d[key]) + d[key] = self.rotator(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -564,6 +662,7 @@ def __init__( max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -576,9 +675,12 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.max_k = max_k self.spatial_axes = spatial_axes @@ -589,17 +691,31 @@ def randomize(self, data: Any | None = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> Mapping[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ self.randomize() d = dict(data) # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests - rotator = Rotate90(self._rand_k, self.spatial_axes) - rotator.lazy_evaluation = self.lazy_evaluation + lazy_ = self.lazy if lazy is None else lazy + rotator = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -649,6 +765,8 @@ class Resized(MapTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Resize.backend @@ -664,22 +782,37 @@ def __init__( anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys)) self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys)) - self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) + self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.resizer.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.resizer.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma, dtype in self.key_iterator( d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype ): @@ -690,6 +823,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc anti_aliasing=anti_aliasing, anti_aliasing_sigma=anti_aliasing_sigma, dtype=dtype, + lazy=lazy_, ) return d @@ -722,6 +856,7 @@ def __init__( dtype: DtypeLike | torch.dtype = np.float32, align_corners: bool = False, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -772,6 +907,8 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`monai.transforms.compose.MapTransform` @@ -779,6 +916,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.affine = Affine( rotate_params=rotate_params, shear_params=shear_params, @@ -789,19 +927,33 @@ def __init__( device=device, dtype=dtype, # type: ignore align_corners=align_corners, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.affine.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.affine.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode) + d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -832,6 +984,7 @@ def __init__( cache_grid: bool = False, device: torch.device | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -885,6 +1038,8 @@ def __init__( accelerate the transform. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`monai.transforms.compose.MapTransform` @@ -893,6 +1048,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.rand_affine = RandAffine( prob=1.0, # because probability handled in this class rotate_range=rotate_range, @@ -902,21 +1058,36 @@ def __init__( spatial_size=spatial_size, cache_grid=cache_grid, device=device, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rand_affine.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rand_affine.lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffined: self.rand_affine.set_random_state(seed, state) super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor], lazy: bool | None = None + ) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -929,6 +1100,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N item = d[first_key] spatial_size = item.peek_pending_shape() if isinstance(item, MetaTensor) else item.shape[1:] + lazy_ = self.lazy if lazy is None else lazy sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform @@ -936,18 +1108,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid - grid = self.rand_affine.get_identity_grid(sp_size) + grid = self.rand_affine.get_identity_grid(sp_size, lazy=lazy_) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy=lazy_) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform if do_resampling: - d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid) # type: ignore + d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid, lazy=lazy_) # type: ignore else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: @@ -1066,6 +1238,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) @@ -1208,6 +1389,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) @@ -1250,25 +1440,45 @@ class Flipd(MapTransform, InvertibleTransform, LazyTransform): keys: Keys to pick data for transformation. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend def __init__( - self, keys: KeysCollection, spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False + self, + keys: KeysCollection, + spatial_axis: Sequence[int] | int | None = None, + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.flipper = Flip(spatial_axis=spatial_axis) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1290,6 +1500,8 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend @@ -1300,30 +1512,45 @@ def __init__( prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.flipper = Flip(spatial_axis=spatial_axis) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipd: super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1348,27 +1575,43 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, La keys: Keys to pick data for transformation. prob: Probability of flipping. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = RandAxisFlip.backend - def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False + ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.flipper = RandAxisFlip(prob=1.0) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = RandAxisFlip(prob=1.0, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAxisFlipd: super().set_random_state(seed, state) self.flipper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -1379,12 +1622,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc # all the keys share the same random selected axis self.flipper.randomize(d[first_key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False) + d[key] = self.flipper(d[key], randomize=False, lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1423,6 +1667,8 @@ class Rotated(MapTransform, InvertibleTransform, LazyTransform): the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Rotate.backend @@ -1437,27 +1683,42 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) - self.rotator = Rotate(angle=angle, keep_size=keep_size) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.rotator = Rotate(angle=angle, keep_size=keep_size, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rotator.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rotator.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): d[key] = self.rotator( - d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ ) return d @@ -1501,6 +1762,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = RandRotate.backend @@ -1518,31 +1781,49 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_rotate = RandRotate(range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size) + LazyTransform.__init__(self, lazy=lazy) + self.rand_rotate = RandRotate( + range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy + ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rand_rotate.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rand_rotate.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRotated: super().set_random_state(seed, state) self.rand_rotate.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) # all the keys share the same random rotate angle self.rand_rotate.randomize() + lazy_ = self.lazy if lazy is None else lazy + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): @@ -1554,10 +1835,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc align_corners=align_corners, dtype=dtype, randomize=False, + lazy=lazy_, ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1598,6 +1880,8 @@ class Zoomd(MapTransform, InvertibleTransform, LazyTransform): If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1615,26 +1899,44 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, lazy=lazy, **kwargs) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.zoomer.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.zoomer.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + d[key] = self.zoomer( + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ + ) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1680,9 +1982,10 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - """ backend = RandZoom.backend @@ -1699,27 +2002,43 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, **kwargs) + LazyTransform.__init__(self, lazy=lazy) + self.rand_zoom = RandZoom( + prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs + ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rand_zoom.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rand_zoom.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomd: super().set_random_state(seed, state) self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -1730,6 +2049,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc # all the keys share the same random zoom factor self.rand_zoom.randomize(d[first_key]) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1742,10 +2062,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc align_corners=align_corners, dtype=dtype, randomize=False, + lazy=lazy_, ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1798,7 +2119,6 @@ def __init__( It also can be a sequence, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. - """ super().__init__(keys, allow_missing_keys) self.grid_distortion = GridDistortion(num_cells=num_cells, distort_steps=distort_steps, device=device) @@ -1806,6 +2126,15 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) @@ -1872,6 +2201,15 @@ def set_random_state( return self def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) if not self._do_transform: @@ -1922,6 +2260,15 @@ def __init__( self.splitter = GridSplit(grid=grid) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> list[dict[Hashable, NdarrayOrTensor]]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) n_outputs = np.prod(self.grid) output: list[dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] @@ -2003,6 +2350,15 @@ def __init__( ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) for key in self.key_iterator(d): d[key] = self.patcher(d[key]) @@ -2091,6 +2447,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) # All the keys share the same random noise for key in self.key_iterator(d): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 591ebbb489..9d77a83389 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -66,12 +66,12 @@ def _maybe_new_metatensor(img, dtype=None, device=None): def spatial_resample( - img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, transform_info + img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info ) -> torch.Tensor: """ Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be resampled, assuming `img` is channel-first. @@ -92,6 +92,7 @@ def spatial_resample( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype_pt: data `dtype` for resampling computation. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -132,17 +133,16 @@ def spatial_resample( affine_unchanged = ( allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size) ) or (allclose(xform, np.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size)) - lazy_evaluation = transform_info.get(TraceKeys.LAZY_EVALUATION, False) meta_info = TraceableTransform.track_transform_meta( img, sp_size=spatial_size, - affine=None if affine_unchanged and not lazy_evaluation else xform, + affine=None if affine_unchanged and not lazy else xform, extra_info=extra_info, orig_size=original_spatial_shape, transform_info=transform_info, - lazy_evaluation=lazy_evaluation, + lazy=lazy, ) - if lazy_evaluation: + if lazy: out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if affine_unchanged: @@ -184,17 +184,18 @@ def spatial_resample( return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def orientation(img, original_affine, spatial_ornt, transform_info): +def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> torch.Tensor: """ Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. original_affine: original affine of the input image. spatial_ornt: orientations of the spatial axes, see also https://nipy.org/nibabel/reference/nibabel.orientations.html + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -217,23 +218,23 @@ def orientation(img, original_affine, spatial_ornt, transform_info): extra_info=extra_info, orig_size=spatial_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if axes: out = torch.flip(out, dims=axes) if not np.all(full_transpose == np.arange(len(out.shape))): out = out.permute(full_transpose.tolist()) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def flip(img, sp_axes, transform_info): +def flip(img, sp_axes, lazy, transform_info): """ Functional implementation of flip. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -242,6 +243,7 @@ def flip(img, sp_axes, transform_info): If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -255,25 +257,22 @@ def flip(img, sp_axes, transform_info): sp = axis - 1 xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=sp_size, - affine=xform, - extra_info=extra_info, - transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info out = torch.flip(out, axes) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, transform_info): +def resize( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +): """ Functional implementation of resize. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -291,6 +290,7 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` anti_aliasing_sigma: {float, tuple of floats}, optional Standard deviation for Gaussian filtering used when anti-aliasing. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ img = convert_to_tensor(img, track_meta=get_track_meta()) @@ -308,10 +308,10 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, extra_info=extra_info, orig_size=orig_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - if anti_aliasing and transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: + if anti_aliasing and lazy: warnings.warn("anti-aliasing is not compatible with lazy evaluation.") out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -340,11 +340,11 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, transform_info): +def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of rotate. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -360,6 +360,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t dtype: data type for resampling computation. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -393,10 +394,10 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t extra_info=extra_info, orig_size=im_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info _, _m, _p, _ = resolves_modes(mode, padding_mode) xform = AffineTransform( @@ -410,11 +411,11 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, transform_info): +def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of zoom. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -432,6 +433,7 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, dtype: data type for resampling computation. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -447,9 +449,9 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, } if keep_size: do_pad_crop = not np.allclose(output_size, im_shape) - if do_pad_crop and transform_info.get(TraceKeys.LAZY_EVALUATION, False): # update for lazy evaluation + if do_pad_crop and lazy: # update for lazy evaluation _pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode) - _pad_crop.lazy_evaluation = True + _pad_crop.lazy = True _tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1)) _tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform}) lazy_cropped = _pad_crop(_tmp_img) @@ -465,10 +467,10 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, extra_info=extra_info, orig_size=im_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info img_t = out.to(dtype) _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1) @@ -493,17 +495,18 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, return out -def rotate90(img, axes, k, transform_info): +def rotate90(img, axes, k, lazy, transform_info): """ Functional implementation of rotate90. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. If axis is negative it counts from the last to the first axis. k: number of times to rotate by 90 degrees. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ extra_info = {"axes": [d - 1 for d in axes], "k": k} @@ -533,20 +536,22 @@ def rotate90(img, axes, k, transform_info): extra_info=extra_info, orig_size=ori_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info out = torch.rot90(out, k, axes) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, transform_info): +def affine_func( + img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy, transform_info +): """ Functional implementation of affine. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -570,6 +575,7 @@ def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_re do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but skipping the actual (potentially heavy) resampling operation. image_only: if True return only the image volume, otherwise return (image, affine). + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -592,9 +598,9 @@ def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_re extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: out = _maybe_new_metatensor(img) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info return out if image_only else (out, affine) diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py index 0193065562..a47fc87e79 100644 --- a/monai/transforms/traits.py +++ b/monai/transforms/traits.py @@ -14,7 +14,9 @@ from __future__ import annotations -__all__ = ["LazyTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] +__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] + +from typing import Any class LazyTrait: @@ -27,23 +29,44 @@ class LazyTrait: """ @property - def lazy_evaluation(self): + def lazy(self): """ - Get whether lazy_evaluation is enabled for this transform instance. + Get whether lazy evaluation is enabled for this transform instance. Returns: True if the transform is operating in a lazy fashion, False if not. """ raise NotImplementedError() - @lazy_evaluation.setter - def lazy_evaluation(self, enabled: bool): + @lazy.setter + def lazy(self, enabled: bool): """ - Set whether lazy_evaluation is enabled for this transform instance. + Set whether lazy evaluation is enabled for this transform instance. Args: enabled: True if the transform should operate in a lazy fashion, False if not. """ raise NotImplementedError() + @property + def checks_data(self): + """ + Get whether the transform checks the sample pixel/voxel data on its inputs or not as part of its + operation. A transform that checks data requires that all of the pending operations on its input + transforms are up to date before it is executed, but it can still execute lazily by adding pending + operations to the input tensors. + Returns: + True if the transform checks data and False if it does not + """ + + +class InvertibleTrait: + """ + An interface to indicate that the transform can be inverted, i.e. undone by performing + the inverse of the operation performed during `__call__`. + """ + + def inverse(self, data: Any) -> Any: + raise NotImplementedError() + class RandomizableTrait: """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3e66431bbc..8b4795f7e6 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -25,6 +25,8 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor + +# from monai.transforms.lazy.executors import apply_pending_transforms_in_order, apply_pending_transforms_out_of_order from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends @@ -44,27 +46,66 @@ def _apply_transform( - transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False + transform: Callable[..., ReturnType], + data: Any, + unpack_parameters: bool = False, + lazy: bool | None = False, + lazy_strategy: str | None = "in_order", + overrides: dict | None = None, + logger_name: str | None = None, ) -> ReturnType: """ - Perform transformation `transform` with the provided parameters `parameters`. + Perform a transform 'transform' on 'data', according to the other parameters specified. + + If `data` is a tuple and `unpack_parameters` is True, each parameter of `data` is unpacked + as arguments to `transform`. Otherwise `data` is considered as single argument to `transform`. + + If 'lazy' is True, this method first checks whether it can execute this method lazily. If it + can't, it will ensure that all pending lazy transforms on 'data' are applied before applying + this 'transform' to it. If 'lazy' is True, and 'overrides' are provided, those overrides will + be applied to the pending operations on 'data'. See ``Compose`` for more details on lazy + resampling, which is an experimental feature for 1.2. - If `parameters` is a tuple and `unpack_items` is True, each parameter of `parameters` is unpacked - as arguments to `transform`. - Otherwise `parameters` is considered as single argument to `transform`. + Please note, this class is function is designed to be called by ``apply_transform``. + In general, you should not need to make specific use of it unless you are implementing + pipeline execution mechanisms. Args: transform: a callable to be used to transform `data`. - parameters: parameters for the `transform`. + data: the tensorlike or dictionary of tensorlikes to be executed on unpack_parameters: whether to unpack parameters for `transform`. Defaults to False. + lazy: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + carried out on a transform by transform basis. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. + lazy_strategy: this field controls how execution occurs when processing data lazily. Permitted + options are "in_order", "out_of_order". Please see `Compose`_ for more details of what these + options mean. In general, you should not need to change this from its default. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy + is True. If lazy is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_pending` and ``Compose`` for more details. + logger_name: The name of the logger that should be used during transform execution. If None, logging is + suppressed. Returns: ReturnType: The return type of `transform`. """ - if isinstance(parameters, tuple) and unpack_parameters: - return transform(*parameters) + from monai.transforms.lazy.executors import apply_pending_transforms_in_order, apply_pending_transforms_out_of_order - return transform(parameters) + if lazy_strategy == "in_order": + data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + elif lazy_strategy == "out_of_order": + data = apply_pending_transforms_out_of_order(transform, data, lazy, overrides, logger_name) + else: + raise ValueError(f"'lazy_strategy' must be one of {('in_order', 'out_of_order')} but is '{lazy_strategy}") + + if isinstance(data, tuple) and unpack_parameters: + return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + + return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) def apply_transform( @@ -72,7 +113,10 @@ def apply_transform( data: Any, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, + lazy: bool | None = False, + lazy_strategy: str = "in_order", + overrides: dict | None = None, + logger_name: str | None = None, ) -> list[ReturnType] | ReturnType: """ Transform `data` with `transform`. @@ -87,9 +131,12 @@ def apply_transform( map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. unpack_items: whether to unpack parameters using `*`. Defaults to False. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. + lazy: whether to execute in lazy mode or not. See ``Compose`` for more information about lazy resampling. + lazy_strategy: this field controls how execution occurs when processing data lazily. Permitted + options are "in_order", "out_of_order". Please see `Compose`_ for more details of what these + options mean. In general, you should not need to change this from its default. + overrides: optional overrides to apply to transform parameters. This parameter is ignored unless transforms + are being executed lazily. Raises: Exception: When ``transform`` raises an exception. @@ -99,18 +146,21 @@ def apply_transform( """ try: if isinstance(data, (list, tuple)) and map_items: - return [_apply_transform(transform, item, unpack_items) for item in data] - return _apply_transform(transform, data, unpack_items) + return [ + _apply_transform(transform, item, unpack_items, lazy, lazy_strategy, overrides, logger_name) + for item in data + ] + return _apply_transform(transform, data, unpack_items, lazy, lazy_strategy, overrides, logger_name) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats and not isinstance(transform, transforms.compose.Compose): + if logger_name is not None and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=logger_name) logger = logging.getLogger(datastats._logger_name) - logger.info(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") if isinstance(data, (list, tuple)): data = data[0] @@ -254,17 +304,26 @@ class LazyTransform(Transform, LazyTrait): dictionary transforms to simplify implementation of new lazy transforms. """ - _lazy_evaluation: bool = False + def __init__(self, lazy: bool | None = False): + if lazy is not None: + if not isinstance(lazy, bool): + raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") + self._lazy = lazy @property - def lazy_evaluation(self): - return self._lazy_evaluation + def lazy(self): + return self._lazy - @lazy_evaluation.setter - def lazy_evaluation(self, lazy_evaluation: bool): - if not isinstance(lazy_evaluation, bool): - raise TypeError(f"lazy_evaluation must be a bool but is of type {type(lazy_evaluation)}") - self._lazy_evaluation = lazy_evaluation + @lazy.setter + def lazy(self, lazy: bool | None): + if lazy is not None: + if not isinstance(lazy, bool): + raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") + self._lazy = lazy + + @property + def checks_data(self): + return False class RandomizableTransform(Randomizable, Transform): @@ -347,6 +406,7 @@ def __new__(cls, *args, **kwargs): return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + super().__init__() self.keys: tuple[Hashable, ...] = ensure_tuple(keys) self.allow_missing_keys = allow_missing_keys if not self.keys: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 53193f4cb6..0720eadcd7 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from functools import lru_cache, wraps from inspect import getmembers, isclass -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -28,7 +28,8 @@ from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij -from monai.transforms.compose import Compose, OneOf + +# from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform, apply_transform from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, @@ -52,6 +53,7 @@ PytorchPadMode, SplineMode, TraceKeys, + TraceStatusKeys, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -66,6 +68,9 @@ from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor +if TYPE_CHECKING: + from monai.transforms.compose import Compose + measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) morphology, has_morphology = optional_import("skimage.morphology") ndimage, _ = optional_import("scipy.ndimage") @@ -121,6 +126,7 @@ "sync_meta_info", "reset_ops_id", "resolves_modes", + "is_tensor_invertible", ] @@ -1343,6 +1349,8 @@ def allow_missing_keys_mode(transform: MapTransform | Compose | tuple[MapTransfo with allow_missing_keys_mode(t): _ = t(data) # OK! """ + from monai.transforms.compose import Compose + # If given a sequence of transforms, Compose them to get a single list if issequenceiterable(transform): transform = Compose(transform) @@ -1549,6 +1557,7 @@ def get_number_image_type_conversions(transform: Compose, test_data: Any, key: H test_data: data to be used to count the number of conversions key: if using dictionary transforms, this key will be used to check the number of conversions. """ + from monai.transforms.compose import OneOf def _get_data(obj, key): return obj if key is None else obj[key] @@ -1970,5 +1979,80 @@ def resolves_modes( return backend, _interp_mode, _padding_mode, _kwargs +def check_applied_operations(entry: list | dict, status_key: str, default_message: str = "No message provided"): + """ + Check the operations of a MetaTensor to determine whether there are any statuses + Args: + entry: a dictionary that may contain TraceKey.STATUS entries, or a list of such dictionaries + status_key: the status key to search for. This must be an entry in `TraceStatusKeys`_ + default_message: The message to provide if no messages are provided for the given status key entry + + Returns: + A list of status messages matching the providing status key + + """ + if isinstance(entry, list): + results = list() + for sub_entry in entry: + results.extend(check_applied_operations(sub_entry, status_key, default_message)) + return results + else: + status_key_ = TraceStatusKeys(status_key) + if TraceKeys.STATUSES in entry: + if status_key_ in entry[TraceKeys.STATUSES]: + reason = entry[TraceKeys.STATUSES][status_key_] + if reason is None: + return [default_message] + return reason if isinstance(reason, list) else [reason] + return [] + + +def is_tensor_invertible(data: torch.Tensor): + """ + Checks whether a given tensor is invertible. The rules are as follows: + 1. If the tensor is not a MetaTensor, it is not invertible + 2. If the tensor is a MetaTensor but it has `TraceStatusKeys.PENDING_DURING_APPLY` in the `TraceKeys.STATUS` of any + of its `applied_operations` it is not invertible + 3. Otherwise, it is invertible + + This function also accepts: + * dictionaries of tensors + * lists or tuples of tensors + * list or tuples of dictionaries of tensors + In any of the above scenarios, it iterates through the collections and executes itself recursively until it is + operating on tensors. + + Args: + data: a `torch.Tensor` or `MetaTensor` or collections of torch.Tensor or MetaTensor, as described above + + Returns: + A tuple. The first entry is `False` or `True`. The second entry is the status messages that can be used for the + user to help debug their pipelines. + + """ + invert_disabled_reasons = list() + if isinstance(data, (list, tuple)): + for d in data: + _, reasons = is_tensor_invertible(d) + if reasons is not None: + invert_disabled_reasons.extend(reasons) + elif isinstance(data, monai.data.MetaTensor): + for op in data.applied_operations: + invert_disabled_reasons.extend( + check_applied_operations( + op, TraceStatusKeys.PENDING_DURING_APPLY, "Pending operations while applying an operation" + ) + ) + elif isinstance(data, dict): + for d in data.values(): + _, reasons = is_tensor_invertible(d) + if reasons is not None: + invert_disabled_reasons.extend(reasons) + + if len(invert_disabled_reasons) > 0: + return False, invert_disabled_reasons + return True, None + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 834e4866d7..4a8e439f0a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -55,6 +55,7 @@ SplineMode, StrEnum, TraceKeys, + TraceStatusKeys, TransformBackends, UpsampleMode, Weight, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index a7ea9e29a8..25c747ed90 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -36,6 +36,7 @@ "SkipMode", "Method", "TraceKeys", + "TraceStatusKeys", "CommonKeys", "GanKeys", "PostFix", @@ -316,7 +317,14 @@ class TraceKeys(StrEnum): KEY_SUFFIX: str = "_transforms" NONE: str = "none" TRACING: str = "tracing" - LAZY_EVALUATION: str = "lazy_evaluation" + STATUSES: str = "statuses" + LAZY: str = "lazy" + + +class TraceStatusKeys(StrEnum): + """Enumerable status keys for the TraceKeys.STATUS flag""" + + PENDING_DURING_APPLY = "pending_during_apply" class CommonKeys(StrEnum): diff --git a/tests/croppers.py b/tests/croppers.py index 156600f202..8c9b43bf0a 100644 --- a/tests/croppers.py +++ b/tests/croppers.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Randomizable -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -117,18 +117,19 @@ def crop_test_pending_ops(self, input_param, input_shape, align_corners=False): expected = result_non_lazy["img"] if is_map else result_non_lazy self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(input_data) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(crop_fn, MapTransform): - crop_fn.lazy_evaluation = False + crop_fn.lazy = False inverted = crop_fn.inverse(result) self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations)) self.assertEqual(inverted.shape, im.shape) @@ -155,7 +156,7 @@ def crop_test_combine_ops(self, funcs, input_shape): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True if isinstance(_func, Randomizable): _func.set_random_state(seed=123) pending_result = _func(pending_result) @@ -164,7 +165,8 @@ def crop_test_combine_ops(self, funcs, input_shape): assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/lazy_transforms_utils.py b/tests/lazy_transforms_utils.py index f1f8708285..1681e26037 100644 --- a/tests/lazy_transforms_utils.py +++ b/tests/lazy_transforms_utils.py @@ -15,7 +15,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import InvertibleTransform, MapTransform, Randomizable -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import assert_allclose apply_transforms_kwargs = ("pending", "mode", "padding_mode", "dtype", "align_corners") @@ -61,7 +61,7 @@ def test_resampler_lazy( if isinstance(resampler, Randomizable): resampler.set_random_state(seed=seed) set_track_meta(True) - resampler.lazy_evaluation = True + resampler.lazy = True pending_output = resampler(**deepcopy(call_param)) if output_idx is not None: expected_output, pending_output = expected_output[output_idx], pending_output[output_idx] @@ -73,7 +73,7 @@ def test_resampler_lazy( if not skip_shape_check: assert_allclose(lazy_out.peek_pending_shape(), non_lazy_out.shape[1:4]) apply_param = get_apply_param(init_param, call_param) - lazy_out = apply_transforms(lazy_out, **apply_param)[0] + lazy_out = apply_pending(lazy_out, overrides=apply_param)[0] assert_allclose(lazy_out, non_lazy_out, rtol=rtol, atol=atol) if ( isinstance(resampler, InvertibleTransform) @@ -82,10 +82,10 @@ def test_resampler_lazy( and isinstance(non_lazy_out, MetaTensor) and non_lazy_out.applied_operations ): - resampler.lazy_evaluation = False + resampler.lazy = False out = resampler.inverse(lazy_out.clone()) ref = resampler.inverse(non_lazy_out.clone()) assert_allclose(out.applied_operations, []) assert_allclose(out.pending_operations, []) assert_allclose(ref, out, type_test=False, rtol=1e-3, atol=1e-3) - resampler.lazy_evaluation = True + resampler.lazy = True diff --git a/tests/padders.py b/tests/padders.py index e21faabc10..02d7b40af6 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from monai.utils.enums import NumpyPadMode, PytorchPadMode from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -127,18 +127,19 @@ def pad_test_pending_ops(self, input_param, input_shape): expected = result_non_lazy["img"] if is_map else result_non_lazy self.assertIsInstance(expected, MetaTensor) # lazy - pad_fn.lazy_evaluation = True + pad_fn.lazy = True pending_result = pad_fn(input_data) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): - pad_fn.lazy_evaluation = False + pad_fn.lazy = False inverted = pad_fn.inverse(result) self.assertTrue((not inverted.pending_operations) and (not inverted.applied_operations)) self.assertEqual(inverted.shape, im.shape) @@ -161,13 +162,14 @@ def pad_test_combine_ops(self, funcs, input_shape, expected_shape): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True pending_result = _func(pending_result) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_affine.py b/tests/test_affine.py index e8f7f33b17..9c2f4197a6 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -20,7 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Affine, Resize -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.utils import optional_import from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion @@ -208,16 +208,18 @@ def test_affine_resize(self, s): def method_0(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + overrides = {"padding_mode": "border", "align_corners": ac} + out = apply_pending(out, overrides=overrides)[0] return out def method_1(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + overrides = {"mode": 1, "padding_mode": "nearest", "align_corners": ac} + out = apply_pending(out, overrides=overrides)[0] return out def method_2(im, ac): diff --git a/tests/test_apply.py b/tests/test_apply.py index cf74721267..4784d46413 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.utils import create_rotate from monai.utils import LazyAttr, convert_to_tensor from tests.utils import get_arange_img @@ -40,20 +40,20 @@ def single_2d_transform_cases(): class TestApply(unittest.TestCase): def _test_apply_impl(self, tensor, pending_transforms, expected_shape): - result = apply_transforms(tensor, pending_transforms) + result = apply_pending(tensor, pending_transforms) self.assertListEqual(result[1], pending_transforms) self.assertEqual(result[0].shape, expected_shape) def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): tensor_ = convert_to_tensor(tensor, track_meta=True) if pending_as_parameter: - result, transforms = apply_transforms(tensor_, pending_transforms) + result, transforms = apply_pending(tensor_, pending_transforms) else: for p in pending_transforms: tensor_.push_pending_operation(p) if not isinstance(p, dict): return - result, transforms = apply_transforms(tensor_) + result, transforms = apply_pending(tensor_) self.assertEqual(result.shape, expected_shape) SINGLE_TRANSFORM_CASES = single_2d_transform_cases() diff --git a/tests/test_compose.py b/tests/test_compose.py index 65b9d8fbfb..132d399fa6 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -11,17 +11,19 @@ from __future__ import annotations +import logging import sys import unittest from copy import deepcopy +from io import StringIO import numpy as np import torch from parameterized import parameterized +import monai.transforms as mt from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose, Flip, NormalizeIntensity, Rotate, Rotate90, Rotated, Zoom -from monai.transforms.compose import execute_compose +from monai.transforms.compose import ExecutionOptions, execute_compose from monai.transforms.transform import Randomizable from monai.utils import set_determinism @@ -37,7 +39,7 @@ def __call__(self, __unused): class TestCompose(unittest.TestCase): def test_empty_compose(self): - c = Compose() + c = mt.Compose() i = 1 self.assertEqual(c(i), 1) @@ -48,7 +50,7 @@ def a(i): def b(i): return i + "b" - c = Compose([a, b, a, b]) + c = mt.Compose([a, b, a, b]) self.assertEqual(c(""), "abab") def test_dict_compose(self): @@ -66,7 +68,7 @@ def b(d): data = {"a": 0, "b": 0} expected = {"a": 3, "b": 2} - self.assertDictEqual(Compose(transforms)(data), expected) + self.assertDictEqual(mt.Compose(transforms)(data), expected) self.assertDictEqual(execute_compose(data, transforms), expected) def test_list_dict_compose(self): @@ -89,7 +91,7 @@ def c(d): # transform to handle dict data transforms = [a, a, b, c, c] data = {"a": 0, "b": 0, "c": 0} expected = {"a": 2, "b": 1, "c": 2} - value = Compose(transforms)(data) + value = mt.Compose(transforms)(data) for item in value: self.assertDictEqual(item, expected) value = execute_compose(data, transforms) @@ -106,7 +108,7 @@ def b(i, i2): transforms = [a, b, a, b] data = ("", "") expected = ("abab", "a2b2a2b2") - self.assertEqual(Compose(transforms, map_items=False, unpack_items=True)(data), expected) + self.assertEqual(mt.Compose(transforms, map_items=False, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected) def test_list_non_dict_compose_with_unpack(self): @@ -119,7 +121,7 @@ def b(i, i2): transforms = [a, b, a, b] data = [("", ""), ("t", "t")] expected = [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")] - self.assertEqual(Compose(transforms, unpack_items=True)(data), expected) + self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) def test_list_dict_compose_no_map(self): @@ -143,7 +145,7 @@ def c(d): # transform to handle dict data transforms = [a, a, b, c, c] data = {"a": 0, "b": 0, "c": 0} expected = {"a": 2, "b": 1, "c": 2} - value = Compose(transforms, map_items=False)(data) + value = mt.Compose(transforms, map_items=False)(data) for item in value: self.assertDictEqual(item, expected) value = execute_compose(data, transforms, map_items=False) @@ -161,7 +163,7 @@ def __call__(self, data): self.randomize() return self.rand + data - c = Compose([_Acc(), _Acc()]) + c = mt.Compose([_Acc(), _Acc()]) self.assertNotAlmostEqual(c(0), c(0)) c.set_random_state(123) self.assertAlmostEqual(c(1), 1.61381597) @@ -177,17 +179,17 @@ def randomize(self, foo1, foo2): def __call__(self, data): pass - c = Compose([_RandomClass(), _RandomClass()]) + c = mt.Compose([_RandomClass(), _RandomClass()]) with self.assertWarns(Warning): c.randomize() def test_err_msg(self): - transforms = Compose([abs, AddChannel(), round], log_stats=False) + transforms = mt.Compose([abs, mt.AddChannel(), round]) with self.assertRaisesRegex(Exception, "AddChannel"): transforms(42.1) def test_data_loader(self): - xform_1 = Compose([_RandXform()]) + xform_1 = mt.Compose([_RandXform()]) train_ds = Dataset([1], transform=xform_1) xform_1.set_random_state(123) @@ -211,7 +213,7 @@ def test_data_loader(self): def test_data_loader_2(self): set_determinism(seed=123) - xform_2 = Compose([_RandXform(), _RandXform()]) + xform_2 = mt.Compose([_RandXform(), _RandXform()]) train_ds = Dataset([1], transform=xform_2) out_2 = train_ds[0] @@ -232,42 +234,47 @@ def test_data_loader_2(self): set_determinism(None) def test_flatten_and_len(self): - x = AddChannel() - t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + x = mt.AddChannel() + t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])]) t2 = t1.flatten() for t in t2.transforms: - self.assertNotIsInstance(t, Compose) + self.assertNotIsInstance(t, mt.Compose) # test len self.assertEqual(len(t1), 8) def test_backwards_compatible_imports(self): - from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 + from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], - [None, (Rotate(np.pi / 8),)], - [None, (Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())], - [("a",), (Rotated(("a",), np.pi / 8),)], + [None, (mt.Rotate(np.pi / 8),)], + [None, (mt.Flip(0), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity())], + [("a",), (mt.Rotated(("a",), np.pi / 8),)], ] class TestComposeExecute(unittest.TestCase): - @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) - def test_compose_execute_equivalence(self, keys, pipeline): + @staticmethod + def data_from_keys(keys): if keys is None: - data = torch.unsqueeze(torch.tensor(np.arange(24 * 32).reshape(24, 32)), axis=0) + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) else: data = {} for i_k, k in enumerate(keys): - data[k] = torch.unsqueeze(torch.tensor(np.arange(24 * 32)).reshape(24, 32) + i_k * 768, axis=0) + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_equivalence(self, keys, pipeline): + data = self.data_from_keys(keys) - expected = Compose(deepcopy(pipeline))(data) + expected = mt.Compose(deepcopy(pipeline))(data) for cutoff in range(len(pipeline)): - c = Compose(deepcopy(pipeline)) + c = mt.Compose(deepcopy(pipeline)) actual = c(c(data, end=cutoff), start=cutoff) if isinstance(actual, dict): for k in actual.keys(): @@ -283,6 +290,308 @@ def test_compose_execute_equivalence(self, keys, pipeline): else: self.assertTrue(torch.allclose(expected, actual)) + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_bad_start_param(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, start=None) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=None) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, start=-1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=-1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_negative_range(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, start=2, end=1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=2, end=1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_bad_end_param(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, end=len(pipeline) + 1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), end=len(pipeline) + 1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_empty_range(self, keys, pipeline): + data = self.data_from_keys(keys) + + c = mt.Compose(deepcopy(pipeline)) + for i in range(len(pipeline)): + result = c(data, start=i, end=i) + self.assertIs(data, result) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_with_logger_name(self, keys, pipeline): + data = self.data_from_keys(keys) + + c = mt.Compose(deepcopy(pipeline), logger_name="a_logger_name") + c(data) + + +TEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES = [ + [ + None, + (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()), + False, + ( + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Spacing', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Zoom', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + ), + ], + [ + None, + ( + mt.Flip(0, lazy=True), + mt.Spacing((1.2, 1.2), lazy=True), + mt.Flip(1, lazy=True), + mt.Rotate90(1), + mt.Zoom(0.8, lazy=True), + mt.NormalizeIntensity(), + ), + None, + ( + "INFO - Accumulate pending transforms - lazy: None, pending: 0, " + "upcoming 'Flip', transform.lazy: True\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 1, " + "upcoming 'Spacing', transform.lazy: True\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 2, " + "upcoming 'Flip', transform.lazy: True\n" + "INFO - Apply pending transforms - lazy: None, pending: 3, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Pending transforms applied: applied_operations: 3\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 0, " + "upcoming 'Zoom', transform.lazy: True\n" + "INFO - Apply pending transforms - lazy: None, pending: 1, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + None, + (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Spacing', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 2, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 3, " + "upcoming 'Rotate90', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 4, " + "upcoming 'Zoom', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 5, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(("a", "b"), 0), + mt.Spacingd(("a", "b"), 1.2), + mt.Rotate90d(("a", "b"), 1), + mt.NormalizeIntensityd(("a",)), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 2, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 2, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy mode: True, key: 'a', pending: 3, " + "upcoming 'NormalizeIntensityd', transform is not lazy\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 3\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(keys="a", spatial_axis=0), + mt.Rotate90d(keys="b", k=1, allow_missing_keys=True), + mt.Zoomd(keys=("a", "b"), zoom=0.8, allow_missing_keys=True), + mt.Spacingd(keys="a", pixdim=1.2), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 0, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 1, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 1, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 2, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 2\n" + ), + ], + [ + None, + ( + mt.Flip(0), + mt.Spacing((1.2, 1.2)), + mt.Flip(1), + mt.ApplyPending(), + mt.Rotate90(1), + mt.Zoom(0.8), + mt.NormalizeIntensity(), + ), + False, + ( + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Spacing', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'ApplyPending', transform is not lazy\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Zoom', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + ), + ], + [ + None, + ( + mt.Flip(0), + mt.Spacing((1.2, 1.2)), + mt.Flip(1), + mt.ApplyPending(), + mt.Rotate90(1), + mt.Zoom(0.8), + mt.NormalizeIntensity(), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Spacing', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 2, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 3, " + "upcoming 'ApplyPending', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 3\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Zoom', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 2, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(keys="a", spatial_axis=0), + mt.Rotate90d(keys="b", k=1, allow_missing_keys=True), + mt.ApplyPendingd(keys=("a", "b")), + mt.Zoomd(keys=("a", "b"), zoom=0.8, allow_missing_keys=True), + mt.Spacingd(keys="a", pixdim=1.2), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 0, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy mode: True, key: 'a', pending: 1, " + "upcoming 'ApplyPendingd', transform is not lazy\n" + "INFO - Apply pending transforms - lazy mode: True, key: 'b', pending: 1, " + "upcoming 'ApplyPendingd', transform is not lazy\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 1\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 1\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 0, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'b', pending: 0, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy mode: True, key: 'a', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 2\n" + ), + ], +] + + +class TestComposeExecuteWithLogging(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.arange(12 * 16).reshape(1, 12, 16) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES) + def test_compose_with_logging(self, keys, pipeline, lazy, expected): + stream = StringIO() + handler = logging.StreamHandler(stream) + formatter = logging.Formatter("%(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger = logging.getLogger("a_logger_name") + logger.setLevel(logging.INFO) + while len(logger.handlers) > 0: + logger.removeHandler(logger.handlers[-1]) + logger.addHandler(handler) + + data = self.data_from_keys(keys) + c = mt.Compose(deepcopy(pipeline), lazy=lazy, logger_name="a_logger_name") + c(data) + + handler.flush() + actual = stream.getvalue() + self.assertEqual(actual, expected) + class TestOps: @staticmethod @@ -318,10 +627,20 @@ def _inner(data1, data2): class TestComposeExecuteWithFlags(unittest.TestCase): @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES) def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): - expected = Compose(pipeline, **flags)(data) + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(24 * 32).reshape(24, 32)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(24 * 32)).reshape(24, 32) + i_k * 768, dim=0) + return data + + expected = mt.Compose(pipeline, **flags)(data) for cutoff in range(len(pipeline)): - c = Compose(deepcopy(pipeline), **flags) + c = mt.Compose(deepcopy(pipeline), **flags) actual = c(c(data, end=cutoff), start=cutoff) if isinstance(actual, dict): for k in actual.keys(): @@ -338,16 +657,135 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): self.assertTrue(expected, actual) -TEST_LAZY_COMPOSE_PIPELINE_FIX_CASES = [[(Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())]] +TEST_LAZY_COMPOSE_PIPELINE_FIX_CASES = [ + [(mt.Flip(0), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity())] +] class TestLazyComposePipelineFixes(unittest.TestCase): @parameterized.expand(TEST_LAZY_COMPOSE_PIPELINE_FIX_CASES) def test_lazy_compose_pipeline_fixes(self, pipeline): data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) - c = Compose(deepcopy(pipeline), lazy_evaluation=True) + c = mt.Compose(deepcopy(pipeline), lazy=True) _ = c(data) +class TNonLazy(mt.Transform): + def __init__(self, tag): + self.tag = tag + + def __call__(self, data): + return data + + +class TLazy(mt.LazyTransform): + def __init__(self, tag, lazy): + super().__init__(lazy) + self.tag = tag + + def __call__(self, data): + return data + + +class TApplyPending(mt.ApplyPending): + def __init__(self, tag): + super().__init__() + self.tag = tag + + +TRANSFORM_REORDERING_TEST_CASES = [ + ( + [TNonLazy("a"), TLazy("lb", True), TLazy("lc", True), TApplyPending("ad"), TLazy("le", True), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["a", "lb", "lc", "ad", "f", "le"], + ["a", "lb", "lc", "ad", "f", "le"], + ), + ( + [TNonLazy("a"), TLazy("lb", True), TLazy("lc", True), TApplyPending("ad"), TLazy("le", False), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["a", "lb", "lc", "ad", "le", "f"], + ["a", "lb", "lc", "ad", "f", "le"], + ), + ( + [TLazy("la", True), TNonLazy("b"), TLazy("lc", True), TApplyPending("ad"), TLazy("le", True), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["b", "la", "lc", "ad", "f", "le"], + ["b", "la", "lc", "ad", "f", "le"], + ), + ( + [TLazy("la", False), TNonLazy("b"), TLazy("lc", True), TApplyPending("ad"), TLazy("le", True), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["la", "b", "lc", "ad", "f", "le"], + ["b", "la", "lc", "ad", "f", "le"], + ), + ( + [TLazy("la", True), TNonLazy("b"), TLazy("lc", True), TApplyPending("ad"), TLazy("le", True), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["b", "la", "lc", "ad", "f", "le"], + ["b", "la", "lc", "ad", "f", "le"], + ), + ( + [TNonLazy("a"), TLazy("lb", True), TLazy("lc", True), TApplyPending("ad"), TLazy("le", False), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["a", "lb", "lc", "ad", "le", "f"], + ["a", "lb", "lc", "ad", "f", "le"], + ), + ( + [TLazy("la", True), TLazy("lb", True), TNonLazy("c"), TApplyPending("ad"), TApplyPending("ae"), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["c", "la", "lb", "ad", "ae", "f"], + ["c", "la", "lb", "ad", "ae", "f"], + ), + ( + [TNonLazy("a"), TLazy("lb", True), TLazy("lc", True), TApplyPending("ad"), TLazy("le", True), TNonLazy("f")], + {"reorder": "lazy_last"}, + ["a", "lb", "lc", "ad", "f", "le"], + ["a", "lb", "lc", "ad", "f", "le"], + ), + ( + [ + TNonLazy("a"), + TLazy("lb", True), + TLazy("lc", True), + TApplyPending("ad"), + TLazy("le", True), + TApplyPending("af"), + TApplyPending("ag"), + ], + {"reorder": "lazy_last"}, + ["a", "lb", "lc", "ad", "le", "af", "ag"], + ["a", "lb", "lc", "ad", "le", "af", "ag"], + ), + ( + [TLazy("la", True), TLazy("lb", True), TNonLazy("c"), TLazy("ld", True)], + {"reorder": "lazy_last"}, + ["c", "la", "lb", "ld"], + ["c", "la", "lb", "ld"], + ), + ( + [TLazy("la", True), TLazy("lb", False), TNonLazy("c"), TLazy("ld", True)], + {"reorder": "lazy_last"}, + ["lb", "c", "la", "ld"], + ["c", "la", "lb", "ld"], + ), +] + + +class TestTransformReordering(unittest.TestCase): + @parameterized.expand(TRANSFORM_REORDERING_TEST_CASES) + def test_transform_reordering_test_cases(self, transforms, options, lazy_enabled_expected, lazy_on_expected): + with self.subTest("enable lazy"): + c = ExecutionOptions()(transforms, lazy=None, options={"reorder": "lazy_last"}) + reordered = [transforms[i] for i in c["indices"]] + actual = [t.tag for t in reordered] + self.assertListEqual(actual, lazy_enabled_expected) + + with self.subTest("force lazy"): + c = ExecutionOptions()(transforms, lazy=True, options={"reorder": "lazy_last"}) + reordered = [transforms[i] for i in c["indices"]] + actual = [t.tag for t in reordered] + self.assertListEqual(actual, lazy_on_expected) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 1ffdc9983e..4435b128ba 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -19,7 +19,7 @@ from monai.config import USE_COMPILED from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForeground -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_COORDS, TESTS, TEST_LAZY_ERROR = [], [], [] @@ -126,13 +126,14 @@ def test_pending_ops(self, input_param, image, _expected_data, align_corners): expected = crop_fn(image) self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) @@ -142,17 +143,18 @@ def test_lazy_error(self, input_param, image, _expected_data, align_corners): with self.assertRaises(ValueError): crop_fn = CropForeground(**input_param) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) - return apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + return apply_pending(pending_result, overrides=overrides)[0] @parameterized.expand(TEST_COORDS + TESTS) def test_inverse_pending_ops(self, input_param, image, _expected_data, align_corners): crop_fn = CropForeground(**input_param) - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) self.assertIsInstance(pending_result, MetaTensor) - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": align_corners})[0] inverted = crop_fn.inverse(result) self.assertEqual(image.shape, inverted.shape) self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations)) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index d2604ef9cf..776776f6c5 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForegroundd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_POSITION, TESTS = [], [] @@ -189,13 +189,14 @@ def test_pending_ops(self, input_param, image, _expected_data, align_corners): expected = crop_fn(image)["img"] self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_flip.py b/tests/test_flip.py index 287852c2c1..d7df55fde0 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -61,7 +61,7 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): init_param = {"spatial_axis": spatial_axis} xform = Flip(**init_param) call_param = {"img": img} - res = xform(**call_param) + res = xform(**call_param) # type: ignore[arg-type] self.assertEqual(img.shape, res.shape) if track_meta: test_resampler_lazy(xform, res, init_param, call_param) diff --git a/tests/test_integration_lazy_samples.py b/tests/test_integration_lazy_samples.py index 88e256f2dc..6afb756748 100644 --- a/tests/test_integration_lazy_samples.py +++ b/tests/test_integration_lazy_samples.py @@ -24,6 +24,7 @@ import monai import monai.transforms as mt from monai.data import create_test_image_3d, decollate_batch +from monai.transforms.utils import is_tensor_invertible from monai.utils import set_determinism from tests.utils import HAS_CUPY, DistTestCase, SkipIfBeforePyTorchVersion, skip_if_quick @@ -32,7 +33,9 @@ def _no_op(x): return x -def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None), num_workers=4, lazy=True): +def run_training_test( + root_dir, device="cuda:0", cachedataset=0, readers=(None, None), num_workers=4, lazy=True, options=None +): print(f"test case: {locals()}") images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) @@ -41,9 +44,13 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, num_workers = 0 if torch.cuda.is_available() else num_workers # define transforms for image and segmentation - lazy_kwargs = dict( - mode=("bilinear", 0), device=device, padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8) - ) + # lazy_kwargs = dict( + # mode=("bilinear", 0), device=device, padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8) + # ) + lazy_kwargs = { + "img": {"mode": "bilinear", "device": device, "padding_mode": "border", "dtype": torch.float32}, + "seg": {"mode": 0, "device": device, "padding_mode": "nearest", "dtype": torch.uint8}, + } train_transforms = mt.Compose( [ mt.LoadImaged(keys=["img", "seg"], reader=readers[0], image_only=True), @@ -58,7 +65,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, mt.Orientationd(keys=["img", "seg"], axcodes="ARS"), mt.RandRotate90d(keys=["img", "seg"], prob=1.0, spatial_axes=(1, 2)), mt.ScaleIntensityd(keys="img"), - mt.IdentityD(keys=["seg"]), + mt.ApplyPendingd(keys=["seg"]), mt.RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[76, 82, 80], pos=1, neg=1, num_samples=4 ), @@ -70,10 +77,9 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, mt.Rotated(keys=["img", "seg"], angle=[np.pi / 2, np.pi / 2, 0], mode="nearest", keep_size=False), mt.Lambdad(keys=["img"], func=_no_op), ], - lazy_evaluation=lazy, + lazy=lazy, + options=options, overrides=lazy_kwargs, - override_keys=("img", "seg"), - verbose=num_workers > 0, # testing both flags ) # create a training data loader @@ -115,6 +121,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, train_ds, batch_size=1, shuffle=True, num_workers=num_workers, generator=_g, persistent_workers=num_workers > 0 ) all_coords = set() + batch_data = None for epoch in range(5): print("-" * 10) print(f"Epoch {epoch + 1}/5") @@ -149,16 +156,13 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, saver(item) # just testing the saving saver(in_img) saver(in_seg) - if lazy: - inverted = 0 - try: - inverted = [inverter(b_data) for b_data in decollate_batch(batch_data)] - except RuntimeError as e: - if "Lambda" in str(e): - inverted = None - assert inverted is None, "invert LambdaD + lazy is not supported" + invertible, reasons = is_tensor_invertible(batch_data) + if options == {"reorder": "lazy_last_nosync"}: + assert invertible is False, f"the output of this pipeline with options {options} should not be invertible" else: - [inverter(b_data) for b_data in decollate_batch(batch_data)] # expecting no error + assert invertible is True, f"the output of this pipeline with options {options} should be invertible" + inverted = [inverter(b_data) for b_data in decollate_batch(batch_data)] # expecting no error + return ops @@ -193,27 +197,34 @@ def train_and_infer(self, idx=0): elif idx == 2: _readers = ("itkreader", "nibabelreader") _w = 0 - results = run_training_test( - self.data_dir, device=self.device, cachedataset=idx, readers=_readers, num_workers=_w, lazy=True - ) + results_expected = run_training_test( self.data_dir, device=self.device, cachedataset=0, readers=_readers, num_workers=_w, lazy=False ) - self.assertFalse(np.allclose(results, [0])) - self.assertFalse(np.allclose(results_expected, [0])) - np.testing.assert_allclose(results, results_expected) - lazy_files = glob(os.path.join(self.data_dir, "output", "*_True_*.nii.gz")) - regular_files = glob(os.path.join(self.data_dir, "output", "*_False_*.nii.gz")) - diffs = [] - for a, b in zip(sorted(lazy_files), sorted(regular_files)): - img_lazy = mt.LoadImage(image_only=True)(a) - img_regular = mt.LoadImage(image_only=True)(b) - diff = np.size(img_lazy) - np.sum(np.isclose(img_lazy, img_regular, atol=1e-4)) - diff_rate = diff / np.size(img_lazy) - diffs.append(diff_rate) - np.testing.assert_allclose(diff_rate, 0.0, atol=0.03) - print("volume diff:", diffs) - return results + for options in (None, {"reorder": "lazy_last_nosync"}): + results = run_training_test( + self.data_dir, + device=self.device, + cachedataset=idx, + readers=_readers, + num_workers=_w, + lazy=True, + options=options, + ) + self.assertFalse(np.allclose(results, [0])) + self.assertFalse(np.allclose(results_expected, [0])) + np.testing.assert_allclose(results, results_expected) + lazy_files = glob(os.path.join(self.data_dir, "output", "*_True_*.nii.gz")) + regular_files = glob(os.path.join(self.data_dir, "output", "*_False_*.nii.gz")) + diffs = [] + for a, b in zip(sorted(lazy_files), sorted(regular_files)): + img_lazy = mt.LoadImage(image_only=True)(a) + img_regular = mt.LoadImage(image_only=True)(b) + diff = np.size(img_lazy) - np.sum(np.isclose(img_lazy, img_regular, atol=1e-4)) + diff_rate = diff / np.size(img_lazy) + diffs.append(diff_rate) + np.testing.assert_allclose(diff_rate, 0.0, atol=0.03) + print("volume diff:", diffs) def test_training(self): for i in range(4): diff --git a/tests/test_invert.py b/tests/test_invert.py index 0d53b4bf61..b7c11362ce 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -90,16 +90,15 @@ def test_invert(self): set_determinism(seed=None) def test_invert_warn_pending(self): + # this test shouldn't raise a warning or error any more as that issue was fixed + # by https://github.com/Project-MONAI/MONAI/pull/6257 set_determinism(seed=0) im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete transform = Compose( - [LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Lambda(func=lambda x: x)], - lazy_evaluation=True, + [LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Lambda(func=lambda x: x)], lazy=True ) output = transform([im_fname for _ in range(2)]) - with self.assertRaises(RuntimeError): # transform id mismatch because of lambda - with self.assertWarns(Warning): # warning of wrong ordering lazy + nonlazy_invertible - transform.inverse(output) + transform.inverse(output) if __name__ == "__main__": diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index 1e7bea17d4..574fd49592 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -62,14 +62,7 @@ ] TEST_CASE_RECURSIVE_2 = [ torch.randn(3, 3), - Compose( - [ - ToNumpy(), - Flip(), - OneOf([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)], weights=[0, 1], log_stats=True), - ToTensor(), - ] - ), + Compose([ToNumpy(), Flip(), OneOf([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)], weights=[0, 1]), ToTensor()]), ] TEST_CASE_RECURSIVE_LIST = [ torch.randn(3, 3), @@ -167,7 +160,6 @@ def test_recursive_tranforms(self, input, transforms): # Check the outputs self.assertEqual(transforms.map_items, transforms_range.map_items) self.assertEqual(transforms.unpack_items, transforms_range.unpack_items) - self.assertEqual(transforms.log_stats, transforms_range.log_stats) np.testing.assert_equal(output.numpy(), output_r.numpy()) @parameterized.expand([TEST_CASE_RECURSIVE_LIST]) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 36980c23a7..2977b141ce 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -15,8 +15,12 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import ( InvertibleTransform, @@ -227,5 +231,37 @@ def test_one_of(self): self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0) +TEST_ONEOF_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestOneOfAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_ONEOF_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = OneOf(deepcopy(pipeline)) + c(data, start=1) + + with self.assertRaises(ValueError): + c = OneOf(deepcopy(pipeline)) + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_orientation.py b/tests/test_orientation.py index 6e89d085d2..aa1c326bdf 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -190,7 +190,7 @@ def test_ornt_meta( img = MetaTensor(img, affine=affine).to(device) ornt = Orientation(**init_param) call_param = {"data_array": img} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] if img.ndim in (3, 4): test_resampler_lazy(ornt, res, init_param, call_param) diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index ddb5dc3e98..cf4eb23d42 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -74,7 +74,7 @@ def test_orntd( img = MetaTensor(img, affine=affine) img = img.to(device) call_param = {"data": {k: img.clone() for k in ornt.keys}} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] for k in ornt.keys: if img.ndim in (3, 4): test_resampler_lazy(ornt, res, init_param, call_param, output_key=k) @@ -92,7 +92,7 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi expected_shape = img.shape expected_code = ornt.ornt_transform.axcodes call_param = {"data": {k: img.clone() for k in ornt.keys}} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] for k in ornt.keys: _im = res[k] np.testing.assert_allclose(_im.shape, expected_shape) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 5c1e2359e8..a607029c1a 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -234,7 +234,7 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): resampler = RandAffined(**lazy_init_param).set_random_state(123) expected_output = resampler(**call_param) test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key) - resampler.lazy_evaluation = False + resampler.lazy = False if input_param.get("cache_grid", False): self.assertTrue(g.rand_affine._cached_grid is not None) diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 457617fc19..81e42372db 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -33,7 +33,7 @@ def test_correct_results(self): # test lazy test_resampler_lazy(flip, result, call_param=call_param, seed=321) - flip.lazy_evaluation = False + flip.lazy = False expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] assert_allclose(result, p(np.stack(expected)), type_test="tensor") diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index e6fac5637f..75357b23e1 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -33,7 +33,7 @@ def test_correct_results(self): # test lazy test_resampler_lazy(flip, result, call_param=call_param, output_key="img", seed=1234) - flip.lazy_evaluation = False + flip.lazy = False test_local_inversion(flip, result, {"img": im}, "img") expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 6723dfc4c6..88d2631ca5 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndices, RandCropByLabelClasses -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS_INDICES, TESTS_SHAPE = [], [] @@ -154,14 +154,14 @@ def test_pending_ops(self, input_param, input_data, _expected_type, _expected_sh self.assertIsInstance(expected[0], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(**input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 8af1df5c42..748f26f1ff 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [] @@ -143,14 +143,14 @@ def test_pending_ops(self, input_param, input_data, _expected_type, _expected_sh self.assertIsInstance(expected[0]["img"], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result["img"], overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i]["img"], rtol=1e-5) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e1c4cdff58..98af6b0b5e 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabel -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ @@ -136,14 +136,14 @@ def test_pending_ops(self, input_param, input_data, _expected_shape): self.assertIsInstance(expected[0], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(**input_data_mod) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 11b7960617..1b57548d12 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabeld -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ @@ -153,15 +153,16 @@ def test_pending_ops(self, input_param, input_data, _expected_shape): self.assertIsInstance(expected[0]["image"], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data_mod) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["image"], MetaTensor) assert_allclose(_pending_result["image"].peek_pending_affine(), expected[i]["image"].affine) assert_allclose(_pending_result["image"].peek_pending_shape(), expected[i]["image"].shape[1:]) # only support nearest - result_image = apply_transforms(_pending_result["image"], mode="nearest", align_corners=False)[0] - result_extra = apply_transforms(_pending_result["extra"], mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result_image = apply_pending(_pending_result["image"], overrides=overrides)[0] + result_extra = apply_pending(_pending_result["extra"], overrides=overrides)[0] # compare assert_allclose(result_image, expected[i]["image"], rtol=1e-5) assert_allclose(result_extra, expected[i]["extra"], rtol=1e-5) diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index d67b4ca31b..be5394c172 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -37,7 +37,7 @@ def test_correct_results(self, _, spatial_axis): # test lazy test_resampler_lazy(flip, result, init_param, call_param, output_key="img") - flip.lazy_evaluation = False + flip.lazy = False expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 8bd697efe5..ca3eda3b12 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -91,7 +91,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, # test lazy test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) - rotate_fn.lazy_evaluation = False + rotate_fn.lazy = False _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -133,7 +133,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, # test lazy test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) - rotate_fn.lazy_evaluation = False + rotate_fn.lazy = False assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) test_local_inversion(rotate_fn, rotated, im) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 2504c0f01b..88f88bf422 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -37,7 +37,7 @@ def test_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) - rotate.lazy_evaluation = False + rotate.lazy = False def test_k(self): init_param = {"max_k": 2} @@ -60,7 +60,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) - rotate.lazy_evaluation = False + rotate.lazy = False def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0) @@ -71,7 +71,7 @@ def test_spatial_axes(self): rotated = rotate(**call_param) # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1234) - rotate.lazy_evaluation = False + rotate.lazy = False self.assertEqual(len(rotated.applied_operations), 1) expected = [np.rot90(channel, rotate._rand_k, (0, 1)) for channel in self.imt[0]] @@ -88,7 +88,7 @@ def test_prob_k_spatial_axes(self): rotated = rotate(**call_param) # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index f811f1a6a6..23e9025c08 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -34,7 +34,7 @@ def test_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1323, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -58,7 +58,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -76,7 +76,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -94,7 +94,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index a0d56bcaf3..df121e2220 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandScaleCrop, RandSpatialCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -84,13 +84,13 @@ def test_random_shape(self, input_param, input_shape, expected_shape): # lazy # reset random seed to ensure the same results cropper.set_random_state(seed=123) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 69d2e5af5d..92f0f9d9be 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandSpatialCropSamples -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -112,14 +112,14 @@ def test_pending_ops(self, input_param, input_shape, _expected_shape, _expected_ self.assertIsInstance(expected[0], MetaTensor) # lazy xform.set_random_state(1234) - xform.lazy_evaluation = True + xform.lazy = True pending_result = xform(image) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index fc6e6c8c43..ec0d63cc50 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ @@ -122,15 +122,16 @@ def test_pending_ops(self, input_param, input_data, _expected_shape, _expected_l # lazy xform.set_random_state(1234) - xform.lazy_evaluation = True + xform.lazy = True pending_result = xform(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result_img = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] - result_seg = apply_transforms(_pending_result["seg"], mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result_img = apply_pending(_pending_result["img"], overrides=overrides)[0] + result_seg = apply_pending(_pending_result["seg"], overrides=overrides)[0] # compare assert_allclose(result_img, expected[i]["img"], rtol=1e-5) assert_allclose(result_seg, expected[i]["seg"], rtol=1e-5) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 5114a45159..123459235f 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandScaleCropd, RandSpatialCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -89,13 +89,13 @@ def test_random_shape(self, input_param, input_shape, expected_shape): # lazy # reset random seed to ensure the same results cropper.set_random_state(seed=123) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index e279f29f68..47a8f3bfa2 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -178,14 +178,14 @@ def test_pending_ops(self, _, input_param, img, weight, expected_shape, expected self.assertIsInstance(expected[0], MetaTensor) # lazy crop.set_random_state(10) - crop.lazy_evaluation = True + crop.lazy = True pending_result = crop(img, weight) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 51e1b15c2c..9d37779613 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.dictionary import RandWeightedCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -166,14 +166,14 @@ def test_pending_ops(self, _, input_param, input_data, expected_shape, expected_ self.assertIsInstance(expected[0]["img"], MetaTensor) # lazy crop.set_random_state(10) - crop.lazy_evaluation = True + crop.lazy = True pending_result = crop(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result["img"], overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i]["img"], rtol=1e-5) diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index f080056b63..bb0495c793 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -58,7 +58,7 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz test_resampler_lazy( random_zoom, zoomed, init_param, call_param, key, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6 ) - random_zoom.lazy_evaluation = False + random_zoom.lazy = False test_local_inversion(random_zoom, zoomed, {key: im}, key) expected = [ diff --git a/tests/test_random_order.py b/tests/test_random_order.py index 9ed22d30ae..5eadedb58a 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -12,9 +12,15 @@ from __future__ import annotations import unittest +from copy import deepcopy +import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import RandomOrder from monai.transforms.compose import Compose @@ -98,5 +104,37 @@ def test_inverse(self, transform, invertible, use_metatensor): self.assertDictEqual(fwd_data[i], _fwd_inv_data) +TEST_RANDOM_ORDER_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestRandomOrderAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_RANDOM_ORDER_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = RandomOrder(deepcopy(pipeline)) + c(data, start=1) + + with self.assertRaises(ValueError): + c = RandomOrder(deepcopy(pipeline)) + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 8c33643d1f..287df039b8 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after TEST_CASES = [ @@ -79,18 +79,18 @@ def test_pending_ops(self, input_param, input_shape, _expected_data, align_corne expected = padcropper(image) self.assertIsInstance(expected, MetaTensor) # lazy - padcropper.lazy_evaluation = True + padcropper.lazy = True pending_result = padcropper(image) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms( - pending_result, - mode="nearest", - padding_mode=TESTS_PENDING_MODE[input_param["mode"]], - align_corners=align_corners, - )[0] + overrides = { + "mode": "nearest", + "padding_mode": TESTS_PENDING_MODE[input_param["mode"]], + "align_corners": align_corners, + } + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) inverted = padcropper.inverse(result) diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index a71652375b..471144a609 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -20,7 +20,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.test_resize_with_pad_or_crop import TESTS_PENDING_MODE from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after @@ -74,15 +74,18 @@ def test_pending_ops(self, input_param, input_data, _expected_data): expected = padcropper(input_data)["img"] self.assertIsInstance(expected, MetaTensor) # lazy - padcropper.lazy_evaluation = True + padcropper.lazy = True pending_result = padcropper(input_data)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms( - pending_result, mode="nearest", padding_mode=TESTS_PENDING_MODE[input_param["mode"]], align_corners=True - )[0] + overrides = { + "mode": "nearest", + "padding_mode": TESTS_PENDING_MODE[input_param["mode"]], + "align_corners": True, + } + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index fd54e7639f..0948469df9 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -18,7 +18,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Affine, Rotate90 -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.utils import optional_import from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( @@ -41,7 +41,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -61,7 +61,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -77,7 +77,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] @@ -93,7 +93,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -111,7 +111,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -127,7 +127,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -143,7 +143,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] @@ -159,7 +159,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -177,16 +177,16 @@ def test_affine_rot90(self, s): def method_0(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + out = apply_pending(out, overrides={"padding_mode": "border", "align_corners": ac})[0] return out def method_1(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + out = apply_pending(out, overrides={"mode": 1, "padding_mode": "nearest", "align_corners": ac})[0] return out def method_2(im, ac): diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 95d475d480..08d3a97498 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -33,7 +33,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -54,7 +54,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -71,7 +71,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -88,7 +88,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] diff --git a/tests/test_some_of.py b/tests/test_some_of.py index 0cc903bb2d..cba2c8a464 100644 --- a/tests/test_some_of.py +++ b/tests/test_some_of.py @@ -12,9 +12,15 @@ from __future__ import annotations import unittest +from copy import deepcopy +import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import TraceableTransform, Transform from monai.transforms.compose import Compose, SomeOf @@ -206,5 +212,37 @@ def test_bad_num_transforms(self): self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=("a", 1)) +TEST_SOMEOF_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestSomeOfAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_SOMEOF_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + with self.assertRaises(ValueError): + c = SomeOf(deepcopy(pipeline)) + c(data, start=1) + + with self.assertRaises(ValueError): + c = SomeOf(deepcopy(pipeline)) + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_combine_transforms.py b/tests/test_spatial_combine_transforms.py index 74c03fc4ff..8594daed16 100644 --- a/tests/test_spatial_combine_transforms.py +++ b/tests/test_spatial_combine_transforms.py @@ -20,7 +20,7 @@ import monai.transforms as mt from monai.data import create_test_image_2d, create_test_image_3d from monai.data.meta_tensor import MetaTensor -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from monai.utils import set_determinism from tests.lazy_transforms_utils import get_apply_param @@ -162,7 +162,7 @@ def test_combine_transforms(self, input_shape, funcs): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True if isinstance(_func, mt.Randomizable): _func.set_random_state(seed=seed) pending_result = _func(pending_result) @@ -175,7 +175,7 @@ def test_combine_transforms(self, input_shape, funcs): init_param = funcs[-1][1] call_param = {} apply_param = get_apply_param(init_param, call_param) - result = apply_transforms(pending_result, **apply_param)[0] + result = apply_pending(pending_result, overrides=apply_param)[0] match_ratio = np.sum(np.isclose(result.array, expected.array, atol=5e-1)) / np.prod(result.shape) self.assertGreater(match_ratio, 0.5) # at least half of the images are very close diff --git a/tests/test_zoom.py b/tests/test_zoom.py index b614acc9e4..e1ea3c25a3 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -20,7 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Zoom -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import ( DEFAULT_TEST_AFFINE, TEST_NDARRAYS_ALL, @@ -53,12 +53,13 @@ def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False): expected = zoom_fn(im) self.assertIsInstance(expected, MetaTensor) # lazy - zoom_fn.lazy_evaluation = True + zoom_fn.lazy = True pending_result = zoom_fn(im) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) - result = apply_transforms(pending_result, mode="bilinear", dtype=np.float64, align_corners=align_corners)[0] + overrides = {"mode": "bilinear", "dtype": np.float64, "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare match_ratio = np.sum(np.isclose(result, expected)) / np.prod(result.shape) self.assertGreater(match_ratio, 0.95) diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index a76e43a6b4..1dcbf98572 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -57,7 +57,7 @@ def test_correct_results(self, zoom, mode, keep_size, align_corners=None): test_resampler_lazy( zoom_fn, zoomed, init_param, call_param, output_key=key, atol=1e-4 if USE_COMPILED else 1e-6 ) - zoom_fn.lazy_evaluation = False + zoom_fn.lazy = False test_local_inversion(zoom_fn, zoomed, {key: im}, key) _order = 0