Skip to content

Commit ec8ab8b

Browse files
authored
Merge pull request #24 from AdaptiveMotorControlLab/cy/download-improvements
Improved pretrained weights usage
2 parents 1709528 + 487560d commit ec8ab8b

File tree

8 files changed

+107
-34
lines changed

8 files changed

+107
-34
lines changed

napari_cellseg3d/log_utility.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import threading
2+
import warnings
23

34
from qtpy import QtCore
45
from qtpy.QtGui import QTextCursor
@@ -28,13 +29,15 @@ def write(self, message):
2829
try:
2930
if not hasattr(self, "flag"):
3031
self.flag = False
31-
message = message.replace('\r', '').rstrip()
32+
message = message.replace("\r", "").rstrip()
3233
if message:
3334
method = "replace_last_line" if self.flag else "append"
34-
QtCore.QMetaObject.invokeMethod(self,
35-
method,
36-
QtCore.Qt.QueuedConnection,
37-
QtCore.Q_ARG(str, message))
35+
QtCore.QMetaObject.invokeMethod(
36+
self,
37+
method,
38+
QtCore.Qt.QueuedConnection,
39+
QtCore.Q_ARG(str, message),
40+
)
3841
self.flag = True
3942
else:
4043
self.flag = False
@@ -77,3 +80,10 @@ def print_and_log(self, text, printing=True):
7780
)
7881
finally:
7982
self.lock.release()
83+
84+
def warn(self, warning):
85+
self.lock.acquire()
86+
try:
87+
warnings.warn(warning)
88+
finally:
89+
self.lock.release()

napari_cellseg3d/model_workers.py

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import importlib.util
55
from typing import Optional
6+
import warnings
67

78
import numpy as np
89
from tifffile import imwrite
@@ -65,19 +66,28 @@
6566
Path("/models/pretrained")
6667
)
6768

69+
6870
class WeightsDownloader:
71+
"""A utility class the downloads the weights of a model when needed."""
72+
73+
def __init__(self, log_widget: Optional[log_utility.Log] = None):
74+
"""
75+
Creates a WeightsDownloader, optionally with a log widget to display the progress.
6976
70-
def __init__(self, log_widget: Optional[log_utility.Log]= None):
77+
Args:
78+
log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
79+
"""
7180
self.log_widget = log_widget
7281

73-
def download_weights(self,model_name: str):
82+
def download_weights(self, model_name: str, model_weights_filename: str):
7483
"""
75-
Downloads a specific pretrained model.
76-
This code is adapted from DeepLabCut with permission from MWMathis.
84+
Downloads a specific pretrained model.
85+
This code is adapted from DeepLabCut with permission from MWMathis.
7786
78-
Args:
79-
model_name (str): name of the model to download
80-
"""
87+
Args:
88+
model_name (str): name of the model to download
89+
model_weights_filename (str): name of the .pth file expected for the model
90+
"""
8191
import json
8292
import tarfile
8393
import urllib.request
@@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
94104
json_path = os.path.join(
95105
pretrained_folder_path, "pretrained_model_urls.json"
96106
)
107+
108+
check_path = os.path.join(
109+
pretrained_folder_path, model_weights_filename
110+
)
111+
if os.path.exists(check_path):
112+
message = f"Weight file {model_weights_filename} already exists, skipping download step"
113+
if self.log_widget is not None:
114+
self.log_widget.print_and_log(message, printing=False)
115+
print(message)
116+
return
117+
97118
with open(json_path) as f:
98119
neturls = json.load(f)
99120
if model_name in neturls.keys():
@@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
107128
pbar = tqdm(unit="B", total=total_size, position=0)
108129
else:
109130
self.log_widget.print_and_log(start_message)
110-
pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget)
131+
pbar = tqdm(
132+
unit="B",
133+
total=total_size,
134+
position=0,
135+
file=self.log_widget,
136+
)
111137

112-
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
138+
filename, _ = urllib.request.urlretrieve(
139+
url, reporthook=show_progress
140+
)
113141
with tarfile.open(filename, mode="r:gz") as tar:
114142
tar.extractall(pretrained_folder_path)
115143
else:
@@ -121,10 +149,12 @@ def show_progress(count, block_size, total_size):
121149
class LogSignal(WorkerBaseSignals):
122150
"""Signal to send messages to be logged from another thread.
123151
124-
Separate from Worker instances as indicated `here`_"""
152+
Separate from Worker instances as indicated `here`_""" # TODO link ?
125153

