Skip to content

Commit 7363823

Browse files
authored
Merge pull request #12 from AdaptiveMotorControlLab/feature/swinunetr
[WIP] Adding SwinUNetR
2 parents fde3c71 + 158e0c2 commit 7363823

14 files changed

+249
-469
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,3 @@ venv/
9999
/napari_cellseg3d/models/saved_weights/
100100
/docs/res/logo/old_logo/
101101
/reqs/
102-

napari_cellseg3d/interface.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Optional
2-
from typing import Union
32
from typing import List
43

54

@@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget):
6160
widget.setVisible(checkbox.isChecked())
6261

6362

63+
def add_label(widget, label, label_before=True, horizontal=True):
64+
if label_before:
65+
return combine_blocks(widget, label, horizontal=horizontal)
66+
else:
67+
return combine_blocks(label, widget, horizontal=horizontal)
68+
69+
6470
class Button(QPushButton):
6571
"""Class for a button with a title and connected to a function when clicked. Inherits from QPushButton.
6672
@@ -494,20 +500,33 @@ def __init__(
494500
step=1,
495501
parent: Optional[QWidget] = None,
496502
fixed: Optional[bool] = True,
503+
label: Optional[str] = None,
497504
):
498505
"""Args:
499506
min (Optional[float]): minimum value, defaults to 0
500507
max (Optional[float]): maximum value, defaults to 10
501508
default (Optional[float]): default value, defaults to 0
502509
step (Optional[float]): step value, defaults to 1
503510
parent: parent widget, defaults to None
504-
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed"""
511+
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed
512+
label (Optional[str]): if provided, creates a label with the chosen title to use with the counter"""
505513

506514
super().__init__(parent)
507515
set_spinbox(self, min, max, default, step, fixed)
508516

517+
if label is not None:
518+
self.label = make_label(name=label)
519+
520+
# def setToolTip(self, a0: str) -> None:
521+
# self.setToolTip(a0)
522+
# if self.label is not None:
523+
# self.label.setToolTip(a0)
524+
525+
def get_with_label(self, horizontal=True):
526+
return add_label(self, self.label, horizontal=horizontal)
527+
509528
def set_precision(self, decimals):
510-
"""Sets the precision of the box to the speicifed number of decimals"""
529+
"""Sets the precision of the box to the specified number of decimals"""
511530
self.setDecimals(decimals)
512531

513532
@classmethod
@@ -535,6 +554,7 @@ def __init__(
535554
step=1,
536555
parent: Optional[QWidget] = None,
537556
fixed: Optional[bool] = True,
557+
label: Optional[str] = None,
538558
):
539559
"""Args:
540560
min (Optional[int]): minimum value, defaults to 0
@@ -546,6 +566,9 @@ def __init__(
546566

547567
super().__init__(parent)
548568
set_spinbox(self, min, max, default, step, fixed)
569+
self.label = None
570+
if label is not None:
571+
self.label = make_label(label, self)
549572

550573
@classmethod
551574
def make_n(

napari_cellseg3d/model_framework.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from napari_cellseg3d import utils
1414
from napari_cellseg3d.log_utility import Log
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
16-
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
16+
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR
17+
18+
# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1719
from napari_cellseg3d.models import model_VNet as VNet
1820
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
1921
from napari_cellseg3d.plugin_base import BasePluginFolder
@@ -62,8 +64,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6264
self.models_dict = {
6365
"VNet": VNet,
6466
"SegResNet": SegResNet,
65-
"TRAILMAP": TRAILMAP,
67+
# "TRAILMAP": TRAILMAP,
6668
"TRAILMAP_MS": TRAILMAP_MS,
69+
"SwinUNetR": SwinUNetR,
6770
}
6871
"""dict: dictionary of available models, with string for widget display as key
6972

napari_cellseg3d/model_workers.py

Lines changed: 115 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
# Qt
4646
from qtpy.QtCore import Signal
4747

48-
4948
from napari_cellseg3d import utils
5049
from napari_cellseg3d import log_utility
5150

@@ -165,6 +164,9 @@ def __init__(self):
165164
super().__init__()
166165

167166

