Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c022afd

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Work on generators: improve EnCs, add large EnFr and OCR test; LSTM corrections.
PiperOrigin-RevId: 174403513
1 parent 9afc190 commit c022afd

File tree

6 files changed

+242
-221
lines changed

6 files changed

+242
-221
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
from collections import defaultdict
2323
import gzip
24-
import io
2524
import os
2625
import random
26+
import stat
2727
import tarfile
2828

2929
# Dependency imports
@@ -258,8 +258,11 @@ def gunzip_file(gz_path, new_path):
258258
tf.logging.info("File %s already exists, skipping unpacking" % new_path)
259259
return
260260
tf.logging.info("Unpacking %s to %s" % (gz_path, new_path))
261+
# We may be unpacking into a newly created directory, add write mode.
262+
mode = stat.S_IRWXU or stat.S_IXGRP or stat.S_IRGRP or stat.S_IROTH
263+
os.chmod(os.path.dirname(new_path), mode)
261264
with gzip.open(gz_path, "rb") as gz_file:
262-
with io.open(new_path, "wb") as new_file:
265+
with tf.gfile.GFile(new_path, mode="wb") as new_file:
263266
for line in gz_file:
264267
new_file.write(line)
265268

tensor2tensor/data_generators/image.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import json
2525
import os
2626
import random
27+
import struct
2728
import tarfile
2829
import zipfile
2930

@@ -925,3 +926,58 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
925926
@property
926927
def targeted_vocab_size(self):
927928
return 2**15 # 32768
929+
930+
931+
@registry.register_problem
932+
class OcrTest(Image2TextProblem):
933+
"""OCR test problem."""
934+
935+
@property
936+
def is_small(self):
937+
return True
938+
939+
@property
940+
def is_character_level(self):
941+
return True
942+
943+
@property
944+
def target_space_id(self):
945+
return problem.SpaceID.EN_CHR
946+
947+
@property
948+
def train_shards(self):
949+
return 1
950+
951+
@property
952+
def dev_shards(self):
953+
return 1
954+
955+
def preprocess_example(self, example, mode, _):
956+
# Resize from usual size ~1350x60 to 90x4 in this test.
957+
img = example["inputs"]
958+
example["inputs"] = tf.to_int64(
959+
tf.image.resize_images(img, [90, 4], tf.image.ResizeMethod.AREA))
960+
return example
961+
962+
def generator(self, data_dir, tmp_dir, is_training):
963+
# In this test problem, we assume that the data is in tmp_dir/ocr/ in
964+
# files names 0.png, 0.txt, 1.png, 1.txt and so on until num_examples.
965+
num_examples = 2
966+
ocr_dir = os.path.join(tmp_dir, "ocr/")
967+
tf.logging.info("Looking for OCR data in %s." % ocr_dir)
968+
for i in xrange(num_examples):
969+
image_filepath = os.path.join(ocr_dir, "%d.png" % i)
970+
text_filepath = os.path.join(ocr_dir, "%d.txt" % i)
971+
with tf.gfile.Open(text_filepath, "rb") as f:
972+
label = f.read()
973+
with tf.gfile.Open(image_filepath, "rb") as f:
974+
encoded_image_data = f.read()
975+
# In PNG files width and height are stored in these bytes.
976+
width, height = struct.unpack(">ii", encoded_image_data[16:24])
977+
yield {
978+
"image/encoded": [encoded_image_data],
979+
"image/format": ["png"],
980+
"image/class/label": label.strip(),
981+
"image/height": [height],
982+
"image/width": [width]
983+
}

tensor2tensor/data_generators/translate_enfr.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,50 +34,54 @@
3434
# End-of-sentence marker.
3535
EOS = text_encoder.EOS_ID
3636

