Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3b133d0

Browse files
authored
Merge pull request #6 from rsepassi/master
v1.0.3
2 parents 4f55394 + 192e90f commit 3b133d0

23 files changed

+869
-330
lines changed

README.md

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,45 @@
11
# T2T: Tensor2Tensor Transformers
22

3+
[![PyPI
4+
version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
5+
[![GitHub
6+
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
7+
[![Contributions
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
9+
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
10+
311
[T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible
4-
library and binaries for supervised learning with TensorFlow and with a focus on
5-
sequence tasks. Actively used and maintained by researchers and engineers within
6-
Google Brain, T2T strives to maximize idea bandwidth and minimize execution
7-
latency.
8-
9-
T2T is particularly well-suited to researchers working on sequence tasks. We're
10-
eager to collaborate with you on extending T2T's powers, so please feel free to
11-
open an issue on GitHub to kick off a discussion and send along pull requests,
12-
See [our contribution doc](CONTRIBUTING.md) for details and our [open
12+
library and binaries for supervised learning with TensorFlow and with support
13+
for sequence tasks. It is actively used and maintained by researchers and
14+
engineers within the Google Brain team.
15+
16+
We're eager to collaborate with you on extending T2T, so please feel
17+
free to [open an issue on
18+
GitHub](https://github.com/tensorflow/tensor2tensor/issues) or
19+
send along a pull request to add your data-set or model.
20+
See [our contribution
21+
doc](CONTRIBUTING.md) for details and our [open
1322
issues](https://github.com/tensorflow/tensor2tensor/issues).
1423

15-
## T2T overview
24+
---
25+
26+
## Walkthrough
27+
28+
Here's a walkthrough training a good English-to-German translation
29+
model using the Transformer model from [*Attention Is All You
30+
Need*](https://arxiv.org/abs/1706.03762) on WMT data.
1631

1732
```
1833
pip install tensor2tensor
1934
35+
# See what problems, models, and hyperparameter sets are available.
36+
# You can easily swap between them (and add new ones).
37+
t2t-trainer --registry_help
38+
2039
PROBLEM=wmt_ende_tokens_32k
2140
MODEL=transformer
2241
HPARAMS=transformer_base
42+
2343
DATA_DIR=$HOME/t2t_data
2444
TMP_DIR=/tmp/t2t_datagen
2545
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
@@ -35,6 +55,7 @@ t2t-datagen \
3555
mv $TMP_DIR/tokens.vocab.32768 $DATA_DIR
3656
3757
# Train
58+
# * If you run out of memory, add --hparams='batch_size=2048' or even 1024.
3859
t2t-trainer \
3960
--data_dir=$DATA_DIR \
4061
--problems=$PROBLEM \
@@ -59,23 +80,63 @@ t2t-trainer \
5980
--output_dir=$TRAIN_DIR \
6081
--train_steps=0 \
6182
--eval_steps=0 \
62-
--beam_size=$BEAM_SIZE \
63-
--alpha=$ALPHA \
83+
--decode_beam_size=$BEAM_SIZE \
84+
--decode_alpha=$ALPHA \
6485
--decode_from_file=$DECODE_FILE
6586
6687
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
6788
```
6889

69-
T2T modularizes training into several components, each of which can be seen in
70-
use in the above commands.
90+
---
7191

72-
See the models, problems, and hyperparameter sets that are available:
92+
## Installation
7393

74-
`t2t-trainer --registry_help`
94+
```
95+
pip install tensor2tensor
96+
```
97+
98+
Binaries:
99+
100+
```
101+
# Data generator
102+
t2t-datagen
103+
104+
# Trainer
105+
t2t-trainer --registry_help
106+
```
107+
108+
Library usage:
109+
110+
```
111+
python -c "from tensor2tensor.models.transformer import Transformer"
112+
```
113+
114+
---
115+
116+
## Features
117+
118+
* Many state of the art and baseline models are built-in and new models can be
119+
added easily (open an issue or pull request!).
120+
* Many datasets across modalities - text, audio, image - available for
121+
generation and use, and new ones can be added easily (open an issue or pull
122+
request for public datasets!).
123+
* Models can be used with any dataset and input mode (or even multiple); all
124+
modality-specific processing (e.g. embedding lookups for text tokens) is done
125+
with `Modality` objects, which are specified per-feature in the dataset/task
126+
specification.
127+
* Support for multi-GPU machines and synchronous (1 master, many workers) and
128+
asynchrounous (independent workers synchronizing through a parameter server)
129+
distributed training.
130+
* Easily swap amongst datasets and models by command-line flag with the data
131+
generation script `t2t-datagen` and the training script `t2t-trainer`.
132+
133+
---
134+
135+
## T2T overview
75136

76137
### Datasets
77138

78-
**Datasets** are all standardized on TFRecord files with `tensorflow.Example`
139+
**Datasets** are all standardized on `TFRecord` files with `tensorflow.Example`
79140
protocol buffers. All datasets are registered and generated with the
80141
[data
81142
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen)
@@ -125,10 +186,12 @@ hyperparameters can be overriden with the `--hparams` flag. `--schedule` and
125186
related flags control local and distributed training/evaluation
126187
([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md)).
127188

189+
---
190+
128191
## Adding a dataset
129192

130-
See the data generators
131-
[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md).
193+
See the [data generators
194+
README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md).
132195

133196
---
134197

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.2',
8+
version='1.0.3',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor.egg-info/PKG-INFO

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Metadata-Version: 1.1
2+
Name: tensor2tensor
3+
Version: 1.0.3
4+
Summary: Tensor2Tensor
5+
Home-page: http://github.com/tensorflow/tensor2tensor
6+
Author: Google Inc.
7+
Author-email: [email protected]
8+
License: Apache 2.0
9+
Description: UNKNOWN
10+
Keywords: tensorflow
11+
Platform: UNKNOWN
12+
Classifier: Development Status :: 4 - Beta
13+
Classifier: Intended Audience :: Developers
14+
Classifier: Intended Audience :: Science/Research
15+
Classifier: License :: OSI Approved :: Apache Software License
16+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence

tensor2tensor.egg-info/SOURCES.txt

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
tensor2tensor/__init__.py
2+
tensor2tensor.egg-info/PKG-INFO
3+
tensor2tensor.egg-info/SOURCES.txt
4+
tensor2tensor.egg-info/dependency_links.txt
5+
tensor2tensor.egg-info/requires.txt
6+
tensor2tensor.egg-info/top_level.txt
7+
tensor2tensor/bin/t2t-datagen
8+
tensor2tensor/bin/t2t-trainer
9+
tensor2tensor/data_generators/__init__.py
10+
tensor2tensor/data_generators/algorithmic.py
11+
tensor2tensor/data_generators/algorithmic_math.py
12+
tensor2tensor/data_generators/algorithmic_math_test.py
13+
tensor2tensor/data_generators/algorithmic_test.py
14+
tensor2tensor/data_generators/audio.py
15+
tensor2tensor/data_generators/audio_test.py
16+
tensor2tensor/data_generators/concatenate_examples.py
17+
tensor2tensor/data_generators/generator_utils.py
18+
tensor2tensor/data_generators/generator_utils_test.py
19+
tensor2tensor/data_generators/image.py
20+
tensor2tensor/data_generators/image_test.py
21+
tensor2tensor/data_generators/lm_example.py
22+
tensor2tensor/data_generators/problem_hparams.py
23+
tensor2tensor/data_generators/problem_hparams_test.py
24+
tensor2tensor/data_generators/replace_oov.py
25+
tensor2tensor/data_generators/snli.py
26+
tensor2tensor/data_generators/text_encoder.py
27+
tensor2tensor/data_generators/text_encoder_build_subword.py
28+
tensor2tensor/data_generators/text_encoder_inspect_subword.py
29+
tensor2tensor/data_generators/tokenizer.py
30+
tensor2tensor/data_generators/tokenizer_test.py
31+
tensor2tensor/data_generators/wmt.py
32+
tensor2tensor/data_generators/wmt_test.py
33+
tensor2tensor/data_generators/wsj_parsing.py
34+
tensor2tensor/models/__init__.py
35+
tensor2tensor/models/attention_lm.py
36+
tensor2tensor/models/attention_lm_moe.py
37+
tensor2tensor/models/bytenet.py
38+
tensor2tensor/models/bytenet_test.py
39+
tensor2tensor/models/common_attention.py
40+
tensor2tensor/models/common_hparams.py
41+
tensor2tensor/models/common_layers.py
42+
tensor2tensor/models/common_layers_test.py
43+
tensor2tensor/models/lstm.py
44+
tensor2tensor/models/lstm_test.py
45+
tensor2tensor/models/models.py
46+
tensor2tensor/models/multimodel.py
47+
tensor2tensor/models/multimodel_test.py
48+
tensor2tensor/models/neural_gpu.py
49+
tensor2tensor/models/neural_gpu_test.py
50+
tensor2tensor/models/slicenet.py
51+
tensor2tensor/models/slicenet_test.py
52+
tensor2tensor/models/transformer.py
53+
tensor2tensor/models/transformer_test.py
54+
tensor2tensor/models/xception.py
55+
tensor2tensor/models/xception_test.py
56+
tensor2tensor/utils/__init__.py
57+
tensor2tensor/utils/avg_checkpoints.py
58+
tensor2tensor/utils/beam_search.py
59+
tensor2tensor/utils/beam_search_test.py
60+
tensor2tensor/utils/bleu_hook.py
61+
tensor2tensor/utils/bleu_hook_test.py
62+
tensor2tensor/utils/data_reader.py
63+
tensor2tensor/utils/data_reader_test.py
64+
tensor2tensor/utils/expert_utils.py
65+
tensor2tensor/utils/metrics.py
66+
tensor2tensor/utils/metrics_test.py
67+
tensor2tensor/utils/modality.py
68+
tensor2tensor/utils/modality_test.py
69+
tensor2tensor/utils/registry.py
70+
tensor2tensor/utils/registry_test.py
71+
tensor2tensor/utils/t2t_model.py
72+
tensor2tensor/utils/trainer_utils.py
73+
tensor2tensor/utils/trainer_utils_test.py
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

tensor2tensor.egg-info/requires.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy
2+
sympy
3+
six
4+
tensorflow-gpu>=1.2.0rc1

tensor2tensor.egg-info/top_level.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensor2tensor

tensor2tensor/bin/make_tf_configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
# Dependency imports
3434

35+
import six
3536
import tensorflow as tf
3637

3738
flags = tf.flags
@@ -50,7 +51,7 @@ def main(_):
5051

5152
cluster = {"ps": ps, "worker": workers}
5253

53-
for task_type, jobs in [("worker", workers), ("ps", ps)]:
54+
for task_type, jobs in six.iteritems(cluster):
5455
for idx, job in enumerate(jobs):
5556
if task_type == "worker":
5657
cmd_line_flags = " ".join([

tensor2tensor/data_generators/algorithmic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def identity_generator(nbr_symbols, max_length, nbr_cases):
2828
"""Generator for the identity (copy) task on sequences of symbols.
2929
3030
The length of the sequence is drawn uniformly at random from [1, max_length]
31-
and then symbols are drawn uniformly at random from [1, nbr_symbols] until
31+
and then symbols are drawn uniformly at random from [2, nbr_symbols] until
3232
nbr_cases sequences have been produced.
3333
3434
Args:
@@ -42,15 +42,15 @@ def identity_generator(nbr_symbols, max_length, nbr_cases):
4242
"""
4343
for _ in xrange(nbr_cases):
4444
l = np.random.randint(max_length) + 1
45-
inputs = [np.random.randint(nbr_symbols) + 1 for _ in xrange(l)]
45+
inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)]
4646
yield {"inputs": inputs, "targets": inputs}
4747

4848

4949
def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
5050
"""Generator for the shift task on sequences of symbols.
5151
5252
The length of the sequence is drawn uniformly at random from [1, max_length]
53-
and then symbols are drawn uniformly at random from [1, nbr_symbols - shift]
53+
and then symbols are drawn uniformly at random from [2, nbr_symbols - shift]
5454
until nbr_cases sequences have been produced (output[i] = input[i] + shift).
5555
5656
Args:
@@ -65,15 +65,15 @@ def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
6565
"""
6666
for _ in xrange(nbr_cases):
6767
l = np.random.randint(max_length) + 1
68-
inputs = [np.random.randint(nbr_symbols - shift) + 1 for _ in xrange(l)]
68+
inputs = [np.random.randint(nbr_symbols - shift) + 2 for _ in xrange(l)]
6969
yield {"inputs": inputs, "targets": [i + shift for i in inputs]}
7070

7171

7272
def reverse_generator(nbr_symbols, max_length, nbr_cases):
7373
"""Generator for the reversing task on sequences of symbols.
7474
7575
The length of the sequence is drawn uniformly at random from [1, max_length]
76-
and then symbols are drawn uniformly at random from [1, nbr_symbols] until
76+
and then symbols are drawn uniformly at random from [2, nbr_symbols] until
7777
nbr_cases sequences have been produced.
7878
7979
Args:
@@ -87,7 +87,7 @@ def reverse_generator(nbr_symbols, max_length, nbr_cases):
8787
"""
8888
for _ in xrange(nbr_cases):
8989
l = np.random.randint(max_length) + 1
90-
inputs = [np.random.randint(nbr_symbols) + 1 for _ in xrange(l)]
90+
inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)]
9191
yield {"inputs": inputs, "targets": list(reversed(inputs))}
9292

9393

@@ -139,8 +139,8 @@ def addition_generator(base, max_length, nbr_cases):
139139
n2 = random_number_lower_endian(l2, base)
140140
result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base)
141141
# We shift digits by 1 on input and output to leave 0 for padding.
142-
inputs = [i + 1 for i in n1] + [base + 1] + [i + 1 for i in n2]
143-
targets = [i + 1 for i in number_to_lower_endian(result, base)]
142+
inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2]
143+
targets = [i + 2 for i in number_to_lower_endian(result, base)]
144144
yield {"inputs": inputs, "targets": targets}
145145

146146

@@ -173,6 +173,6 @@ def multiplication_generator(base, max_length, nbr_cases):
173173
n2 = random_number_lower_endian(l2, base)
174174
result = lower_endian_to_number(n1, base) * lower_endian_to_number(n2, base)
175175
# We shift digits by 1 on input and output to leave 0 for padding.
176-
inputs = [i + 1 for i in n1] + [base + 1] + [i + 1 for i in n2]
177-
targets = [i + 1 for i in number_to_lower_endian(result, base)]
176+
inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2]
177+
targets = [i + 2 for i in number_to_lower_endian(result, base)]
178178
yield {"inputs": inputs, "targets": targets}

0 commit comments

Comments
 (0)