Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e892dc3

Browse files
author
Ryan Sepassi
committed
Update example_life.md
PiperOrigin-RevId: 169625024
1 parent 8ee8350 commit e892dc3

File tree

13 files changed

+263
-102
lines changed

13 files changed

+263
-102
lines changed

docs/example_life.md

Lines changed: 179 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,189 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
99
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

12-
This document show how a training example passes through the T2T pipeline,
13-
and how all its parts are connected to work together.
12+
This doc explains how a training example flows through T2T, from data generation
13+
to training, evaluation, and decoding. It points out the various hooks available
14+
in the `Problem` and `T2TModel` classes and gives an overview of the T2T code
15+
(key functions, files, hyperparameters, etc.).
1416

15-
## The Life of an Example
17+
Some key files and their functions:
1618

17-
A training example passes the following stages in T2T:
18-
* raw input (text from command line or file)
19-
* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
20-
* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
21-
* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
22-
* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
23-
* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
24-
* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.
19+
* [`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py):
20+
Constructs and runs all the main components of the system (the `Problem`,
21+
the `HParams`, the `Estimator`, the `Experiment`, the `input_fn`s and
22+
`model_fn`).
23+
* [`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py):
24+
`basic_params1` serves as the base for all model hyperparameters. Registered
25+
model hparams functions always start with this default set of
26+
hyperparameters.
27+
* [`problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py):
28+
Every dataset in T2T subclasses `Problem`.
29+
* [`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py):
30+
Every model in T2T subclasses `T2TModel`.
2531

26-
We go into these phases step by step below.
32+
## Data Generation
2733

28-
## Feature Encoders
34+
The `t2t-datagen` binary is the entrypoint for data generation. It simply looks
35+
up the `Problem` specified by `--problem` and calls
36+
`Problem.generate_data(data_dir, tmp_dir)`.
2937

30-
TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.
38+
All `Problem`s are expected to generate 2 sharded `TFRecords` files - 1 for
39+
training and 1 for evaluation - with `tensorflow.Example` protocol buffers. The
40+
expected names of the files are given by `Problem.{training, dev}_filepaths`.
41+
Typically, the features in the `Example` will be `"inputs"` and `"targets"`;
42+
however, some tasks have a different on-disk representation that is converted to
43+
`"inputs"` and `"targets"` online in the input pipeline (e.g. image features are
44+
typically stored with features `"image/encoded"` and `"image/format"` and the
45+
decoding happens in the input pipeline).
3146

32-
## Modalities
47+
For tasks that require a vocabulary, this is also the point at which the
48+
vocabulary is generated and all examples are encoded.
3349

