diff --git a/monai/deploy/operators/monai_seg_inference_operator.py b/monai/deploy/operators/monai_seg_inference_operator.py index 63d6007a..4778791c 100644 --- a/monai/deploy/operators/monai_seg_inference_operator.py +++ b/monai/deploy/operators/monai_seg_inference_operator.py @@ -1,4 +1,4 @@ -# Copyright 2021-2023 MONAI Consortium +# Copyright 2021-2025 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os from pathlib import Path @@ -19,6 +20,7 @@ from monai.deploy.utils.importutil import optional_import from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11. +from monai.utils import BlendMode, PytorchPadMode MONAI_UTILS = "monai.utils" torch, _ = optional_import("torch", "1.5") @@ -54,9 +56,14 @@ class InfererType(StrEnum): SLIDING_WINDOW = "sliding_window" +# define other StrEnum types +BlendModeType = BlendMode +PytorchPadModeType = PytorchPadMode + + # @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"]) class MonaiSegInferenceOperator(InferenceOperator): - """This segmentation operator uses MONAI transforms and Sliding Window Inference. + """This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference. This operator performs pre-transforms on a input image, inference using a given model, and post-transforms. The segmentation image is saved @@ -64,6 +71,12 @@ class MonaiSegInferenceOperator(InferenceOperator): If specified in the post transforms, results may also be saved to disk. + This operator uses the MONAI inference utils functions for sliding window and simple inference, + and thus input parameters need to be as expected by these functions. + + Any additional sliding window arguments not explicitly defined in this operator can be passed via + **kwargs for forwarding to 'sliding_window_inference'. + Named Input: image: Image object of the input image. @@ -74,6 +87,35 @@ class MonaiSegInferenceOperator(InferenceOperator): # For testing the app directly, the model should be at the following path. MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts")) + @staticmethod + def filter_sw_kwargs(**kwargs) -> Dict[str, Any]: + """ + Returns a dictionary of named parameters of the sliding_window_inference function that are: + - Not explicitly defined in the __init__ of this class + - Not explicitly used when calling sliding_window_inference + + Args: + **kwargs: extra arguments passed into __init__ beyond the explicitly defined args. + + Returns: + filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference. + """ + + init_params = inspect.signature(MonaiSegInferenceOperator).parameters + + # inputs + predictor explicitly used when calling sliding_window_inference + explicit_used = ["inputs", "predictor"] + + filtered_params = {} + for name, val in kwargs.items(): + if name in init_params or name in explicit_used: + # match log formatting + logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}") + logger.warning(f"{name!r} is already explicity defined or used; ignoring input arg") + else: + filtered_params[name] = val + return filtered_params + def __init__( self, fragment: Fragment, @@ -85,6 +127,8 @@ def __init__( model_name: Optional[str] = "", overlap: float = 0.25, sw_batch_size: int = 4, + mode: Union[BlendModeType, str] = BlendModeType.CONSTANT, + padding_mode: Union[PytorchPadModeType, str] = PytorchPadModeType.CONSTANT, inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW, model_path: Path = MODEL_LOCAL_PATH, **kwargs, @@ -103,9 +147,15 @@ def __init__( overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25. Applicable for "SLIDING_WINDOW" only. sw_batch_size(int): The batch size to run window slices. Defaults to 4. - Applicable for "SLIDING_WINDOW" only. + Applicable for "SLIDING_WINDOW" only. + mode (BlendModeType): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT". + Applicable for "SLIDING_WINDOW" only. + padding_mode (PytorchPadModeType): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs, + "CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT". + Applicable for "SLIDING_WINDOW" only. inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW". model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir. + **kwargs: any other sliding window parameters to forward (e.g. `sigma_scale`, `cval`, etc.). """ self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) @@ -121,7 +171,10 @@ def __init__( self._model_name = model_name.strip() if isinstance(model_name, str) else "" self._overlap = overlap self._sw_batch_size = sw_batch_size + self._mode = mode + self._padding_mode = padding_mode self._inferer = inferer + self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args # Add this so that the local model path can be set from the calling app self.model_path = model_path @@ -215,6 +268,28 @@ def sw_batch_size(self, val: int): raise ValueError("sw_batch_size must be a positive integer.") self._sw_batch_size = val + @property + def mode(self) -> Union[BlendModeType, str]: + """The blend mode used during sliding window inference""" + return self._mode + + @mode.setter + def mode(self, val: BlendModeType): + if not isinstance(val, BlendModeType): + raise ValueError(f"Value must be of the correct type {BlendModeType}.") + self._mode = val + + @property + def padding_mode(self) -> Union[PytorchPadModeType, str]: + """The padding mode to use when padding input images for inference""" + return self._padding_mode + + @padding_mode.setter + def padding_mode(self, val: PytorchPadModeType): + if not isinstance(val, PytorchPadModeType): + raise ValueError(f"Value must be of the correct type {PytorchPadModeType}.") + self._padding_mode = val + @property def inferer(self) -> Union[InfererType, str]: """The type of inferer to use""" @@ -320,7 +395,10 @@ def compute_impl(self, input_image, context): roi_size=self._roi_size, sw_batch_size=self.sw_batch_size, overlap=self.overlap, + mode=self._mode, + padding_mode=self._padding_mode, predictor=self.model, + **self._implicit_params, # additional sliding window arguments ) elif self._inferer == InfererType.SIMPLE: # Instantiates the SimpleInferer and directly uses its __call__ function