-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
124 lines (104 loc) · 3.39 KB
/
train.py
File metadata and controls
124 lines (104 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import importlib
import tensorflow as tf
from keras import Model, ops
from keras import optimizers
from utils.losses import weighted_dice_loss
from keras.metrics import OneHotIoU
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from utils.callbacks import cosine_annealing_with_warmup
import backbones
from configs import config, update_config
def parse_args():
parser = argparse.ArgumentParser(description="Train FOMO network")
parser.add_argument(
"--cfg",
help="experiment configure file name",
default="configs/mff/mff_mobilenetv2.yaml",
type=str,
)
args = parser.parse_args()
update_config(config, args)
return args
def get_model_by_name(model_name):
weights = config.TRAIN.BEST_SAVE_PATH if config.TRAIN.RESUME else "imagenet"
if model_name == "mobilenetv2":
model = backbones.MobileFOMOv2(
config.TRAIN.IMAGE_SIZE, config.MODEL.ALPHA, config.DATASET.NUM_CLASSES, weights
)
elif model_name == "squeezenet":
model = backbones.SqueezeFOMO(
config.TRAIN.IMAGE_SIZE, config.DATASET.NUM_CLASSES
)
elif model_name == "mobilenetv3":
model = backbones.MobileFOMOv3(
config.TRAIN.IMAGE_SIZE, config.MODEL.ALPHA, config.DATASET.NUM_CLASSES, weights
)
elif model_name == "mobilevit":
model = backbones.MobileFOMOViT(
config.TRAIN.IMAGE_SIZE, config.DATASET.NUM_CLASSES
)
else:
print("Invalid model name or model not implemented yet!")
raise NotImplementedError
return model
def main() -> Model:
args = parse_args()
dataloader_module = importlib.import_module("dataloaders")
DatasetClass = getattr(dataloader_module, config.DATASET.NAME)
train_ds = DatasetClass(
config,
config.DATASET.TRAIN_SET,
augment=True,
shuffle=True,
workers=4,
use_multiprocessing=True,
)
val_ds = DatasetClass(
config,
config.DATASET.VALIDATION_SET,
augment=False,
workers=4,
use_multiprocessing=True,
)
if hasattr(train_ds, "get_dataset"):
train_ds = train_ds.get_dataset()
val_ds = val_ds.get_dataset()
model: Model = get_model_by_name(config.MODEL.BACKBONE.lower())
loss_fn = weighted_dice_loss(config.TRAIN.CLASS_WEIGHTS)
optim: optimizers.Optimizer = optimizers.get(config.TRAIN.OPTIMIZER)
optim.learning_rate = config.TRAIN.LR
model.compile(
loss=loss_fn,
optimizer=optim,
metrics=[OneHotIoU(config.DATASET.NUM_CLASSES, range(1, config.DATASET.NUM_CLASSES), "iou")],
)
callbacks = [
ModelCheckpoint(
config.TRAIN.BEST_SAVE_PATH,
monitor="val_iou",
mode="max",
save_best_only=True,
verbose=1,
),
LearningRateScheduler(
lambda epoch, lr: cosine_annealing_with_warmup(
epoch,
lr,
total_epochs=config.TRAIN.NUM_EPOCHS,
warmup_epochs=5,
min_lr=1e-6,
max_lr = config.TRAIN.LR,
)
),
]
history = model.fit(
train_ds,
batch_size=config.TRAIN.BATCH_SIZE,
callbacks=callbacks,
epochs=config.TRAIN.NUM_EPOCHS,
verbose=1,
validation_data=val_ds,
)
if __name__ == "__main__":
main()