167+
# TODO : use dataclass for config instead ?
168+
169+
168170
class InferenceWorker(GeneratorWorker):
169171
"""A custom worker to run inference jobs in.
170172
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
@@ -180,6 +182,7 @@ def __init__(
180182
instance,
181183
use_window,
182184
window_infer_size,
185+
window_overlap,
183186
keep_on_cpu,
184187
stats_csv,
185188
images_filepaths=None,
@@ -231,6 +234,7 @@ def __init__(
231234
self.instance_params = instance
232235
self.use_window = use_window
233236
self.window_infer_size = window_infer_size
237+
self.window_overlap_percentage = window_overlap
234238
self.keep_on_cpu = keep_on_cpu
235239
self.stats_to_csv = stats_csv
236240
############################################
@@ -301,8 +305,6 @@ def log_parameters(self):
301305
f"Probability threshold is {self.instance_params['threshold']:.2f}\n"
302306
f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n"
303307
)
304-
# self.log(f"")
305-
# self.log("\n")
306308
self.log("-" * 20)
307309

308310
def load_folder(self):
@@ -313,25 +315,57 @@ def load_folder(self):
313315
data_check = LoadImaged(keys=["image"])(images_dict[0])
314316

315317
check = data_check["image"].shape
316-
# TODO remove
317-
# z_aniso = 5 / 1.5
318-
# if zoom is not None :
319-
# pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
320-
# else:
318+
321319
self.log("\nChecking dimensions...")
322320
pad = utils.get_padding_dim(check)
323321

324-
load_transforms = Compose(
325-
[
326-
LoadImaged(keys=["image"]),
327-
# AddChanneld(keys=["image"]), #already done
328-
EnsureChannelFirstd(keys=["image"]),
329-
# Orientationd(keys=["image"], axcodes="PLI"),
330-
# anisotropic_transform,
331-
SpatialPadd(keys=["image"], spatial_size=pad),
332-
EnsureTyped(keys=["image"]),
333-
]
334-
)
322+
# dims = self.model_dict["model_input_size"]
323+
#
324+
# if self.model_dict["name"] == "SegResNet":
325+
# model = self.model_dict["class"].get_net(
326+
# input_image_size=[
327+
# dims,
328+
# dims,
329+
# dims,
330+
# ]
331+
# )
332+
# elif self.model_dict["name"] == "SwinUNetR":
333+
# model = self.model_dict["class"].get_net(
334+
# img_size=[dims, dims, dims],
335+
# use_checkpoint=False,
336+
# )
337+
# else:
338+
# model = self.model_dict["class"].get_net()
339+
#
340+
# self.log_parameters()
341+
#
342+
# model.to(self.device)
343+
344+
# print("FILEPATHS PRINT")
345+
# print(self.images_filepaths)
346+
if self.use_window:
347+
load_transforms = Compose(
348+
[
349+
LoadImaged(keys=["image"]),
350+
# AddChanneld(keys=["image"]), #already done
351+
EnsureChannelFirstd(keys=["image"]),
352+
# Orientationd(keys=["image"], axcodes="PLI"),
353+
# anisotropic_transform,
354+
EnsureTyped(keys=["image"]),
355+
]
356+
)
357+
else:
358+
load_transforms = Compose(
359+
[
360+
LoadImaged(keys=["image"]),
361+
# AddChanneld(keys=["image"]), #already done
362+
EnsureChannelFirstd(keys=["image"]),
363+
# Orientationd(keys=["image"], axcodes="PLI"),
364+
# anisotropic_transform,
365+
SpatialPadd(keys=["image"], spatial_size=pad),
366+
EnsureTyped(keys=["image"]),
367+
]
368+
)
335369

336370
self.log("\nLoading dataset...")
337371
inference_ds = Dataset(data=images_dict, transform=load_transforms)
@@ -364,19 +398,32 @@ def load_layer(self):
364398

365399
# print(volume.shape)
366400
# print(volume.dtype)
367-
368-
load_transforms = Compose(
369-
[
370-
ToTensor(),
371-
# anisotropic_transform,
372-
AddChannel(),
373-
SpatialPad(spatial_size=pad),
374-
AddChannel(),
375-
EnsureType(),
376-
],
377-
map_items=False,
378-
log_stats=True,
379-
)
401+
if self.use_window:
402+
load_transforms = Compose(
403+
[
404+
ToTensor(),
405+
# anisotropic_transform,
406+
AddChannel(),
407+
# SpatialPad(spatial_size=pad),
408+
AddChannel(),
409+
EnsureType(),
410+
],
411+
map_items=False,
412+
log_stats=True,
413+
)
414+
else:
415+
load_transforms = Compose(
416+
[
417+
ToTensor(),
418+
# anisotropic_transform,
419+
AddChannel(),
420+
SpatialPad(spatial_size=pad),
421+
AddChannel(),
422+
EnsureType(),
423+
],
424+
map_items=False,
425+
log_stats=True,
426+
)
380427

381428
self.log("\nLoading dataset...")
382429
input_image = load_transforms(volume)
@@ -405,8 +452,10 @@ def model_output(
405452

406453
if self.use_window:
407454
window_size = self.window_infer_size
455+
window_overlap = self.window_overlap_percentage
408456
else:
409457
window_size = None
458+
window_overlap = 0.25
410459

411460
outputs = sliding_window_inference(
412461
inputs,
@@ -415,6 +464,7 @@ def model_output(
415464
predictor=model_output,
416465
sw_device=self.device,
417466
device=dataset_device,
467+
overlap=window_overlap,
418468
)
419469

420470
out = outputs.detach().cpu()
@@ -508,13 +558,12 @@ def save_image(
508558
)
509559

510560
imwrite(file_path, image)
561+
filename = os.path.split(file_path)[1]
511562

512563
if from_layer:
513-
self.log(f"\nLayer prediction saved as :")
564+
self.log(f"\nLayer prediction saved as : {filename}")
514565
else:
515-
self.log(f"\nFile n°{i+1} saved as :")
516-
filename = os.path.split(file_path)[1]
517-
self.log(filename)
566+
self.log(f"\nFile n°{i+1} saved as : {filename}")
518567

519568
def aniso_transform(self, image):
520569
zoom = self.transforms["zoom"][1]
@@ -630,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):
630679

631680
self.save_image(out, from_layer=True)
632681

633-
instance_labels, data_dict = self.get_instance_result(out,from_layer=True)
682+
instance_labels, data_dict = self.get_instance_result(
683+
out, from_layer=True
684+
)
634685

635-
return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict)
686+
return self.create_result_dict(
687+
out, instance_labels, from_layer=True, data_dict=data_dict
688+
)
636689

637690
def inference(self):
638691
"""
@@ -674,29 +727,27 @@ def inference(self):
674727
torch.set_num_threads(1) # required for threading on macOS ?
675728
self.log("Number of threads has been set to 1 for macOS")
676729

