@@ -66,8 +66,9 @@ def generator(labels, batch_size=1):
66
66
p .add_argument ('--epochs' , type = int , default = 10000 , help = 'training epochs' )
67
67
p .add_argument ('--batch-size' , type = int , default = 1 , help = 'batch size' )
68
68
p .add_argument ('--init-epoch' , type = int , default = 0 , help = 'initial epoch number' )
69
- p .add_argument ('--init-weights ' , help = 'weights file to initialize model with' )
69
+ p .add_argument ('--init-joint ' , help = 'weights file to initialize joint model with' )
70
70
p .add_argument ('--init-affine' , help = 'weights file to initialize affine submodel with' )
71
+ p .add_argument ('--init-deform' , help = 'weights file to initialize deformable submodel with' )
71
72
p .add_argument ('--save-freq' , type = int , default = 100 , help = 'epochs between model saves' )
72
73
p .add_argument ('--lr' , type = float , default = 1e-5 , help = 'learning rate' )
73
74
p .add_argument ('--loss-mult' , type = float , default = 10 , help = 'similarity-loss weight' )
@@ -202,12 +203,15 @@ def call(self, x):
202
203
203
204
204
205
# initialization
205
- if arg .init_weights :
206
- model .load_weights (arg .init_weights )
206
+ if arg .init_joint :
207
+ model .load_weights (arg .init_joint )
207
208
208
209
if arg .init_affine :
209
210
model_aff .load_weights (arg .init_affine )
210
211
212
+ if arg .init_deform :
213
+ model_def .load_weights (arg .init_deform )
214
+
211
215
212
216
# training
213
217
model .fit (
0 commit comments