Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 01369c9

Browse files
authored
Merge pull request #694 from SkyAndCloud/raw
add support for specifying checkpoint_path in t2t-decoder interactive mode
2 parents c8709d1 + 74e419b commit 01369c9

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

tensor2tensor/bin/t2t_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create_decode_hparams():
8282

8383
def decode(estimator, hparams, decode_hp):
8484
if FLAGS.decode_interactive:
85-
decoding.decode_interactively(estimator, hparams, decode_hp)
85+
decoding.decode_interactively(estimator, hparams, decode_hp, checkpoint_path=FLAGS.checkpoint_path)
8686
elif FLAGS.decode_from_file:
8787
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
8888
decode_hp, FLAGS.decode_to_file,

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def create_run_config(hp):
178178
save_ckpt_secs = FLAGS.save_checkpoints_secs or None
179179
if save_ckpt_secs:
180180
save_ckpt_steps = None
181-
assert FLAGS.output_dir
181+
assert FLAGS.output_dir or FLAGS.checkpoint_path
182182
return trainer_lib.create_run_config(
183183
model_dir=os.path.expanduser(FLAGS.output_dir),
184184
master=FLAGS.master,

tensor2tensor/utils/decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def input_fn():
343343
return input_fn
344344

345345

346-
def decode_interactively(estimator, hparams, decode_hp):
346+
def decode_interactively(estimator, hparams, decode_hp, checkpoint_path=None):
347347
"""Interactive decoding."""
348348

349349
def input_fn():
@@ -353,7 +353,7 @@ def input_fn():
353353
example = _interactive_input_tensor_to_features_dict(example, hparams)
354354
return example
355355

356-
result_iter = estimator.predict(input_fn)
356+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
357357
for result in result_iter:
358358
problem_idx = result["problem_choice"]
359359
is_image = False # TODO(lukaszkaiser): find out from problem id / class.

0 commit comments

Comments
 (0)