Skip to content

Commit 6f7a846

Browse files
Enhance predict method to support multimodal data concatenation and update docstring
Signed-off-by: Simone Bendazzoli <[email protected]>
1 parent 6de6132 commit 6f7a846

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

monai/deploy/operators/monet_bundle_inference_operator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from monai.deploy.core import Image
1515
from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
1616
from monai.deploy.utils.importutil import optional_import
17+
from monai.transforms import ConcatItemsd, ResampleToMatch
1718

1819
torch, _ = optional_import("torch", "1.10.2")
19-
20+
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
2021
__all__ = ["MONetBundleInferenceOperator"]
2122

2223

@@ -60,10 +61,18 @@ def _init_config(self, config_names):
6061
self._nnunet_predictor = parser.get_parsed_content("network_def")
6162

6263
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
63-
"""Predicts output using the inferer."""
64+
"""Predicts output using the inferer. If multimodal data is provided as keyword arguments,
65+
it concatenates the data with the main input data."""
6466

6567
self._nnunet_predictor.predictor.network = self._model_network
6668

69+
if len(kwargs) > 0:
70+
multimodal_data = {"image": data}
71+
for key in kwargs.keys():
72+
if isinstance(kwargs[key], MetaTensor):
73+
multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data
74+
)
75+
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
6776
if len(data.shape) == 4:
6877
data = data[None]
6978
return self._nnunet_predictor(data)

0 commit comments

Comments
 (0)