Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1047fa6

Browse files
author
Ryan Sepassi
committed
Document distributed training and update README
1 parent aee16b4 commit 1047fa6

File tree

7 files changed

+192
-8
lines changed

7 files changed

+192
-8
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ issues](https://github.com/tensorflow/tensor2tensor/issues).
1717
```
1818
pip install tensor2tensor
1919
20-
DATA_DIR=$HOME/t2t_data
21-
TMP_DIR=/tmp/t2t_datagen
2220
PROBLEM=wmt_ende_tokens_32k
2321
MODEL=transformer
2422
HPARAMS=transformer_base
25-
TRAIN_DIR=$HOME/t2t_train/$PROBLEM_$MODEL_$HPARAMS
23+
DATA_DIR=$HOME/t2t_data
24+
TMP_DIR=/tmp/t2t_datagen
25+
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
2626
27-
mkdir $DATA_DIR $TMP_DIR $TRAIN_DIR
27+
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
2828
2929
# Generate data
3030
t2t-datagen \
@@ -69,6 +69,10 @@ cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
6969
T2T modularizes training into several components, each of which can be seen in
7070
use in the above commands.
7171

72+
See the models, problems, and hyperparameter sets that are available:
73+
74+
`t2t-trainer --registry_help`
75+
7276
### Datasets
7377

7478
**Datasets** are all standardized on TFRecord files with `tensorflow.Example`
@@ -118,7 +122,8 @@ The **trainer** binary is the main entrypoint for training, evaluation, and
118122
inference. Users can easily switch between problems, models, and hyperparameter
119123
sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific
120124
hyperparameters can be overriden with the `--hparams` flag. `--schedule` and
121-
related flags control local and distributed training/evaluation.
125+
related flags control local and distributed training/evaluation
126+
([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md)).
122127

123128
## Adding a dataset
124129

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Install tensor2tensor."""
22

3-
from setuptools import setup, find_packages
3+
from setuptools import find_packages
4+
from setuptools import setup
45

56
setup(
67
name='tensor2tensor',

tensor2tensor/bin/make_tf_configs.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Output command line arguments and json-encoded TF_CONFIGs.
16+
17+
Usage:
18+
19+
`make_tf_configs.py --workers="server1:1234" --ps="server3:2134,server4:2334"`
20+
21+
Outputs 1 line per job to stdout, first the workers, then the parameter servers.
22+
Each line has the TF_CONFIG, then a tab, then the command line flags for that
23+
job.
24+
25+
If there is a single worker, workers will have the `--sync` flag.
26+
"""
27+
from __future__ import absolute_import
28+
from __future__ import division
29+
from __future__ import print_function
30+
31+
import json
32+
33+
# Dependency imports
34+
35+
import tensorflow as tf
36+
37+
flags = tf.flags
38+
FLAGS = flags.FLAGS
39+
40+
flags.DEFINE_string("workers", "", "Comma-separated list of worker addresses")
41+
flags.DEFINE_string("ps", "", "Comma-separated list of ps addresses")
42+
43+
44+
def main(_):
45+
if not (FLAGS.workers and FLAGS.ps):
46+
raise ValueError("Must provide --workers and --ps")
47+
48+
workers = FLAGS.workers.split(",")
49+
ps = FLAGS.ps.split(",")
50+
51+
cluster = {"ps": ps, "worker": workers}
52+
53+
for task_type, jobs in [("worker", workers), ("ps", ps)]:
54+
for idx, job in enumerate(jobs):
55+
if task_type == "worker":
56+
cmd_line_flags = " ".join([
57+
"--master=%s" % job,
58+
"--ps_replicas=%d" % len(ps),
59+
"--worker_replicas=%d" % len(workers),
60+
"--worker_gpu=1",
61+
"--worker_id=%d" % idx,
62+
"--ps_gpu=1",
63+
"--schedule=train",
64+
"--sync" if len(workers) == 1 else "",
65+
])
66+
else:
67+
cmd_line_flags = " ".join([
68+
"--schedule=run_std_server",
69+
])
70+
71+
tf_config = json.dumps({
72+
"cluster": cluster,
73+
"task": {
74+
"type": task_type,
75+
"index": idx
76+
}
77+
})
78+
print(tf_config + "\t" + cmd_line_flags)
79+
80+
81+
if __name__ == "__main__":
82+
tf.app.run()

tensor2tensor/bin/t2t-trainer

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def main(_):
4242
tf.logging.set_verbosity(tf.logging.INFO)
4343
utils.log_registry()
4444
utils.validate_flags()
45-
# TODO(rsepassi): Document distributed training
4645
utils.run(
4746
data_dir=FLAGS.data_dir,
4847
model=FLAGS.model,
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Distributed Training
2+
3+
The `t2t-trainer` supports both synchronous and asynchronous distributed
4+
training.
5+
6+
T2T uses TensorFlow Estimators and so distributed training is configured with
7+
the `TF_CONFIG` environment variable that is read by the
8+
[RunConfig](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/estimators/run_config.py)
9+
along with a set of flags.
10+
11+
## `TF_CONFIG`
12+
13+
Both workers and parameter servers must have the `TF_CONFIG` environment
14+
variable set.
15+
16+
The `TF_CONFIG` environment variable is a json-encoded string with the addresses
17+
of the workers and parameter servers (in the `'cluster'` key) and the
18+
identification of the current task (in the `'task'` key).
19+
20+
For example:
21+
22+
```
23+
cluster = {
24+
'ps': ['host1:2222', 'host2:2222'],
25+
'worker': ['host3:2222', 'host4:2222', 'host5:2222']
26+
}
27+
os.environ['TF_CONFIG'] = json.dumps({
28+
'cluster': cluster,
29+
'task': {'type': 'worker', 'index': 1}
30+
})
31+
```
32+
33+
## Command-line flags
34+
35+
The following T2T command-line flags must also be set on the workers for
36+
distributed training:
37+
38+
- `--master=$ADDRESS`
39+
- `--worker_replicas=$NUM_WORKERS`
40+
- `--worker_gpu=$NUM_GPUS_PER_WORKER`
41+
- `--worker_id=$WORKER_ID`
42+
- `--ps_replicas=$NUM_PS`
43+
- `--ps_gpu=$NUM_GPUS_PER_PS`
44+
- `--schedule=train`
45+
- `--sync`, if you want synchronous training, i.e. for there to be a single
46+
master worker coordinating the work across "ps" jobs (yes, the naming is
47+
unfortunate). If not set, then each worker operates independently while
48+
variables are shared on the parameter servers.
49+
50+
Parameter servers only need `--schedule=run_std_server`.
51+
52+
## Utility to produce `TF_CONFIG` and flags
53+
54+
[`bin/make_tf_configs.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/make_tf_configs.py))
55+
generates the `TF_CONFIG` json strings and the above-mentioned command-line
56+
flags for the workers and parameter servers.
57+
58+
## Command-line flags for eval jobs
59+
60+
Eval jobs should set the following flags and do not need the `TF_CONFIG`
61+
environment variable to be set as the eval jobs run locally and do not
62+
communicate to the other jobs (the eval jobs read the model checkpoints that the
63+
trainer writes out):
64+
65+
- `--schedule=continuous_eval_on_train_data` or
66+
`--schedule=continuous_eval` (for test data)
67+
- `--worker_job='/job:localhost'`
68+
- `--output_dir=$TRAIN_DIR`