37-
_ENFR_TRAIN_DATASETS = [
37+
_ENFR_TRAIN_SMALL_DATA = [
3838
[
3939
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
4040
("baseline-1M-enfr/baseline-1M_train.en",
4141
"baseline-1M-enfr/baseline-1M_train.fr")
4242
],
43-
# [
44-
# "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
45-
# ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
46-
# ],
47-
# [
48-
# "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
49-
# ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
50-
# ],
51-
# [
52-
# "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
53-
# ("training/news-commentary-v9.fr-en.en",
54-
# "training/news-commentary-v9.fr-en.fr")
55-
# ],
56-
# [
57-
# "http://www.statmt.org/wmt10/training-giga-fren.tar",
58-
# ("giga-fren.release2.fixed.en.gz",
59-
# "giga-fren.release2.fixed.fr.gz")
60-
# ],
61-
# [
62-
# "http://www.statmt.org/wmt13/training-parallel-un.tgz",
63-
# ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
64-
# ],
6543
]
66-
_ENFR_TEST_DATASETS = [
44+
_ENFR_TEST_SMALL_DATA = [
6745
[
6846
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
6947
("baseline-1M-enfr/baseline-1M_valid.en",
7048
"baseline-1M-enfr/baseline-1M_valid.fr")
7149
],
72-
# [
73-
# "http://data.statmt.org/wmt17/translation-task/dev.tgz",
74-
# ("dev/newstest2013.en", "dev/newstest2013.fr")
75-
# ],
50+
]
51+
_ENFR_TRAIN_LARGE_DATA = [
52+
[
53+
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
54+
("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
55+
],
56+
[
57+
"http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
58+
("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
59+
],
60+
[
61+
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
62+
("training/news-commentary-v9.fr-en.en",
63+
"training/news-commentary-v9.fr-en.fr")
64+
],
65+
[
66+
"http://www.statmt.org/wmt10/training-giga-fren.tar",
67+
("giga-fren.release2.fixed.en.gz",
68+
"giga-fren.release2.fixed.fr.gz")
69+
],
70+
[
71+
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
72+
("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
73+
],
74+
]
75+
_ENFR_TEST_LARGE_DATA = [
76+
[
77+
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
78+
("dev/newstest2013.en", "dev/newstest2013.fr")
79+
],
7680
]
7781

7882

7983
@registry.register_problem
80-
class TranslateEnfrWmt8k(translate.TranslateProblem):
84+
class TranslateEnfrWmtSmall8k(translate.TranslateProblem):
8185
"""Problem spec for WMT En-Fr translation."""
8286

8387
@property
@@ -88,11 +92,18 @@ def targeted_vocab_size(self):
8892
def vocab_name(self):
8993
return "vocab.enfr"
9094

95+
@property
96+
def use_small_dataset(self):
97+
return True
98+
9199
def generator(self, data_dir, tmp_dir, train):
92100
symbolizer_vocab = generator_utils.get_or_generate_vocab(
93101
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
94-
_ENFR_TRAIN_DATASETS)
95-
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
102+
_ENFR_TRAIN_SMALL_DATA)
103+
if self.use_small_dataset:
104+
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
105+
else:
106+
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
96107
tag = "train" if train else "dev"
97108
data_path = translate.compile_data(tmp_dir, datasets,
98109
"wmt_enfr_tok_%s" % tag)
@@ -109,15 +120,31 @@ def target_space_id(self):
109120

110121

111122
@registry.register_problem
112-
class TranslateEnfrWmt32k(TranslateEnfrWmt8k):
123+
class TranslateEnfrWmtSmall32k(TranslateEnfrWmtSmall8k):
113124

114125
@property
115126
def targeted_vocab_size(self):
116127
return 2**15 # 32768
117128

118129

119130
@registry.register_problem
120-
class TranslateEnfrWmtCharacters(translate.TranslateProblem):
131+
class TranslateEnfrWmt8k(TranslateEnfrWmtSmall8k):
132+
133+
@property
134+
def use_small_dataset(self):
135+
return False
136+
137+
138+
@registry.register_problem
139+
class TranslateEnfrWmt32k(TranslateEnfrWmtSmall32k):
140+
141+
@property
142+
def use_small_dataset(self):
143+
return False
144+
145+
146+
@registry.register_problem
147+
class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem):
121148
"""Problem spec for WMT En-Fr translation."""
122149

123150
@property
@@ -130,7 +157,10 @@ def vocab_name(self):
130157

131158
def generator(self, data_dir, tmp_dir, train):
132159
character_vocab = text_encoder.ByteTextEncoder()
133-
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
160+
if self.use_small_dataset:
161+
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
162+
else:
163+
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
134164
tag = "train" if train else "dev"
135165
data_path = translate.compile_data(tmp_dir, datasets,
136166
"wmt_enfr_chr_%s" % tag)
@@ -144,3 +174,11 @@ def input_space_id(self):
144174
@property
145175
def target_space_id(self):
146176
return problem.SpaceID.FR_CHR
177+
178+
179+
@registry.register_problem
180+
class TranslateEnfrWmtCharacters(TranslateEnfrWmtSmallCharacters):
181+
182+
@property
183+
def use_small_dataset(self):
184+
return False

tensor2tensor/data_generators/translate_enzh.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,26 @@
3636
# End-of-sentence marker.
3737
EOS = text_encoder.EOS_ID
3838

39-
_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
39+
# End-of-sentence marker.
40+
EOS = text_encoder.EOS_ID
41+
42+
# This is far from being the real WMT17 task - only toyset here
43+
# you need to register to get UN data and CWT data. Also, by convention,
44+
# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
45+
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
4046
"training-parallel-nc-v12.tgz"),
41-
("training/news-commentary-v12.zh-en.zh",
42-
"training/news-commentary-v12.zh-en.en")]]
47+
("training/news-commentary-v12.zh-en.en",
48+
"training/news-commentary-v12.zh-en.zh")]]
4349

