3
3
from pathlib import Path
4
4
import importlib .util
5
5
from typing import Optional
6
+ import warnings
6
7
7
8
import numpy as np
8
9
from tifffile import imwrite
65
66
Path ("/models/pretrained" )
66
67
)
67
68
69
+
68
70
class WeightsDownloader :
71
+ """A utility class the downloads the weights of a model when needed."""
72
+
73
+ def __init__ (self , log_widget : Optional [log_utility .Log ] = None ):
74
+ """
75
+ Creates a WeightsDownloader, optionally with a log widget to display the progress.
69
76
70
- def __init__ (self , log_widget : Optional [log_utility .Log ]= None ):
77
+ Args:
78
+ log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
79
+ """
71
80
self .log_widget = log_widget
72
81
73
- def download_weights (self ,model_name : str ):
82
+ def download_weights (self , model_name : str , model_weights_filename : str ):
74
83
"""
75
- Downloads a specific pretrained model.
76
- This code is adapted from DeepLabCut with permission from MWMathis.
84
+ Downloads a specific pretrained model.
85
+ This code is adapted from DeepLabCut with permission from MWMathis.
77
86
78
- Args:
79
- model_name (str): name of the model to download
80
- """
87
+ Args:
88
+ model_name (str): name of the model to download
89
+ model_weights_filename (str): name of the .pth file expected for the model
90
+ """
81
91
import json
82
92
import tarfile
83
93
import urllib .request
@@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
94
104
json_path = os .path .join (
95
105
pretrained_folder_path , "pretrained_model_urls.json"
96
106
)
107
+
108
+ check_path = os .path .join (
109
+ pretrained_folder_path , model_weights_filename
110
+ )
111
+ if os .path .exists (check_path ):
112
+ message = f"Weight file { model_weights_filename } already exists, skipping download step"
113
+ if self .log_widget is not None :
114
+ self .log_widget .print_and_log (message , printing = False )
115
+ print (message )
116
+ return
117
+
97
118
with open (json_path ) as f :
98
119
neturls = json .load (f )
99
120
if model_name in neturls .keys ():
@@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
107
128
pbar = tqdm (unit = "B" , total = total_size , position = 0 )
108
129
else :
109
130
self .log_widget .print_and_log (start_message )
110
- pbar = tqdm (unit = "B" , total = total_size , position = 0 , file = self .log_widget )
131
+ pbar = tqdm (
132
+ unit = "B" ,
133
+ total = total_size ,
134
+ position = 0 ,
135
+ file = self .log_widget ,
136
+ )
111
137
112
- filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
138
+ filename , _ = urllib .request .urlretrieve (
139
+ url , reporthook = show_progress
140
+ )
113
141
with tarfile .open (filename , mode = "r:gz" ) as tar :
114
142
tar .extractall (pretrained_folder_path )
115
143
else :
@@ -121,10 +149,12 @@ def show_progress(count, block_size, total_size):
121
149
class LogSignal (WorkerBaseSignals ):
122
150
"""Signal to send messages to be logged from another thread.
123
151
124
- Separate from Worker instances as indicated `here`_"""
152
+ Separate from Worker instances as indicated `here`_""" # TODO link ?
125
153
126
154
log_signal = Signal (str )
127
155
"""qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
156
+ warn_signal = Signal (str )
157
+ """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread"""
128
158
129
159
# Should not be an instance variable but a class variable, not defined in __init__, see
130
160
# https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
@@ -185,6 +215,7 @@ def __init__(
185
215
super ().__init__ (self .inference )
186
216
self ._signals = LogSignal () # add custom signals
187
217
self .log_signal = self ._signals .log_signal
218
+ self .warn_signal = self ._signals .warn_signal
188
219
###########################################
189
220
###########################################
190
221
self .device = device
@@ -204,7 +235,6 @@ def __init__(
204
235
self .downloader = WeightsDownloader ()
205
236
"""Download utility"""
206
237
207
-
208
238
@staticmethod
209
239
def create_inference_dict (images_filepaths ):
210
240
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -225,6 +255,10 @@ def log(self, text):
225
255
"""
226
256
self .log_signal .emit (text )
227
257
258
+ def warn (self , warning ):
259
+ """Sends a warning to main thread"""
260
+ self .warn_signal .emit (warning )
261
+
228
262
def log_parameters (self ):
229
263
230
264
self .log ("-" * 20 )
@@ -297,7 +331,7 @@ def inference(self):
297
331
sys = platform .system ()
298
332
print (f"OS is { sys } " )
299
333
if sys == "Darwin" :
300
- torch .set_num_threads (1 ) # required for threading on macOS ?
334
+ torch .set_num_threads (1 ) # required for threading on macOS ?
301
335
self .log ("Number of threads has been set to 1 for macOS" )
302
336
303
337
images_dict = self .create_inference_dict (self .images_filepaths )
@@ -323,7 +357,11 @@ def inference(self):
323
357
model = self .model_dict ["class" ].get_net ()
324
358
if self .model_dict ["name" ] == "SegResNet" :
325
359
model = self .model_dict ["class" ].get_net ()(
326
- input_image_size = [dims , dims , dims ], # TODO FIX ! find a better way & remove model-specific code
360
+ input_image_size = [
361
+ dims ,
362
+ dims ,
363
+ dims ,
364
+ ], # TODO FIX ! find a better way & remove model-specific code
327
365
out_channels = 1 ,
328
366
# dropout_prob=0.3,
329
367
)
@@ -372,8 +410,13 @@ def inference(self):
372
410
if self .weights_dict ["custom" ]:
373
411
weights = self .weights_dict ["path" ]
374
412
else :
375
- self .downloader .download_weights (self .model_dict ["name" ])
376
- weights = os .path .join (WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ())
413
+ self .downloader .download_weights (
414
+ self .model_dict ["name" ],
415
+ self .model_dict ["class" ].get_weights_file (),
416
+ )
417
+ weights = os .path .join (
418
+ WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ()
419
+ )
377
420
378
421
model .load_state_dict (
379
422
torch .load (
@@ -611,7 +654,10 @@ def __init__(
611
654
super ().__init__ (self .train )
612
655
self ._signals = LogSignal ()
613
656
self .log_signal = self ._signals .log_signal
657
+ self .warn_signal = self ._signals .warn_signal
614
658
659
+ self ._weight_error = False
660
+ #############################################
615
661
self .device = device
616
662
self .model_dict = model_dict
617
663
self .weights_path = weights_path
@@ -633,7 +679,7 @@ def __init__(
633
679
634
680
self .train_files = []
635
681
self .val_files = []
636
-
682
+ #######################################
637
683
self .downloader = WeightsDownloader ()
638
684
639
685
def set_download_log (self , widget ):
@@ -647,6 +693,10 @@ def log(self, text):
647
693
"""
648
694
self .log_signal .emit (text )
649
695
696
+ def warn (self , warning ):
697
+ """Sends a warning to main thread"""
698
+ self .warn_signal .emit (warning )
699
+
650
700
def log_parameters (self ):
651
701
652
702
self .log ("-" * 20 )
@@ -690,6 +740,13 @@ def log_parameters(self):
690
740
691
741
if self .weights_path is not None :
692
742
self .log (f"Using weights from : { self .weights_path } " )
743
+ if self ._weight_error :
744
+ self .log (
745
+ ">>>>>>>>>>>>>>>>>\n "
746
+ "WARNING:\n Chosen weights were incompatible with the model,\n "
747
+ "the model will be trained from random weights\n "
748
+ "<<<<<<<<<<<<<<<<<\n "
749
+ )
693
750
694
751
# self.log("\n")
695
752
self .log ("-" * 20 )
@@ -904,18 +961,27 @@ def train(self):
904
961
if self .weights_path is not None :
905
962
if self .weights_path == "use_pretrained" :
906
963
weights_file = model_class .get_weights_file ()
907
- self .downloader .download_weights (model_name )
964
+ self .downloader .download_weights (model_name , weights_file )
908
965
weights = os .path .join (WEIGHTS_DIR , weights_file )
909
966
self .weights_path = weights
910
967
else :
911
968
weights = os .path .join (self .weights_path )
912
969
913
- model .load_state_dict (
914
- torch .load (
915
- weights ,
916
- map_location = self .device ,
970
+ try :
971
+ model .load_state_dict (
972
+ torch .load (
973
+ weights ,
974
+ map_location = self .device ,
975
+ )
917
976
)
918
- )
977
+ except RuntimeError :
978
+ warn = (
979
+ "WARNING:\n It seems the weights were incompatible with the model,\n "
980
+ "the model will be trained from random weights"
981
+ )
982
+ self .log (warn )
983
+ self .warn (warn )
984
+ self ._weight_error = True
919
985
920
986
if self .device .type == "cuda" :
921
987
self .log ("\n Using GPU :" )
0 commit comments