Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions monai/deploy/operators/monai_seg_inference_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2023 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -47,13 +47,29 @@
__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]


class BlendModeType(StrEnum):
"""Represents the supported blend modes for sliding window inference."""

CONSTANT = "constant"
GAUSSIAN = "gaussian"


class InfererType(StrEnum):
"""Represents the supported types of the inferer, e.g. Simple and Sliding Window."""

SIMPLE = "simple"
SLIDING_WINDOW = "sliding_window"


class PytorchPadModeType(StrEnum):
"""Represents the supported padding modes for sliding window inference."""

CONSTANT = "constant"
REFLECT = "reflect"
REPLICATE = "replicate"
CIRCULAR = "circular"


# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
class MonaiSegInferenceOperator(InferenceOperator):
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
Expand Down Expand Up @@ -85,6 +101,8 @@ def __init__(
model_name: Optional[str] = "",
overlap: float = 0.25,
sw_batch_size: int = 4,
mode: Union[BlendModeType, str] = BlendModeType.CONSTANT,
padding_mode: Union[PytorchPadModeType, str] = PytorchPadModeType.CONSTANT,
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
model_path: Path = MODEL_LOCAL_PATH,
**kwargs,
Expand All @@ -103,7 +121,12 @@ def __init__(
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
Applicable for "SLIDING_WINDOW" only.
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
Applicable for "SLIDING_WINDOW" only.
mode (BlendMode): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
Applicable for "SLIDING_WINDOW" only.
padding_mode (PytorchPadMode): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
"CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
Applicable for "SLIDING_WINDOW" only.
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
"""
Expand All @@ -121,6 +144,8 @@ def __init__(
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
self._overlap = overlap
self._sw_batch_size = sw_batch_size
self._mode = mode
self._padding_mode = padding_mode
self._inferer = inferer

# Add this so that the local model path can be set from the calling app
Expand Down Expand Up @@ -215,6 +240,28 @@ def sw_batch_size(self, val: int):
raise ValueError("sw_batch_size must be a positive integer.")
self._sw_batch_size = val

@property
def mode(self) -> Union[BlendModeType, str]:
"""The blend mode used during sliding window inference"""
return self._mode

@mode.setter
def mode(self, val: BlendModeType):
if not isinstance(val, BlendModeType):
raise ValueError(f"Value must be of the correct type {BlendModeType}.")
self._mode = val

@property
def padding_mode(self) -> Union[PytorchPadModeType, str]:
"""The padding mode to use when padding input images for inference"""
return self._padding_mode

@padding_mode.setter
def padding_mode(self, val: PytorchPadModeType):
if not isinstance(val, PytorchPadModeType):
raise ValueError(f"Value must be of the correct type {PytorchPadModeType}.")
self._padding_mode = val

@property
def inferer(self) -> Union[InfererType, str]:
"""The type of inferer to use"""
Expand Down Expand Up @@ -320,6 +367,8 @@ def compute_impl(self, input_image, context):
roi_size=self._roi_size,
sw_batch_size=self.sw_batch_size,
overlap=self.overlap,
mode=self._mode,
padding_mode=self._padding_mode,
predictor=self.model,
)
elif self._inferer == InfererType.SIMPLE:
Expand Down