9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from monai .deploy .operators .monai_bundle_inference_operator import MonaiBundleInferenceOperator , get_bundle_config
13
- from monai .deploy .utils .importutil import optional_import
14
12
from typing import Any , Dict , Tuple , Union
13
+
15
14
from monai .deploy .core import Image
15
+ from monai .deploy .operators .monai_bundle_inference_operator import MonaiBundleInferenceOperator , get_bundle_config
16
+ from monai .deploy .utils .importutil import optional_import
16
17
17
18
MONAI_UTILS = "monai.utils"
18
19
nibabel , _ = optional_import ("nibabel" , "3.2.1" )
@@ -44,12 +45,12 @@ class MONetBundleInferenceOperator(MonaiBundleInferenceOperator):
44
45
This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
45
46
configurations and prediction logic. It initializes the nnUNet predictor and provides
46
47
a method for performing inference on input data.
47
-
48
+
48
49
Attributes
49
50
----------
50
51
_nnunet_predictor : torch.nn.Module
51
52
The nnUNet predictor module used for inference.
52
-
53
+
53
54
Methods
54
55
-------
55
56
_init_config(config_names)
@@ -65,16 +66,14 @@ def __init__(
65
66
** kwargs ,
66
67
):
67
68
68
-
69
69
super ().__init__ (* args , ** kwargs )
70
-
71
- self ._nnunet_predictor : torch .nn .Module = None
72
-
73
-
74
- def _init_config (self , config_names ):
70
+
71
+ self ._nnunet_predictor : torch .nn .Module = None
72
+
73
+ def _init_config (self , config_names ):
75
74
76
75
super ()._init_config (config_names )
77
- parser = get_bundle_config (str (self ._bundle_path ), config_names )
76
+ parser = get_bundle_config (str (self ._bundle_path ), config_names )
78
77
self ._parser = parser
79
78
80
79
self ._nnunet_predictor = parser .get_parsed_content ("network_def" )
@@ -83,7 +82,7 @@ def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ..
83
82
"""Predicts output using the inferer."""
84
83
85
84
self ._nnunet_predictor .predictor .network = self ._model_network
86
- #os.environ['nnUNet_def_n_proc'] = "1"
85
+ # os.environ['nnUNet_def_n_proc'] = "1"
87
86
if len (data .shape ) == 4 :
88
87
data = data [None ]
89
88
return self ._nnunet_predictor (data )
0 commit comments