126154
log_signal = Signal(str)
127155
"""qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
156+
warn_signal = Signal(str)
157+
"""qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread"""
128158

129159
# Should not be an instance variable but a class variable, not defined in __init__, see
130160
# https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
@@ -185,6 +215,7 @@ def __init__(
185215
super().__init__(self.inference)
186216
self._signals = LogSignal() # add custom signals
187217
self.log_signal = self._signals.log_signal
218+
self.warn_signal = self._signals.warn_signal
188219
###########################################
189220
###########################################
190221
self.device = device
@@ -204,7 +235,6 @@ def __init__(
204235
self.downloader = WeightsDownloader()
205236
"""Download utility"""
206237

207-
208238
@staticmethod
209239
def create_inference_dict(images_filepaths):
210240
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -225,6 +255,10 @@ def log(self, text):
225255
"""
226256
self.log_signal.emit(text)
227257

258+
def warn(self, warning):
259+
"""Sends a warning to main thread"""
260+
self.warn_signal.emit(warning)
261+
228262
def log_parameters(self):
229263

230264
self.log("-" * 20)
@@ -297,7 +331,7 @@ def inference(self):
297331
sys = platform.system()
298332
print(f"OS is {sys}")
299333
if sys == "Darwin":
300-
torch.set_num_threads(1) # required for threading on macOS ?
334+
torch.set_num_threads(1) # required for threading on macOS ?
301335
self.log("Number of threads has been set to 1 for macOS")
302336

303337
images_dict = self.create_inference_dict(self.images_filepaths)
@@ -323,7 +357,11 @@ def inference(self):
323357
model = self.model_dict["class"].get_net()
324358
if self.model_dict["name"] == "SegResNet":
325359
model = self.model_dict["class"].get_net()(
326-
input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code
360+
input_image_size=[
361+
dims,
362+
dims,
363+
dims,
364+
], # TODO FIX ! find a better way & remove model-specific code
327365
out_channels=1,
328366
# dropout_prob=0.3,
329367
)
@@ -372,8 +410,13 @@ def inference(self):
372410
if self.weights_dict["custom"]:
373411
weights = self.weights_dict["path"]
374412
else:
375-
self.downloader.download_weights(self.model_dict["name"])
376-
weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file())
413+
self.downloader.download_weights(
414+
self.model_dict["name"],
415+
self.model_dict["class"].get_weights_file(),
416+
)
417+
weights = os.path.join(
418+
WEIGHTS_DIR, self.model_dict["class"].get_weights_file()
419+
)
377420

