Skip to content

Commit 923a37d

Browse files
committed
Initialize only the deformable submodel.
1 parent 2cd7061 commit 923a37d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

scripts/tf/train_synthmorph_joint.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def generator(labels, batch_size=1):
6666
p.add_argument('--epochs', type=int, default=10000, help='training epochs')
6767
p.add_argument('--batch-size', type=int, default=1, help='batch size')
6868
p.add_argument('--init-epoch', type=int, default=0, help='initial epoch number')
69-
p.add_argument('--init-weights', help='weights file to initialize model with')
69+
p.add_argument('--init-joint', help='weights file to initialize joint model with')
7070
p.add_argument('--init-affine', help='weights file to initialize affine submodel with')
71+
p.add_argument('--init-deform', help='weights file to initialize deformable submodel with')
7172
p.add_argument('--save-freq', type=int, default=100, help='epochs between model saves')
7273
p.add_argument('--lr', type=float, default=1e-5, help='learning rate')
7374
p.add_argument('--loss-mult', type=float, default=10, help='similarity-loss weight')
@@ -202,12 +203,15 @@ def call(self, x):
202203

203204

204205
# initialization
205-
if arg.init_weights:
206-
model.load_weights(arg.init_weights)
206+
if arg.init_joint:
207+
model.load_weights(arg.init_joint)
207208

208209
if arg.init_affine:
209210
model_aff.load_weights(arg.init_affine)
210211

212+
if arg.init_deform:
213+
model_def.load_weights(arg.init_deform)
214+
211215

212216
# training
213217
model.fit(

0 commit comments

Comments
 (0)