44-
_ZHEN_TEST_DATASETS = [[
50+
_ENZH_TEST_DATASETS = [[
4551
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
46-
("dev/newsdev2017-zhen-src.zh.sgm", "dev/newsdev2017-zhen-ref.en.sgm")
52+
("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm")
4753
]]
4854

4955

5056
@registry.register_problem
5157
class TranslateEnzhWmt8k(translate.TranslateProblem):
52-
"""Problem spec for WMT Zh-En translation."""
58+
"""Problem spec for WMT En-Zh translation."""
5359

5460
@property
5561
def targeted_vocab_size(self):
@@ -61,16 +67,16 @@ def num_shards(self):
6167

6268
@property
6369
def source_vocab_name(self):
64-
return "vocab.zhen-zh.%d" % self.targeted_vocab_size
70+
return "vocab.enzh-en.%d" % self.targeted_vocab_size
6571

6672
@property
6773
def target_vocab_name(self):
68-
return "vocab.zhen-en.%d" % self.targeted_vocab_size
74+
return "vocab.enzh-zh.%d" % self.targeted_vocab_size
6975

7076
def generator(self, data_dir, tmp_dir, train):
71-
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
72-
source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS]
73-
target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS]
77+
datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
78+
source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS]
79+
target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS]
7480
source_vocab = generator_utils.get_or_generate_vocab(
7581
data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size,
7682
source_datasets)
@@ -79,21 +85,18 @@ def generator(self, data_dir, tmp_dir, train):
7985
target_datasets)
8086
tag = "train" if train else "dev"
8187
data_path = translate.compile_data(tmp_dir, datasets,
82-
"wmt_zhen_tok_%s" % tag)
83-
# We generate English->X data by convention, to train reverse translation
84-
# just add the "_rev" suffix to the problem name, e.g., like this.
85-
# --problems=translate_enzh_wmt8k_rev
86-
return translate.bi_vocabs_token_generator(data_path + ".lang2",
87-
data_path + ".lang1",
88+
"wmt_enzh_tok_%s" % tag)
89+
return translate.bi_vocabs_token_generator(data_path + ".lang1",
90+
data_path + ".lang2",
8891
source_vocab, target_vocab, EOS)
8992

9093
@property
9194
def input_space_id(self):
92-
return problem.SpaceID.ZH_TOK
95+
return problem.SpaceID.EN_TOK
9396

9497
@property
9598
def target_space_id(self):
96-
return problem.SpaceID.EN_TOK
99+
return problem.SpaceID.ZH_TOK
97100

98101
def feature_encoders(self, data_dir):
99102
source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)

0 commit comments

Comments
 (0)