378421
model.load_state_dict(
379422
torch.load(
@@ -611,7 +654,10 @@ def __init__(
611654
super().__init__(self.train)
612655
self._signals = LogSignal()
613656
self.log_signal = self._signals.log_signal
657+
self.warn_signal = self._signals.warn_signal
614658

659+
self._weight_error = False
660+
#############################################
615661
self.device = device
616662
self.model_dict = model_dict
617663
self.weights_path = weights_path
@@ -633,7 +679,7 @@ def __init__(
633679

634680
self.train_files = []
635681
self.val_files = []
636-
682+
#######################################
637683
self.downloader = WeightsDownloader()
638684

639685
def set_download_log(self, widget):
@@ -647,6 +693,10 @@ def log(self, text):
647693
"""
648694
self.log_signal.emit(text)
649695

696+
def warn(self, warning):
697+
"""Sends a warning to main thread"""
698+
self.warn_signal.emit(warning)
699+
650700
def log_parameters(self):
651701

652702
self.log("-" * 20)
@@ -690,6 +740,13 @@ def log_parameters(self):
690740

691741
if self.weights_path is not None:
692742
self.log(f"Using weights from : {self.weights_path}")
743+
if self._weight_error:
744+
self.log(
745+
">>>>>>>>>>>>>>>>>\n"
746+
"WARNING:\nChosen weights were incompatible with the model,\n"
747+
"the model will be trained from random weights\n"
748+
"<<<<<<<<<<<<<<<<<\n"
749+
)
693750

694751
# self.log("\n")
695752
self.log("-" * 20)
@@ -904,18 +961,27 @@ def train(self):
904961
if self.weights_path is not None:
905962
if self.weights_path == "use_pretrained":
906963
weights_file = model_class.get_weights_file()
907-
self.downloader.download_weights(model_name)
964+
self.downloader.download_weights(model_name, weights_file)
908965
weights = os.path.join(WEIGHTS_DIR, weights_file)
909966
self.weights_path = weights
910967
else:
911968
weights = os.path.join(self.weights_path)
912969

913-
model.load_state_dict(
914-
torch.load(
915-
weights,
916-
map_location=self.device,
970+
try:
971+
model.load_state_dict(
972+
torch.load(
973+
weights,
974+
map_location=self.device,
975+
)
917976
)
918-
)
977+
except RuntimeError:
978+
warn = (
979+
"WARNING:\nIt seems the weights were incompatible with the model,\n"
980+
"the model will be trained from random weights"
981+
)
982+
self.log(warn)
983+
self.warn(warn)
984+
self._weight_error = True
919985

920986
if self.device.type == "cuda":
921987
self.log("\nUsing GPU :")

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from monai.networks.nets import SegResNetVAE
22

33

4-
54
def get_net():
65
return SegResNetVAE
76

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from torch import nn
33

44

5-
65
def get_weights_file():
76
# model additionally trained on Mathis/Wyss mesoSPIM data
87
return "TRAILMAP.pth"
98
# FIXME currently incorrect, find good weights from TRAILMAP_test and upload them
109

10+
1111
def get_net():
1212
return TRAILMAP(1, 1)
1313

@@ -120,4 +120,3 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"):
120120
# nn.BatchNorm3d(out_ch),
121121
)
122122
return out
123-

napari_cellseg3d/models/model_VNet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from monai.networks.nets import VNet
33

44

5-
65
def get_net():
76
return VNet()
87

napari_cellseg3d/plugin_model_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def start(self):
599599

600600
self.worker.started.connect(self.on_start)
601601
self.worker.log_signal.connect(self.log.print_and_log)
602+
self.worker.warn_signal.connect(self.log.warn)
602603
self.worker.yielded.connect(yield_connect_show_res)
603604
self.worker.errored.connect(
604605
yield_connect_show_res

napari_cellseg3d/plugin_model_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def start(self):
857857
[btn.setVisible(False) for btn in self.close_buttons]
858858

859859
self.worker.log_signal.connect(self.log.print_and_log)
860+
self.worker.warn_signal.connect(self.log.warn)
860861

861862
self.worker.started.connect(self.on_start)
862863

@@ -994,7 +995,7 @@ def make_csv(self):
994995
if len(self.loss_values) == 0 or self.loss_values is None:
995996
warnings.warn("No loss values to add to csv !")
996997
return
997-
998+
998999
self.df = pd.DataFrame(
9991000
{
10001001
"epoch": size_column,

napari_cellseg3d/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -978,5 +978,3 @@ def merge_imgs(imgs, original_image_shape):
978978

979979
print(merged_imgs.shape)
980980
return merged_imgs
981-
982-

0 commit comments

Comments
 (0)