|
14 | 14 | from monai.deploy.core import Image
|
15 | 15 | from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
|
16 | 16 | from monai.deploy.utils.importutil import optional_import
|
| 17 | +from monai.transforms import ConcatItemsd, ResampleToMatch |
17 | 18 |
|
18 | 19 | torch, _ = optional_import("torch", "1.10.2")
|
19 |
| - |
| 20 | +MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor") |
20 | 21 | __all__ = ["MONetBundleInferenceOperator"]
|
21 | 22 |
|
22 | 23 |
|
@@ -60,10 +61,18 @@ def _init_config(self, config_names):
|
60 | 61 | self._nnunet_predictor = parser.get_parsed_content("network_def")
|
61 | 62 |
|
62 | 63 | def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
|
63 |
| - """Predicts output using the inferer.""" |
| 64 | + """Predicts output using the inferer. If multimodal data is provided as keyword arguments, |
| 65 | + it concatenates the data with the main input data.""" |
64 | 66 |
|
65 | 67 | self._nnunet_predictor.predictor.network = self._model_network
|
66 | 68 |
|
| 69 | + if len(kwargs) > 0: |
| 70 | + multimodal_data = {"image": data} |
| 71 | + for key in kwargs.keys(): |
| 72 | + if isinstance(kwargs[key], MetaTensor): |
| 73 | + multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data |
| 74 | + ) |
| 75 | + data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"] |
67 | 76 | if len(data.shape) == 4:
|
68 | 77 | data = data[None]
|
69 | 78 | return self._nnunet_predictor(data)
|
0 commit comments