Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit aee16b4

Browse files
author
Ryan Sepassi
committed
bug fixes
1 parent cbdb75d commit aee16b4

File tree

6 files changed

+42
-19
lines changed

6 files changed

+42
-19
lines changed

README.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,53 @@ issues](https://github.com/tensorflow/tensor2tensor/issues).
1717
```
1818
pip install tensor2tensor
1919
20-
DATA_DIR=$HOME/data
20+
DATA_DIR=$HOME/t2t_data
21+
TMP_DIR=/tmp/t2t_datagen
2122
PROBLEM=wmt_ende_tokens_32k
2223
MODEL=transformer
2324
HPARAMS=transformer_base
24-
TRAIN_DIR=$HOME/train
25+
TRAIN_DIR=$HOME/t2t_train/$PROBLEM_$MODEL_$HPARAMS
26+
27+
mkdir $DATA_DIR $TMP_DIR $TRAIN_DIR
2528
2629
# Generate data
2730
t2t-datagen \
2831
--data_dir=$DATA_DIR \
32+
--tmp_dir=$TMP_DIR \
2933
--problem=$PROBLEM
3034
35+
mv $TMP_DIR/tokens.vocab.32768 $DATA_DIR
36+
3137
# Train
3238
t2t-trainer \
3339
--data_dir=$DATA_DIR \
3440
--problems=$PROBLEM \
3541
--model=$MODEL \
3642
--hparams_set=$HPARAMS \
37-
--output_dir=$TRAIN_DIR \
43+
--output_dir=$TRAIN_DIR
3844
3945
# Decode
46+
47+
DECODE_FILE=$DATA_DIR/decode_this.txt
48+
echo "Hello world" >> $DECODE_FILE
49+
echo "Goodbye world" >> $DECODE_FILE
50+
51+
BEAM_SIZE=4
52+
ALPHA=0.6
53+
4054
t2t-trainer \
4155
--data_dir=$DATA_DIR \
4256
--problems=$PROBLEM \
4357
--model=$MODEL \
4458
--hparams_set=$HPARAMS \
4559
--output_dir=$TRAIN_DIR \
46-
--decode_from_file=$DATA_DIR/decode_this.txt
60+
--train_steps=0 \
61+
--eval_steps=0 \
62+
--beam_size=$BEAM_SIZE \
63+
--alpha=$ALPHA \
64+
--decode_from_file=$DECODE_FILE
65+
66+
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
4767
```
4868

4969
T2T modularizes training into several components, each of which can be seen in

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='tensor2tensor',
7-
version='1.0.1.dev1',
7+
version='1.0.2',
88
description='Tensor2Tensor',
99
author='Google Inc.',
1010
author_email='[email protected]',
@@ -18,12 +18,11 @@
1818
'six',
1919
'tensorflow-gpu>=1.2.0rc1',
2020
],
21-
classifiers = [
21+
classifiers=[
2222
'Development Status :: 4 - Beta',
2323
'Intended Audience :: Developers',
2424
'Intended Audience :: Science/Research',
2525
'License :: OSI Approved :: Apache Software License',
2626
'Topic :: Scientific/Engineering :: Artificial Intelligence',
2727
],
28-
keywords='tensorflow',
29-
)
28+
keywords='tensorflow',)

tensor2tensor/bin/t2t-datagen

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ flags = tf.flags
4747
FLAGS = flags.FLAGS
4848

4949
flags.DEFINE_string("data_dir", "", "Data directory.")
50-
flags.DEFINE_string("tmp_dir",
51-
tempfile.gettempdir(), "Temporary storage directory.")
50+
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
51+
"Temporary storage directory.")
5252
flags.DEFINE_string("problem", "",
5353
"The name of the problem to generate data for.")
5454
flags.DEFINE_integer("num_shards", 1, "How many shards to use.")

tensor2tensor/data_generators/text_encoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ def subtoken_to_subtoken_string(self, subtoken):
229229
self._all_subtoken_strings[subtoken]):
230230
return self._all_subtoken_strings[subtoken]
231231
else:
232-
return 'ID%d_' % subtoken
232+
if 0 <= subtoken < self._num_reserved_ids:
233+
return '%s_' % RESERVED_TOKENS[subtoken]
234+
else:
235+
return 'ID%d_' % subtoken
233236

234237
def _escaped_token_to_subtokens(self, escaped_token):
235238
"""Converts an escaped token string to a list of subtokens.

tensor2tensor/data_generators/wmt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,19 @@ def _compile_data(tmp_dir, datasets, filename):
203203
with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_file:
204204
i = 0
205205
while i <= len(lang1_lines):
206-
lang1_file.writelines(
207-
lang1_lines[i * write_chunk_size:(i + 1) * write_chunk_size])
206+
for line in lang1_lines[i * write_chunk_size:(i + 1) * write_chunk_size]:
207+
lang1_file.write(line)
208208
i += 1
209-
lang1_file.writelines(lang1_lines[i * write_chunk_size:])
209+
for line in lang1_lines[i * write_chunk_size:]:
210+
lang1_file.write(line)
210211
with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_file:
211212
i = 0
212213
while i <= len(lang2_lines):
213-
lang2_file.writelines(
214-
lang2_lines[i * write_chunk_size:(i + 1) * write_chunk_size])
214+
for line in lang2_lines[i * write_chunk_size:(i + 1) * write_chunk_size]:
215+
lang2_file.write(line)
215216
i += 1
216-
lang2_file.writelines(lang2_lines[i * write_chunk_size:])
217+
for line in lang2_lines[i * write_chunk_size:]:
218+
lang2_file.write(line)
217219
return filename
218220

219221

tensor2tensor/utils/trainer_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,7 @@ def log_fn(inputs, outputs):
656656
base_filename = FLAGS.decode_from_file
657657
decode_filename = (
658658
base_filename + "." + FLAGS.model + "." + FLAGS.hparams_set + ".beam" +
659-
str(FLAGS.beam_size) + ".a" + str(FLAGS.alpha) + ".alpha" +
660-
str(FLAGS.alpha) + ".decodes")
659+
str(FLAGS.beam_size) + ".alpha" + str(FLAGS.alpha) + ".decodes")
661660
tf.logging.info("Writing decodes into %s" % decode_filename)
662661
outfile = tf.gfile.Open(decode_filename, "w")
663662
for index in range(len(sorted_inputs)):

0 commit comments

Comments
 (0)