This repository was archived by the owner on Jul 7, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change @@ -424,6 +424,35 @@ def imagetransformer2d_tiny():
424
424
return hparams
425
425
426
426
427
+ def update_hparams_for_tpu (hparams ):
428
+ hparams .use_pad_remover = False # where op not supported
429
+ hparams .optimizer = "TrueAdam"
430
+ hparams .batch_size = 4
431
+
432
+
433
+ @registry .register_hparams
434
+ def img2mg_transformer_base_tpu ():
435
+ """Hparams for training img2img_transformer on tpu."""
436
+ hparams = img2img_transformer_base ()
437
+ update_hparams_for_tpu (hparams )
438
+ hparams .batch_size = 4
439
+ hparams .num_heads = 4 # heads are expensive on tpu
440
+ hparams .num_decoder_layers = 8
441
+ hparams .num_encoder_layers = 4
442
+ hparams .shared_embedding_and_softmax_weights = False
443
+ return hparams
444
+
445
+
446
+ @registry .register_hparams
447
+ def img2mg_transformer_tiny_tpu ():
448
+ hparams = img2mg_transformer_base_tpu ()
449
+ hparams .num_hidden_layers = 2
450
+ hparams .hidden_size = 16
451
+ hparams .batch_size = 2
452
+ hparams .num_heads = 2
453
+ return hparams
454
+
455
+
427
456
@registry .register_hparams
428
457
def img2img_transformer2d_n3 ():
429
458
hparams = img2img_transformer2d_base ()
You can’t perform that action at this time.
0 commit comments