34-
TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.
50+
There are several utility functions in
51+
[`generator_utils`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator_utils.py)
52+
that are commonly used by `Problem`s to generate data. Several are highlighted
53+
below:
54+
55+
* `generate_dataset_and_shuffle`: given 2 generators, 1 for training and 1 for
56+
eval, yielding dictionaries of `<feature name, list< int or float or
57+
string >>`, will produce sharded and shuffled `TFRecords` files with
58+
`tensorflow.Example` protos.
59+
* `maybe_download`: downloads a file at a URL to the given directory and
60+
filename (see `maybe_download_from_drive` if the URL points to Google
61+
Drive).
62+
* `get_or_generate_vocab_inner`: given a target vocabulary size and a
63+
generator that yields lines or tokens from the dataset, will build a
64+
`SubwordTextEncoder` along with a backing vocabulary file that can be used
65+
to map input strings to lists of ids.
66+
[`SubwordTextEncoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py)
67+
uses word pieces and its encoding is fully invertible.
68+
69+
## Data Input Pipeline
70+
71+
Once the data is produced on disk, training, evaluation, and inference (if
72+
decoding from the dataset) consume it by way of T2T input pipeline. This section
73+
will give an overview of that pipeline with specific attention to the various
74+
hooks in the `Problem` class and the model's `HParams` object (typically
75+
registered in the model's file and specified by the `--hparams_set` flag).
76+
77+
The entire input pipeline is implemented with the new `tf.data.Dataset` API
78+
(previously `tf.contrib.data.Dataset`).
79+
80+
The key function in the codebase for the input pipeline is
81+
[`data_reader.input_pipeline`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/data_reader.py).
82+
The full input function is built in
83+
[`input_fn_builder.build_input_fn`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/input_fn_builder.py)
84+
(which calls `data_reader.input_pipeline`).
85+
86+
### Reading and decoding data
87+
88+
`Problem.dataset_filename` specifies the prefix of the files on disk (they will
89+
be suffixed with `-train` or `-dev` as well as their sharding).
90+
91+
The features read from the files and their decoding is specified by
92+
`Problem.example_reading_spec`, which returns 2 items:
93+
94+
1. Dict mapping from on-disk feature name to on-disk types (`VarLenFeature` or
95+
`FixedLenFeature`.
96+
2. Dict mapping output feature name to decoder. This return value is optional
97+
and is only needed for tasks whose features may require additional decoding
98+
(e.g. images). You can find the available decoders in
99+
`tf.contrib.slim.tfexample_decoder`.
100+
101+
At this point in the input pipeline, the example is a `dict<feature name,
102+
Tensor>`.
103+
104+
### Preprocessing
105+
106+
The read `Example` now runs through `Problem.preprocess_example`, which by
107+
default runs
108+
[`problem.preprocess_example_common`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py),
109+
which may truncate the inputs/targets or prepend to targets, governed by some
110+
hyperparameters.
111+
112+
### Batching
113+
114+
Examples are bucketed by sequence length and then batched out of those buckets.
115+
This significantly improves performance over a naive batching scheme for
116+
variable length sequences because each example in a batch must be padded to
117+
match the example with the maximum length in the batch.
118+
119+
There are several hyperparameters that affect how examples are batched together:
120+
121+
* `hp.batch_size`: this is the approximate total number of tokens in the batch
122+
(i.e. for a sequence problem, long sequences will have smaller actual batch
123+
size and short sequences will have a larger actual batch size in order to
124+
generally have an equal number of tokens in the batch).
125+
* `hp.max_length`: sequences with length longer than this will be dropped
126+
during training (and also during eval if `hp.eval_drop_long_sequences` is
127+
`True`). If not set, the maximum length of examples is set to
128+
`hp.batch_size`.
129+
* `hp.batch_size_multiplier`: multiplier for the maximum length
130+
* `hp.min_length_bucket`: example length for the smallest bucket (i.e. the
131+
smallest bucket will bucket examples up to this length).
132+
* `hp.length_bucket_step`: controls how spaced out the length buckets are.
133+
134+
## Building the Model
135+
136+
At this point, the input features typically have `"inputs"` and `"targets"`,
137+
each of which is a batched 4-D Tensor (e.g. of shape `[batch_size,
138+
sequence_length, 1, 1]` for text input or `[batch_size, height, width, 3]` for
139+
image input).
140+
141+
A `T2TModel` is composed of transforms of the input features by `Modality`s,
142+
then the body of the model, then transforms of the model output to predictions
143+
by a `Modality`, and then a loss (during training).
144+
145+
The `Modality` types for the various input features and for the target are
146+
specified in `Problem.hparams`. A `Modality` is a feature adapter that enables
147+
models to be agnostic to input/output spaces. You can see the various
148+
`Modality`s in
149+
[`modalities.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/modalities.py).
150+
151+
The sketch structure of a T2T model is as follows:
152+
153+
```python
154+
features = {...} # output from the input pipeline
155+
input_modaly = ... # specified in Problem.hparams
156+
target_modality = ... # specified in Problem.hparams
157+
158+
transformed_features = {}
159+
transformed_features["inputs"] = input_modality.bottom(
160+
features["inputs"])
161+
transformed_features["targets"] = target_modality.targets_bottom(
162+
features["targets"]) # for autoregressive models
163+
164+
body_outputs = model.model_fn_body(transformed_features)
165+
166+
predictions = target_modality.top(body_outputs, features["targets"])
167+
loss = target_modality.loss(predictions, features["targets"])
168+
```
169+
170+
Most `T2TModel`s only override `model_fn_body`.
171+
172+
## Training, Eval, Inference modes
173+
174+
Both the input function and model functions take a mode in the form of a
175+
`tf.estimator.ModeKeys`, which allows the functions to behave differently in
176+
different modes.
177+
178+
In training, the model function constructs an optimizer and minimizes the loss.
179+
180+
In evaluation, the model function constructs the evaluation metrics specified by
181+
`Problem.eval_metrics`.
182+
183+
In inference, the model function outputs predictions.
184+
185+
## `Estimator` and `Experiment`
186+
187+
With the input function and model functions constructed, the actual training
188+
loop and related services (checkpointing, summaries, continuous evaluation,
189+
etc.) are all handled by `Estimator` and `Experiment` objects, constructed in
190+
[`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py).
191+
192+
## Decoding
193+
194+
* [`decoding.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/decoding.py)
195+
196+
TODO(rsepassi): Explain decoding (interactive, from file, and from dataset) and
197+
`Problem.feature_encoders`.

