Skip to content

MONAISegInferenceOperator Additional Arguments #547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions monai/deploy/operators/monai_seg_inference_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2023 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
import os
from pathlib import Path
Expand All @@ -19,6 +20,7 @@

from monai.deploy.utils.importutil import optional_import
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
from monai.utils import BlendMode, PytorchPadMode

MONAI_UTILS = "monai.utils"
torch, _ = optional_import("torch", "1.5")
Expand Down Expand Up @@ -54,16 +56,27 @@ class InfererType(StrEnum):
SLIDING_WINDOW = "sliding_window"


# define other StrEnum types
BlendModeType = BlendMode
PytorchPadModeType = PytorchPadMode


# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
class MonaiSegInferenceOperator(InferenceOperator):
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.

This operator performs pre-transforms on a input image, inference
using a given model, and post-transforms. The segmentation image is saved
as a named Image object in memory.

If specified in the post transforms, results may also be saved to disk.

This operator uses the MONAI inference utils functions for sliding window and simple inference,
and thus input parameters need to be as expected by these functions.

Any additional sliding window arguments not explicitly defined in this operator can be passed via
**kwargs for forwarding to 'sliding_window_inference'.

Named Input:
image: Image object of the input image.

Expand All @@ -74,6 +87,35 @@ class MonaiSegInferenceOperator(InferenceOperator):
# For testing the app directly, the model should be at the following path.
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

@staticmethod
def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
"""
Returns a dictionary of named parameters of the sliding_window_inference function that are:
- Not explicitly defined in the __init__ of this class
- Not explicitly used when calling sliding_window_inference

Args:
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.

Returns:
filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
"""

init_params = inspect.signature(MonaiSegInferenceOperator).parameters

# inputs + predictor explicitly used when calling sliding_window_inference
explicit_used = ["inputs", "predictor"]

filtered_params = {}
for name, val in kwargs.items():
if name in init_params or name in explicit_used:
# match log formatting
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
logger.warning(f"{name!r} is already explicity defined or used; ignoring input arg")
else:
filtered_params[name] = val
return filtered_params

