Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 41 additions & 55 deletions segmenter_model_zoo/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import importlib

import torch
from torch.autograd import Variable
from aicsmlsegment.utils import input_normalization
from scipy.ndimage import zoom
from aicsimageio import AICSImage

from segmenter_model_zoo.quilt_utils import validate_model
from aicsmlsegment.multichannel_sliding_window import sliding_window_inference

###############################################################################

Expand Down Expand Up @@ -214,6 +214,7 @@ def load_train(

else:
model_type = CHECKPOINT_PATH_MAPPING[checkpoint_name]["model_type"]
self.model_name = model_type

# load default model parameters or from model_param
if "size_in" in model_param:
Expand Down Expand Up @@ -304,6 +305,8 @@ def apply_on_single_zstack(
already_normalized: bool = False,
cutoff: float = None,
inference_param: Dict = {},
size_in: List = None,
size_out: List = None,
) -> np.ndarray:
"""
Apply a trained model on an image
Expand Down Expand Up @@ -334,6 +337,10 @@ def apply_on_single_zstack(
only one parameter is allowed: "ResizeRatio" (a list of three
float numbers to indicate the ResizeRatio to apply on ZYX axis).
More parameters may be added in the future.
size_in: List
the input patch size, to overwrite default
size_out: List
the output patch size, to overwrite default

Return:
-------------
Expand Down Expand Up @@ -383,8 +390,14 @@ def apply_on_single_zstack(
model = self.model
model.eval()

# check if need to use default size_in and size_out
if size_in is None:
size_in = self.size_in
if size_out is None:
size_out = self.size_out

# do padding on input
padding = [(x - y) // 2 for x, y in zip(self.size_in, self.size_out)]
padding = [(x - y) // 2 for x, y in zip(size_in, size_out)]
img_pad0 = np.pad(
input_img,
((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])),
Expand All @@ -394,62 +407,35 @@ def apply_on_single_zstack(
img_pad0, ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), "constant"
)

# we only support single output image in model zoo
# other outputs are only supported in full segmenter prediction so far
assert len(self.OutputCh) == 2
output_img = np.zeros(input_img.shape)
# pad the extra batch dimension
img_pad = np.expand_dims(img_pad, axis=0)

# loop through the image patch by patch
num_step_z = int(np.ceil(input_img.shape[1] / self.size_out[0]))
num_step_y = int(np.ceil(input_img.shape[2] / self.size_out[1]))
num_step_x = int(np.ceil(input_img.shape[3] / self.size_out[2]))
# run sliding window inference
with torch.no_grad():
for ix in range(num_step_x):
if ix < num_step_x - 1:
xa = ix * self.size_out[2]
output_tensor, _ = sliding_window_inference(
inputs=torch.from_numpy(img_pad).float().cuda(),
roi_size=size_in,
out_size=size_out,
original_image_size=input_img.shape[-3:],
sw_batch_size=1,
predictor=model.forward,
overlap=0.25,
mode="gaussian",
model_name=self.model_name,
)

output_img = output_tensor.cpu().data.numpy()
if self.OutputCh:
# old models, only take the output from the highest resolution
if type(self.OutputCh) == list:
# if it is [v1, v2], the second value is which channel to take from
# the highest resolution output
if len(self.OutputCh) >= 2:
self.OutputCh = self.OutputCh[1]
else:
xa = input_img.shape[3] - self.size_out[2]

for iy in range(num_step_y):
if iy < num_step_y - 1:
ya = iy * self.size_out[1]
else:
ya = input_img.shape[2] - self.size_out[1]

for iz in range(num_step_z):
if iz < num_step_z - 1:
za = iz * self.size_out[0]
else:
za = input_img.shape[1] - self.size_out[0]

input_patch = img_pad[
:,
za : (za + self.size_in[0]),
ya : (ya + self.size_in[1]),
xa : (xa + self.size_in[2]),
]
input_img_tensor = torch.from_numpy(input_patch)
tmp_out = model(Variable(input_img_tensor.cuda()).unsqueeze(0))
assert len(self.OutputCh) // 2 <= len(
tmp_out
), "the parameter OutputCh not compatible with output tensors"

label = tmp_out[self.OutputCh[0]]
prob = self.softmax(label)
out_flat_tensor = prob.cpu().data
out_tensor = out_flat_tensor.view(
self.size_out[0],
self.size_out[1],
self.size_out[2],
self.nclass[0],
)
out_nda = out_tensor.numpy()
output_img[
0,
za : (za + self.size_out[0]),
ya : (ya + self.size_out[1]),
xa : (xa + self.size_out[2]),
] = out_nda[:, :, :, self.OutputCh[1]]
# just convert list to integer
self.OutputCh = self.OutputCh[0]
output_img = output_img[:, self.OutputCh, :, :, :]

torch.cuda.empty_cache()

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
requirements = [
'PyYAML',
'aicsimageio>3.3.0',
'aicsmlsegment>0.0.5'
'aicsmlsegment>0.0.5',
'scikit-image',
"quilt3",
'quilt3',
]

extra_requirements = {
Expand Down