tensor2tensor/data_generators/cnn_dailymail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def use_train_shards_for_dev(self):
129129
def generator(self, data_dir, tmp_dir, _):
130130
encoder = generator_utils.get_or_generate_vocab_inner(
131131
data_dir, self.vocab_file, self.targeted_vocab_size,
132-
lambda: story_generator(tmp_dir))
132+
story_generator(tmp_dir))
133133
for story in story_generator(tmp_dir):
134134
summary, rest = _story_summary_split(story)
135135
encoded_summary = encoder.encode(summary) + [EOS]

tensor2tensor/data_generators/desc2code.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ def generator_target():
195195
data_dir=data_dir,
196196
vocab_filename=self.vocab_target_filename,
197197
vocab_size=self.target_vocab_size,
198-
generator_fn=generator_target,
199-
)
198+
generator=generator_target(),)
200199

201200
# Yield the training and testing samples
202201
eos_list = [EOS]

tensor2tensor/data_generators/gene_expression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,17 @@ def example_reading_spec(self):
159159
data_items_to_decoders = None
160160
return (data_fields, data_items_to_decoders)
161161

162-
def preprocess_examples(self, examples, mode, unused_hparams):
162+
def preprocess_example(self, example, mode, unused_hparams):
163163
del mode
164164

165165
# Reshape targets to contain num_output_predictions per output timestep
166-
examples["targets"] = tf.reshape(examples["targets"],
167-
[-1, 1, self.num_output_predictions])
166+
example["targets"] = tf.reshape(example["targets"],
167+
[-1, 1, self.num_output_predictions])
168168
# Slice off EOS - not needed, and messes up the GeneExpressionConv model
169169
# which expects the input length to be a multiple of the target length.
170-
examples["inputs"] = examples["inputs"][:-1]
170+
example["inputs"] = example["inputs"][:-1]
171171

172-
return examples
172+
return example
173173

174174
def eval_metrics(self):
175175
return [metrics.Metrics.LOG_POISSON, metrics.Metrics.R2]

tensor2tensor/data_generators/generator_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,15 @@ def gunzip_file(gz_path, new_path):
300300

301301

302302
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
303-
generator_fn):
303+
generator):
304304
"""Inner implementation for vocab generators.
305305
306306
Args:
307307
data_dir: The base directory where data and vocab files are stored. If None,
308308
then do not save the vocab even if it doesn't exist.
309309
vocab_filename: relative filename where vocab file is stored
310310
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
311-
generator_fn: a generator that produces tokens from the vocabulary
311+
generator: a generator that produces tokens from the vocabulary
312312
313313
Returns:
314314
A SubwordTextEncoder vocabulary object.
@@ -325,7 +325,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
325325

326326
tf.logging.info("Generating vocab file: %s", vocab_filepath)
327327
token_counts = defaultdict(int)
328-
for item in generator_fn():
328+
for item in generator:
329329
for tok in tokenizer.encode(text_encoder.native_to_unicode(item)):
330330
token_counts[tok] += 1
331331

@@ -382,8 +382,8 @@ def generate():
382382
file_byte_budget -= len(line)
383383
yield line
384384

385-
return get_or_generate_vocab_inner(
386-
data_dir, vocab_filename, vocab_size, generator_fn=generate)
385+
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
386+
generate())
387387

388388

389389
def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename,
@@ -416,8 +416,8 @@ def generate():
416416
part = parts[index].strip()
417417
yield part
418418

419-
return get_or_generate_vocab_inner(
420-
data_dir, vocab_filename, vocab_size, generator_fn=generate)
419+
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
420+
generate())
421421

422422

423423
def get_or_generate_txt_vocab(data_dir, vocab_filename, vocab_size,
@@ -434,8 +434,8 @@ def generate():
434434
for line in source_file:
435435
yield line.strip()
436436

437-
return get_or_generate_vocab_inner(
438-
data_dir, vocab_filename, vocab_size, generator_fn=generate)
437+
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
438+
generate())
439439

440440

441441
def read_records(filename):

0 commit comments

Comments
 (0)