def __init__(
self,
fragment: Fragment,
Expand All @@ -85,6 +127,8 @@ def __init__(
model_name: Optional[str] = "",
overlap: float = 0.25,
sw_batch_size: int = 4,
mode: Union[BlendModeType, str] = BlendModeType.CONSTANT,
padding_mode: Union[PytorchPadModeType, str] = PytorchPadModeType.CONSTANT,
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
model_path: Path = MODEL_LOCAL_PATH,
**kwargs,
Expand All @@ -103,9 +147,15 @@ def __init__(
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
Applicable for "SLIDING_WINDOW" only.
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
Applicable for "SLIDING_WINDOW" only.
mode (BlendModeType): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
Applicable for "SLIDING_WINDOW" only.
padding_mode (PytorchPadModeType): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
"CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
Applicable for "SLIDING_WINDOW" only.
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
**kwargs: any other sliding window parameters to forward (e.g. `sigma_scale`, `cval`, etc.).
"""

self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
Expand All @@ -121,7 +171,10 @@ def __init__(
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
self._overlap = overlap
self._sw_batch_size = sw_batch_size
self._mode = mode
self._padding_mode = padding_mode
self._inferer = inferer
self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args

# Add this so that the local model path can be set from the calling app
self.model_path = model_path
Expand Down Expand Up @@ -215,6 +268,28 @@ def sw_batch_size(self, val: int):
raise ValueError("sw_batch_size must be a positive integer.")
self._sw_batch_size = val

@property
def mode(self) -> Union[BlendModeType, str]:
"""The blend mode used during sliding window inference"""
return self._mode

@mode.setter
def mode(self, val: BlendModeType):
if not isinstance(val, BlendModeType):
raise ValueError(f"Value must be of the correct type {BlendModeType}.")
self._mode = val

@property
def padding_mode(self) -> Union[PytorchPadModeType, str]:
"""The padding mode to use when padding input images for inference"""
return self._padding_mode

@padding_mode.setter
def padding_mode(self, val: PytorchPadModeType):
if not isinstance(val, PytorchPadModeType):
raise ValueError(f"Value must be of the correct type {PytorchPadModeType}.")
self._padding_mode = val

@property
def inferer(self) -> Union[InfererType, str]:
"""The type of inferer to use"""
Expand Down Expand Up @@ -320,7 +395,10 @@ def compute_impl(self, input_image, context):
roi_size=self._roi_size,
sw_batch_size=self.sw_batch_size,
overlap=self.overlap,
mode=self._mode,
padding_mode=self._padding_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good to add these couple params as explicitly supported named args to the MONAI inference function. In addition, I had the idea of extending the support to all other params on the MONAI inference functions by passing the (filtered) **kwargs down to the functions. I will provide a static function for the filtering in the general comment section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering functionality added. Tested this out with a few example cases, and the desired behavior was displayed. Here is an example - we can see input and predictor inputs are ignored, and that buffer_steps and buffer_dim parameters are successfully passed to sliding_window_inference to produce a ValueError (values chosen purposefully to produce the error):

# delegates inference and saving output to the built-in operator
# parameters pulled from inference.yaml file of the MONAI bundle
infer_operator = MonaiSegInferenceOperator(
    self.fragment,
    roi_size=(96, 96, 96),
    pre_transforms=pre_transforms,
    post_transforms=post_transforms,
    overlap=0.25,
    app_context=self.app_context,
    model_name="",
    inferer=InfererType.SLIDING_WINDOW,
    sw_batch_size=1,
    mode=BlendModeType.GAUSSIAN,
    padding_mode=PytorchPadModeType.REPLICATE,
    model_path=self.model_path,
    inputs="test",
    predictor="testing",
    buffer_steps=1,
    buffer_dim=4
)
2025-08-01 19:04:15,548] [INFO] (ct_totalseg_operator.CTTotalSegOperator) - TorchScript model detected
[2025-08-01 19:04:15,548] [WARNING] (monai_seg_inference_operator.MonaiSegInferenceOperator) - 'inputs' is already explicity defined or used; ignoring input arg
[2025-08-01 19:04:15,548] [WARNING] (monai_seg_inference_operator.MonaiSegInferenceOperator) - 'predictor' is already explicity defined or used; ignoring input arg

....

[2025-08-01 19:04:18,125] [INFO] (monai_seg_inference_operator.MonaiSegInferenceOperator) - Input of <class 'monai.data.meta_tensor.MetaTensor'> shape: torch.Size([1, 1, 270, 270, 204])
[error] [gxf_wrapper.cpp:118] Exception occurred for operator: 'ct_totalseg_op' - ValueError: buffer_dim must be in [-3, 3], got 4.

At:
  /home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/site-packages/monai/inferers/utils.py(142): sliding_window_inference
  /home/bluna301/ct-totalsegmentator-map/my_app/monai_seg_inference_operator.py(435): compute_impl
  /home/bluna301/ct-totalsegmentator-map/my_app/ct_totalseg_operator.py(226): compute

[error] [entity_executor.cpp:596] Failed to tick codelet ct_totalseg_op in entity: ct_totalseg_op code: GXF_FAILURE
[warning] [greedy_scheduler.cpp:243] Error while executing entity 28 named 'ct_totalseg_op': GXF_FAILURE
[info] [greedy_scheduler.cpp:401] Scheduler finished.
[error] [program.cpp:580] wait failed. Deactivating...
[error] [runtime.cpp:1649] Graph wait failed with error: GXF_FAILURE
[warning] [gxf_executor.cpp:2428] GXF call GxfGraphWait(context) in line 2428 of file /workspace/holoscan-sdk/src/core/executors/gxf/gxf_executor.cpp failed with 'GXF_FAILURE' (1)
[info] [gxf_executor.cpp:2438] Graph execution finished.
[error] [gxf_executor.cpp:2446] Graph execution error: GXF_FAILURE
Traceback (most recent call last):
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/__main__.py", line 25, in <module>
    CTTotalSegmentatorApp().run()
  File "/home/bluna301/ct-totalsegmentator-map/my_app/app.py", line 61, in run
    super().run(*args, **kwargs)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/ct_totalseg_operator.py", line 226, in compute
    seg_image = infer_operator.compute_impl(input_image, context)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/monai_seg_inference_operator.py", line 391, in compute_impl
    d[self._pred_dataset_key] = sliding_window_inference(
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/site-packages/monai/inferers/utils.py", line 142, in sliding_window_inference
    raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
ValueError: buffer_dim must be in [-3, 3], got 4.

predictor=self.model,
**self._implicit_params, # additional sliding window arguments
)
elif self._inferer == InfererType.SIMPLE:
# Instantiates the SimpleInferer and directly uses its __call__ function
Expand Down