Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9afc190

Browse files
author
Ryan Sepassi
committed
Use problem.dataset in the TPU input pipeline
PiperOrigin-RevId: 174397407
1 parent 8fa33f6 commit 9afc190

File tree

1 file changed

+17
-35
lines changed

1 file changed

+17
-35
lines changed

tensor2tensor/tpu/tpu_trainer_lib.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -69,44 +69,11 @@ def input_fn(mode, params):
6969
},
7070
}
7171

72-
def decode_record(record):
73-
"""Serialized Example to dict of <feature name, Tensor>."""
74-
data_fields, _ = problem.example_reading_spec()
75-
decoded = tf.parse_single_example(record, features=data_fields)
76-
decoded["inputs"] = decoded["inputs"].values
77-
decoded["targets"] = decoded["targets"].values
78-
return decoded
79-
80-
data_files = tf.contrib.slim.parallel_reader.get_data_files(
81-
problem.filepattern(data_dir, mode))
82-
dataset = tf.data.TFRecordDataset(data_files)
83-
dataset = dataset.map(decode_record, num_parallel_calls=num_threads)
84-
85-
def _preprocess(example, problem, hparams, mode):
86-
example = problem.preprocess_example(example, mode, hparams)
87-
# We do not want int64s as they are not supported on TPUs.
88-
example = data_reader.cast_int64_to_int32(example)
89-
return example
90-
91-
dataset = dataset.map(
92-
lambda ex: _preprocess(ex, problem, hparams, mode),
93-
num_parallel_calls=num_threads)
94-
9572
def _valid_size(example):
9673
return data_reader.example_valid_size(
9774
example, batching_scheme["min_length"], batching_scheme["max_length"])
9875

99-
dataset = dataset.filter(_valid_size)
100-
# TODO(rsepassi): In eval mode, should not repeat
101-
dataset = dataset.repeat(None)
102-
dataset = data_reader.padded_batch(dataset, batch_size,
103-
batching_scheme["padded_shapes"])
104-
105-
if not is_training:
106-
dataset = dataset.map(
107-
lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads)
108-
109-
def shape_def(example):
76+
def define_shapes(example):
11077
"""Set the right shapes for the features."""
11178
inputs = example["inputs"]
11279
targets = example["targets"]
@@ -130,7 +97,22 @@ def shape_def(example):
13097

13198
return example
13299

133-
dataset = dataset.map(shape_def, num_parallel_calls=num_threads)
100+
dataset = problem.dataset(
101+
mode=mode, data_dir=data_dir, num_threads=num_threads, hparams=hparams)
102+
dataset = dataset.map(
103+
data_reader.cast_int64_to_int32, num_threads=num_threads)
104+
dataset = dataset.filter(_valid_size)
105+
if is_training:
106+
dataset = dataset.shuffle(100)
107+
# TODO(rsepassi): In eval mode, should not repeat. Do so because TPU seems
108+
# to crash if it runs out of data during eval.
109+
dataset = dataset.repeat(None)
110+
dataset = data_reader.padded_batch(dataset, batch_size,
111+
batching_scheme["padded_shapes"])
112+
if not is_training:
113+
dataset = dataset.map(
114+
lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads)
115+
dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
134116
dataset = dataset.prefetch(1)
135117
features = dataset.make_one_shot_iterator().get_next()
136118

0 commit comments

Comments
 (0)