tensor2tensor/utils/trainer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def experiment_fn(output_dir):
126126

127127
def create_experiment(output_dir, data_dir, model_name, train_steps,
128128
eval_steps):
129-
hparams = create_hparams(FLAGS.hparams_set, FLAGS.data_dir)
129+
hparams = create_hparams(FLAGS.hparams_set, data_dir)
130130
estimator, input_fns = create_experiment_components(
131131
hparams=hparams,
132132
output_dir=output_dir,

tensor2tensor/utils/trainer_utils_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,29 @@
2020

2121
# Dependency imports
2222

23+
from tensor2tensor.data_generators import algorithmic
24+
from tensor2tensor.data_generators import generator_utils
2325
from tensor2tensor.utils import registry
2426
from tensor2tensor.utils import trainer_utils as utils # pylint: disable=unused-import
2527

2628
import tensorflow as tf
2729

30+
FLAGS = tf.flags.FLAGS
31+
2832

2933
class TrainerUtilsTest(tf.test.TestCase):
3034

35+
@classmethod
36+
def setUpClass(cls):
37+
# Generate a small test dataset
38+
FLAGS.problems = "algorithmic_addition_binary40"
39+
TrainerUtilsTest.data_dir = tf.test.get_temp_dir()
40+
gen = algorithmic.identity_generator(2, 10, 300)
41+
generator_utils.generate_files(gen, FLAGS.problems + "-train",
42+
TrainerUtilsTest.data_dir, 1, 100)
43+
generator_utils.generate_files(gen, FLAGS.problems + "-dev",
44+
TrainerUtilsTest.data_dir, 1, 100)
45+
3146
def testModelsImported(self):
3247
models = registry.list_models()
3348
self.assertTrue("baseline_lstm_seq2seq" in models)
@@ -36,6 +51,20 @@ def testHParamsImported(self):
3651
hparams = registry.list_hparams()
3752
self.assertTrue("transformer_base" in hparams)
3853

54+
def testSingleStep(self):
55+
model_name = "transformer"
56+
FLAGS.hparams_set = "transformer_base"
57+
# Shrink the test model down
58+
FLAGS.hparams = ("batch_size=10,hidden_size=10,num_heads=2,max_length=16,"
59+
"num_hidden_layers=1")
60+
exp = utils.create_experiment(
61+
output_dir=tf.test.get_temp_dir(),
62+
data_dir=TrainerUtilsTest.data_dir,
63+
model_name=model_name,
64+
train_steps=1,
65+
eval_steps=1)
66+
exp.test()
67+
3968

4069
if __name__ == "__main__":
4170
tf.test.main()

0 commit comments

Comments
 (0)