Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ba2f9c8

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Add tpu hparams for img2img
PiperOrigin-RevId: 186704921
1 parent 5704e55 commit ba2f9c8

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tensor2tensor/models/image_transformer_2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,35 @@ def imagetransformer2d_tiny():
424424
return hparams
425425

426426

427+
def update_hparams_for_tpu(hparams):
428+
hparams.use_pad_remover = False # where op not supported
429+
hparams.optimizer = "TrueAdam"
430+
hparams.batch_size = 4
431+
432+
433+
@registry.register_hparams
434+
def img2mg_transformer_base_tpu():
435+
"""Hparams for training img2img_transformer on tpu."""
436+
hparams = img2img_transformer_base()
437+
update_hparams_for_tpu(hparams)
438+
hparams.batch_size = 4
439+
hparams.num_heads = 4 # heads are expensive on tpu
440+
hparams.num_decoder_layers = 8
441+
hparams.num_encoder_layers = 4
442+
hparams.shared_embedding_and_softmax_weights = False
443+
return hparams
444+
445+
446+
@registry.register_hparams
447+
def img2mg_transformer_tiny_tpu():
448+
hparams = img2mg_transformer_base_tpu()
449+
hparams.num_hidden_layers = 2
450+
hparams.hidden_size = 16
451+
hparams.batch_size = 2
452+
hparams.num_heads = 2
453+
return hparams
454+
455+
427456
@registry.register_hparams
428457
def img2img_transformer2d_n3():
429458
hparams = img2img_transformer2d_base()

0 commit comments

Comments
 (0)