45
45
# Qt
46
46
from qtpy .QtCore import Signal
47
47
48
-
49
48
from napari_cellseg3d import utils
50
49
from napari_cellseg3d import log_utility
51
50
@@ -165,6 +164,9 @@ def __init__(self):
165
164
super ().__init__ ()
166
165
167
166
167
+ # TODO : use dataclass for config instead ?
168
+
169
+
168
170
class InferenceWorker (GeneratorWorker ):
169
171
"""A custom worker to run inference jobs in.
170
172
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
@@ -180,6 +182,7 @@ def __init__(
180
182
instance ,
181
183
use_window ,
182
184
window_infer_size ,
185
+ window_overlap ,
183
186
keep_on_cpu ,
184
187
stats_csv ,
185
188
images_filepaths = None ,
@@ -231,6 +234,7 @@ def __init__(
231
234
self .instance_params = instance
232
235
self .use_window = use_window
233
236
self .window_infer_size = window_infer_size
237
+ self .window_overlap_percentage = window_overlap
234
238
self .keep_on_cpu = keep_on_cpu
235
239
self .stats_to_csv = stats_csv
236
240
############################################
@@ -301,8 +305,6 @@ def log_parameters(self):
301
305
f"Probability threshold is { self .instance_params ['threshold' ]:.2f} \n "
302
306
f"Objects smaller than { self .instance_params ['size_small' ]} pixels will be removed\n "
303
307
)
304
- # self.log(f"")
305
- # self.log("\n")
306
308
self .log ("-" * 20 )
307
309
308
310
def load_folder (self ):
@@ -313,25 +315,57 @@ def load_folder(self):
313
315
data_check = LoadImaged (keys = ["image" ])(images_dict [0 ])
314
316
315
317
check = data_check ["image" ].shape
316
- # TODO remove
317
- # z_aniso = 5 / 1.5
318
- # if zoom is not None :
319
- # pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
320
- # else:
318
+
321
319
self .log ("\n Checking dimensions..." )
322
320
pad = utils .get_padding_dim (check )
323
321
324
- load_transforms = Compose (
325
- [
326
- LoadImaged (keys = ["image" ]),
327
- # AddChanneld(keys=["image"]), #already done
328
- EnsureChannelFirstd (keys = ["image" ]),
329
- # Orientationd(keys=["image"], axcodes="PLI"),
330
- # anisotropic_transform,
331
- SpatialPadd (keys = ["image" ], spatial_size = pad ),
332
- EnsureTyped (keys = ["image" ]),
333
- ]
334
- )
322
+ # dims = self.model_dict["model_input_size"]
323
+ #
324
+ # if self.model_dict["name"] == "SegResNet":
325
+ # model = self.model_dict["class"].get_net(
326
+ # input_image_size=[
327
+ # dims,
328
+ # dims,
329
+ # dims,
330
+ # ]
331
+ # )
332
+ # elif self.model_dict["name"] == "SwinUNetR":
333
+ # model = self.model_dict["class"].get_net(
334
+ # img_size=[dims, dims, dims],
335
+ # use_checkpoint=False,
336
+ # )
337
+ # else:
338
+ # model = self.model_dict["class"].get_net()
339
+ #
340
+ # self.log_parameters()
341
+ #
342
+ # model.to(self.device)
343
+
344
+ # print("FILEPATHS PRINT")
345
+ # print(self.images_filepaths)
346
+ if self .use_window :
347
+ load_transforms = Compose (
348
+ [
349
+ LoadImaged (keys = ["image" ]),
350
+ # AddChanneld(keys=["image"]), #already done
351
+ EnsureChannelFirstd (keys = ["image" ]),
352
+ # Orientationd(keys=["image"], axcodes="PLI"),
353
+ # anisotropic_transform,
354
+ EnsureTyped (keys = ["image" ]),
355
+ ]
356
+ )
357
+ else :
358
+ load_transforms = Compose (
359
+ [
360
+ LoadImaged (keys = ["image" ]),
361
+ # AddChanneld(keys=["image"]), #already done
362
+ EnsureChannelFirstd (keys = ["image" ]),
363
+ # Orientationd(keys=["image"], axcodes="PLI"),
364
+ # anisotropic_transform,
365
+ SpatialPadd (keys = ["image" ], spatial_size = pad ),
366
+ EnsureTyped (keys = ["image" ]),
367
+ ]
368
+ )
335
369
336
370
self .log ("\n Loading dataset..." )
337
371
inference_ds = Dataset (data = images_dict , transform = load_transforms )
@@ -364,19 +398,32 @@ def load_layer(self):
364
398
365
399
# print(volume.shape)
366
400
# print(volume.dtype)
367
-
368
- load_transforms = Compose (
369
- [
370
- ToTensor (),
371
- # anisotropic_transform,
372
- AddChannel (),
373
- SpatialPad (spatial_size = pad ),
374
- AddChannel (),
375
- EnsureType (),
376
- ],
377
- map_items = False ,
378
- log_stats = True ,
379
- )
401
+ if self .use_window :
402
+ load_transforms = Compose (
403
+ [
404
+ ToTensor (),
405
+ # anisotropic_transform,
406
+ AddChannel (),
407
+ # SpatialPad(spatial_size=pad),
408
+ AddChannel (),
409
+ EnsureType (),
410
+ ],
411
+ map_items = False ,
412
+ log_stats = True ,
413
+ )
414
+ else :
415
+ load_transforms = Compose (
416
+ [
417
+ ToTensor (),
418
+ # anisotropic_transform,
419
+ AddChannel (),
420
+ SpatialPad (spatial_size = pad ),
421
+ AddChannel (),
422
+ EnsureType (),
423
+ ],
424
+ map_items = False ,
425
+ log_stats = True ,
426
+ )
380
427
381
428
self .log ("\n Loading dataset..." )
382
429
input_image = load_transforms (volume )
@@ -405,8 +452,10 @@ def model_output(
405
452
406
453
if self .use_window :
407
454
window_size = self .window_infer_size
455
+ window_overlap = self .window_overlap_percentage
408
456
else :
409
457
window_size = None
458
+ window_overlap = 0.25
410
459
411
460
outputs = sliding_window_inference (
412
461
inputs ,
@@ -415,6 +464,7 @@ def model_output(
415
464
predictor = model_output ,
416
465
sw_device = self .device ,
417
466
device = dataset_device ,
467
+ overlap = window_overlap ,
418
468
)
419
469
420
470
out = outputs .detach ().cpu ()
@@ -508,13 +558,12 @@ def save_image(
508
558
)
509
559
510
560
imwrite (file_path , image )
561
+ filename = os .path .split (file_path )[1 ]
511
562
512
563
if from_layer :
513
- self .log (f"\n Layer prediction saved as :" )
564
+ self .log (f"\n Layer prediction saved as : { filename } " )
514
565
else :
515
- self .log (f"\n File n°{ i + 1 } saved as :" )
516
- filename = os .path .split (file_path )[1 ]
517
- self .log (filename )
566
+ self .log (f"\n File n°{ i + 1 } saved as : { filename } " )
518
567
519
568
def aniso_transform (self , image ):
520
569
zoom = self .transforms ["zoom" ][1 ]
@@ -630,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):
630
679
631
680
self .save_image (out , from_layer = True )
632
681
633
- instance_labels , data_dict = self .get_instance_result (out ,from_layer = True )
682
+ instance_labels , data_dict = self .get_instance_result (
683
+ out , from_layer = True
684
+ )
634
685
635
- return self .create_result_dict (out , instance_labels , from_layer = True , data_dict = data_dict )
686
+ return self .create_result_dict (
687
+ out , instance_labels , from_layer = True , data_dict = data_dict
688
+ )
636
689
637
690
def inference (self ):
638
691
"""
@@ -674,29 +727,27 @@ def inference(self):
674
727
torch .set_num_threads (1 ) # required for threading on macOS ?
675
728
self .log ("Number of threads has been set to 1 for macOS" )
676
729
677
- # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
678
- # torch.backends.cudnn.benchmark = False
679
-
680
- # TODO : better solution than loading first image always ?
681
- # data_check = LoadImaged(keys=["image"])(images_dict[0])
682
- # print(data)
683
- # check = data_check["image"].shape
684
- # print(check)
685
-
686
730
try :
687
- dims = self .model_dict ["segres_size" ]
731
+ dims = self .model_dict ["model_input_size" ]
732
+ self .log (f"MODEL DIMS : { dims } " )
733
+ self .log (self .model_dict ["name" ])
688
734
689
- model = self .model_dict ["class" ].get_net ()
690
735
if self .model_dict ["name" ] == "SegResNet" :
691
- model = self .model_dict ["class" ].get_net ()(
736
+ model = self .model_dict ["class" ].get_net (
692
737
input_image_size = [
693
738
dims ,
694
739
dims ,
695
740
dims ,
696
741
], # TODO FIX ! find a better way & remove model-specific code
697
- out_channels = 1 ,
698
- # dropout_prob=0.3,
699
742
)
743
+ elif self .model_dict ["name" ] == "SwinUNetR" :
744
+ model = self .model_dict ["class" ].get_net (
745
+ img_size = [dims , dims , dims ],
746
+ use_checkpoint = False ,
747
+ )
748
+ else :
749
+ model = self .model_dict ["class" ].get_net ()
750
+ model = model .to (self .device )
700
751
701
752
self .log_parameters ()
702
753
@@ -722,10 +773,7 @@ def inference(self):
722
773
AsDiscrete (threshold = t ), EnsureType ()
723
774
)
724
775
725
-
726
- self .log (
727
- "\n Loading weights..."
728
- )
776
+ self .log ("\n Loading weights..." )
729
777
730
778
if self .weights_dict ["custom" ]:
731
779
weights = self .weights_dict ["path" ]
@@ -1022,11 +1070,21 @@ def train(self):
1022
1070
else :
1023
1071
size = check
1024
1072
print (f"Size of image : { size } " )
1025
- model = model_class .get_net ()(
1073
+ model = model_class .get_net (
1026
1074
input_image_size = utils .get_padding_dim (size ),
1027
1075
out_channels = 1 ,
1028
1076
dropout_prob = 0.3 ,
1029
1077
)
1078
+ elif model_name == "SwinUNetR" :
1079
+ if self .sampling :
1080
+ size = self .sample_size
1081
+ else :
1082
+ size = check
1083
+ print (f"Size of image : { size } " )
1084
+ model = model_class .get_net (
1085
+ img_size = utils .get_padding_dim (size ),
1086
+ use_checkpoint = True ,
1087
+ )
1030
1088
else :
1031
1089
model = model_class .get_net () # get an instance of the model
1032
1090
model = model .to (self .device )
0 commit comments