677-
# if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
678-
# torch.backends.cudnn.benchmark = False
679-
680-
# TODO : better solution than loading first image always ?
681-
# data_check = LoadImaged(keys=["image"])(images_dict[0])
682-
# print(data)
683-
# check = data_check["image"].shape
684-
# print(check)
685-
686730
try:
687-
dims = self.model_dict["segres_size"]
731+
dims = self.model_dict["model_input_size"]
732+
self.log(f"MODEL DIMS : {dims}")
733+
self.log(self.model_dict["name"])
688734

689-
model = self.model_dict["class"].get_net()
690735
if self.model_dict["name"] == "SegResNet":
691-
model = self.model_dict["class"].get_net()(
736+
model = self.model_dict["class"].get_net(
692737
input_image_size=[
693738
dims,
694739
dims,
695740
dims,
696741
], # TODO FIX ! find a better way & remove model-specific code
697-
out_channels=1,
698-
# dropout_prob=0.3,
699742
)
743+
elif self.model_dict["name"] == "SwinUNetR":
744+
model = self.model_dict["class"].get_net(
745+
img_size=[dims, dims, dims],
746+
use_checkpoint=False,
747+
)
748+
else:
749+
model = self.model_dict["class"].get_net()
750+
model = model.to(self.device)
700751

701752
self.log_parameters()
702753

@@ -722,10 +773,7 @@ def inference(self):
722773
AsDiscrete(threshold=t), EnsureType()
723774
)
724775

725-
726-
self.log(
727-
"\nLoading weights..."
728-
)
776+
self.log("\nLoading weights...")
729777

730778
if self.weights_dict["custom"]:
731779
weights = self.weights_dict["path"]
@@ -1022,11 +1070,21 @@ def train(self):
10221070
else:
10231071
size = check
10241072
print(f"Size of image : {size}")
1025-
model = model_class.get_net()(
1073+
model = model_class.get_net(
10261074
input_image_size=utils.get_padding_dim(size),
10271075
out_channels=1,
10281076
dropout_prob=0.3,
10291077
)
1078+
elif model_name == "SwinUNetR":
1079+
if self.sampling:
1080+
size = self.sample_size
1081+
else:
1082+
size = check
1083+
print(f"Size of image : {size}")
1084+
model = model_class.get_net(
1085+
img_size=utils.get_padding_dim(size),
1086+
use_checkpoint=True,
1087+
)
10301088
else:
10311089
model = model_class.get_net() # get an instance of the model
10321090
model = model.to(self.device)

0 commit comments

Comments
 (0)