Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 80b2f73

Browse files
authored
Merge pull request #570 from rsepassi/push
v1.4.4
2 parents 1c98b8e + 290a12a commit 80b2f73

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1854
-365
lines changed

.travis.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ python:
33
- "2.7"
44
- "3.6"
55
before_install:
6+
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
7+
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
68
- sudo apt-get update -qq
79
- sudo apt-get install -qq libhdf5-dev
10+
- sudo apt-get install -qq tensorflow-model-server
811
install:
912
- pip install -q .[tensorflow]
1013
- pip install -q .[tests]
@@ -21,7 +24,7 @@ script:
2124
- python -c "from tensor2tensor.models import transformer; print(transformer.Transformer.__name__)"
2225

2326
# Run tests
24-
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/utils/trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py
27+
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/utils/trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py --ignore=tensor2tensor/bin/t2t_trainer_test.py
2528
- pytest tensor2tensor/utils/registry_test.py
2629
- pytest tensor2tensor/utils/trainer_lib_test.py
2730

@@ -36,5 +39,14 @@ script:
3639
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
3740
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
3841
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
42+
43+
# Export and query (on Python 2 only)
44+
- t2t-exporter --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR
45+
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
46+
pip install tensorflow-serving-api;
47+
tensorflow_model_server --port=9000 --model_name=my_model --model_base_path=$T2T_TRAIN_DIR/export/Servo &
48+
sleep 10;
49+
t2t-query-server --problem=$T2T_PROBLEM --server=localhost:9000 --servable_name=my_model --data_dir=$T2T_DATA_DIR --inputs_once='1 0 1 0 1 0';
50+
fi
3951
git:
4052
depth: 3

docs/cloud_mlengine.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Running on Cloud ML Engine
2+
3+
Google Cloud Platform offers a managed training environment for TensorFlow
4+
models called [Cloud ML Engine](https://cloud.google.com/ml-engine/) and
5+
you can easily launch Tensor2Tensor on it, including for hyperparameter tuning.
6+
7+
# Launch
8+
9+
It's the same `t2t-trainer` you know and love with the addition of the
10+
`--cloud_mlengine` flag, which by default will launch on a 1-GPU machine.
11+
12+
```
13+
# Note that both the data dir and output dir have to be on GCS
14+
DATA_DIR=gs://my-bucket/data
15+
OUTPUT_DIR=gs://my-bucket/train
16+
t2t-trainer \
17+
--problems=translate_ende_wmt32k \
18+
--model=transformer \
19+
--hparams_set=transformer_base \
20+
--data_dir=$DATA_DIR \
21+
--output_dir=$OUTPUT_DIR \
22+
--cloud_mlengine
23+
```
24+
25+
By passing `--worker_gpu=4` or `--worker_gpu=8` it will automatically launch on
26+
machines with 4 or 8 GPUs.
27+
28+
You can additionally pass the `--cloud_mlengine_master_type` to select another
29+
kind of machine (see the [docs for
30+
`masterType`](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput)
31+
for your options). If you provide this flag yourself, make sure you pass the
32+
correct value for `--worker_gpu`.
33+
34+
**Note**: `t2t-trainer` only currently supports launching with single machines,
35+
possibly with multiple GPUs. Multi-machine setups are not yet supported out of
36+
the box with the `--cloud_mlengine` flag, though multi-machine should in
37+
principle work just fine. Contributions/testers welcome.
38+
39+
## `--t2t_usr_dir`
40+
41+
Launching on Cloud ML Engine works with `--t2t_usr_dir` as well as long as the
42+
directory is fully self-contained (i.e. the imports only refer to other modules
43+
in the directory). If there are additional PyPI dependencies that you need, you
44+
can include a `setup.py` file in your directory (ensure that it uses
45+
`setuptools.find_packages`).
46+
47+
# Hyperparameter Tuning
48+
49+
Hyperparameter tuning with `t2t-trainer` and Cloud ML Engine is also a breeze
50+
with `--hparams_range` and the `--autotune_*` flags:
51+
52+
```
53+
t2t-trainer \
54+
--problems=translate_ende_wmt32k \
55+
--model=transformer \
56+
--hparams_set=transformer_base \
57+
--data_dir=$DATA_DIR \
58+
--output_dir=$OUTPUT_DIR \
59+
--cloud_mlengine \
60+
--hparams_range=transformer_base_range \
61+
--autotune_objective='metrics-translate_ende_wmt32k/neg_log_perplexity' \
62+
--autotune_maximize \
63+
--autotune_max_trials=100 \
64+
--autotune_parallel_trials=3
65+
```
66+
67+
The `--hparams_range` specifies the search space and should be registered with
68+
`@register_ranged_hparams`. It defines a `RangedHParams` object that sets
69+
search ranges and scales for various parameters. See `transformer_base_range`
70+
in
71+
[`transformer.py`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
72+
for an example.
73+
74+
The metric name passed as `--autotune_objective` should be exactly what you'd
75+
see in TensorBoard. To minimize a metric, set `--autotune_maximize=False`.
76+
77+
You control how many total trials to run with `--autotune_max_trials` and the
78+
number of jobs to launch in parallel with `--autotune_parallel_trials`.
79+
80+
Happy tuning!

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.4.3',
8+
version='1.4.4',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',
@@ -35,9 +35,9 @@
3535
'flask',
3636
'future',
3737
'gevent',
38+
'google-api-python-client',
3839
'gunicorn',
3940
'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues.
40-
'munch',
4141
'numpy',
4242
'requests',
4343
'scipy',

tensor2tensor/bin/t2t-rl-trainer

Lines changed: 0 additions & 16 deletions
This file was deleted.

tensor2tensor/bin/t2t_trainer.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
from tensor2tensor import models # pylint: disable=unused-import
2828
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
29-
from tensor2tensor.utils import cloud
29+
from tensor2tensor.utils import cloud_mlengine
30+
from tensor2tensor.utils import cloud_tpu
3031
from tensor2tensor.utils import decoding
3132
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
3233
from tensor2tensor.utils import registry
@@ -81,13 +82,68 @@
8182
flags.DEFINE_bool("cloud_delete_on_done", False,
8283
"Whether to delete the VM and TPU instance when done.")
8384

85+
# Google Cloud ML Engine
86+
flags.DEFINE_bool("cloud_mlengine", False,
87+
"Whether to launch on Cloud ML Engine.")
88+
flags.DEFINE_string("cloud_mlengine_master_type", None,
89+
"Machine type for master on Cloud ML Engine. "
90+
"If provided, overrides default selections based on "
91+
"--worker_gpu. User is responsible for ensuring "
92+
"type is valid and that --worker_gpu matches number of "
93+
"GPUs on machine type. See documentation: "
94+
"https://cloud.google.com/ml-engine/reference/rest/v1/"
95+
"projects.jobs#traininginput")
96+
# Hyperparameter tuning on Cloud ML Engine
97+
# Pass an --hparams_range to enable
98+
flags.DEFINE_string("autotune_objective", None,
99+
"TensorBoard metric name to optimize.")
100+
flags.DEFINE_bool("autotune_maximize", True,
101+
"Whether to maximize (vs. minimize) autotune_objective.")
102+
flags.DEFINE_integer("autotune_max_trials", 10,
103+
"Maximum number of tuning experiments to run.")
104+
flags.DEFINE_integer("autotune_parallel_trials", 1,
105+
"How many trials to run in parallel (will spin up this "
106+
"many jobs.")
107+
# Note than in open-source TensorFlow, the dash gets converted to an underscore,
108+
# so access is FLAGS.job_dir.
109+
flags.DEFINE_string("job-dir", None,
110+
"DO NOT USE. Exists only for Cloud ML Engine to pass in "
111+
"during hyperparameter tuning. Overrides --output_dir.")
112+
84113

85114
def get_problem_name():
86115
problems = FLAGS.problems.split("-")
87116
assert len(problems) == 1
88117
return problems[0]
89118

90119

120+
def set_hparams_from_args(args):
121+
"""Set hparams overrides from unparsed args list."""
122+
if not args:
123+
return
124+
125+
hp_prefix = "--hp_"
126+
tf.logging.info("Found unparsed command-line arguments. Checking if any "
127+
"start with %s and interpreting those as hparams "
128+
"settings.", hp_prefix)
129+
130+
pairs = []
131+
i = 0
132+
while i < len(args):
133+
arg = args[i]
134+
if arg.startswith(hp_prefix):
135+
pairs.append((arg.lstrip(hp_prefix), args[i+1]))
136+
i += 2
137+
else:
138+
tf.logging.warn("Found unknown flag: %s", arg)
139+
i += 1
140+
141+
as_hparams = ",".join(["%s=%s" % (key, val) for key, val in pairs])
142+
if FLAGS.hparams:
143+
as_hparams = "," + as_hparams
144+
FLAGS.hparams += as_hparams
145+
146+
91147
def create_hparams():
92148
if (FLAGS.cloud_tpu or FLAGS.use_tpu) and "tpu" not in FLAGS.hparams_set:
93149
tf.logging.warn("Not all hyperparameter sets work on TPU. "
@@ -244,23 +300,31 @@ def maybe_cloud_tpu():
244300
"be gs:// paths, i.e. on Google Cloud Storage.")
245301

246302
FLAGS.use_tpu = True
247-
with cloud.cloud_tpu(
303+
with cloud_tpu.cloud_tpu(
248304
FLAGS.cloud_vm_name,
249305
FLAGS.cloud_tpu_name,
250306
delete_on_done=FLAGS.cloud_delete_on_done) as tpu_master:
251307
FLAGS.master = tpu_master
252308
yield
253309

254310

255-
def main(_):
311+
def main(argv):
256312
tf.logging.set_verbosity(tf.logging.INFO)
257313
trainer_lib.set_random_seed(FLAGS.random_seed)
258314
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
259315
log_registry()
260316

317+
if FLAGS.cloud_mlengine:
318+
return cloud_mlengine.launch()
319+
261320
if FLAGS.generate_data:
262321
generate_data()
263322

323+
if hasattr(FLAGS, "job_dir") and FLAGS.job_dir:
324+
FLAGS.output_dir = FLAGS.job_dir
325+
326+
if argv:
327+
set_hparams_from_args(argv[1:])
264328
hparams = create_hparams()
265329
if is_chief():
266330
save_metadata(hparams)

tensor2tensor/bin/t2t_trainer_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for t2t_trainer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
from tensor2tensor.bin import t2t_trainer
25+
from tensor2tensor.utils import trainer_lib_test
26+
27+
import tensorflow as tf
28+
29+
FLAGS = tf.flags.FLAGS
30+
31+
32+
class TrainerTest(tf.test.TestCase):
33+
34+
@classmethod
35+
def setUpClass(cls):
36+
trainer_lib_test.TrainerLibTest.setUpClass()
37+
38+
def testTrain(self):
39+
FLAGS.problems = "tiny_algo"
40+
FLAGS.model = "transformer"
41+
FLAGS.hparams_set = "transformer_tiny"
42+
FLAGS.train_steps = 1
43+
FLAGS.eval_steps = 1
44+
FLAGS.output_dir = tf.test.get_temp_dir()
45+
FLAGS.data_dir = tf.test.get_temp_dir()
46+
t2t_trainer.main(None)
47+
48+
49+
if __name__ == "__main__":
50+
tf.test.main()

tensor2tensor/data_generators/librispeech.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def _collect_data(directory, input_ext, transcription_ext):
6666
transcript_path = os.path.join(root, transcript)
6767
with open(transcript_path, "r") as transcript_file:
6868
for transcript_line in transcript_file:
69-
line_contents = transcript_line.split(" ", 1)
70-
assert len(line_contents) == 2
69+
line_contents = transcript_line.strip().split(" ", 1)
7170
media_base, label = line_contents
7271
key = os.path.join(root, media_base)
7372
assert key not in data_files

tensor2tensor/data_generators/problem.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,26 @@ def _maybe_reverse_and_copy(example):
517517
if shuffle_files:
518518
random.shuffle(data_files)
519519
dataset = tf.data.Dataset.from_tensor_slices(tf.constant(data_files))
520-
dataset = dataset.apply(
521-
tf.contrib.data.parallel_interleave(
522-
_load_records, sloppy=is_training, cycle_length=8))
520+
521+
if hasattr(tf.contrib.data, "parallel_interleave"):
522+
dataset = dataset.apply(
523+
tf.contrib.data.parallel_interleave(
524+
_load_records, sloppy=is_training, cycle_length=8))
525+
else:
526+
dataset = dataset.interleave(_load_records, cycle_length=8,
527+
block_length=16)
528+
523529
if repeat:
524530
dataset = dataset.repeat()
525531
dataset = dataset.map(self.decode_example, num_parallel_calls=num_threads)
526532
if preprocess:
527-
dataset = dataset.apply(
528-
tf.contrib.data.parallel_interleave(
529-
_preprocess, sloppy=is_training, cycle_length=8))
533+
if hasattr(tf.contrib.data, "parallel_interleave"):
534+
dataset = dataset.apply(
535+
tf.contrib.data.parallel_interleave(
536+
_preprocess, sloppy=is_training, cycle_length=8))
537+
else:
538+
dataset = dataset.interleave(_preprocess, cycle_length=8,
539+
block_length=16)
530540
dataset = dataset.map(
531541
_maybe_reverse_and_copy, num_parallel_calls=num_threads)
532542

@@ -633,6 +643,8 @@ def _dataset_partition(self, mode, config):
633643
num_partitions: an integer
634644
"""
635645
if mode != tf.estimator.ModeKeys.TRAIN or not hasattr(config, "tpu_config"):
646+
# Reset in the case when using TPU but alternating TRAIN and EVAL.
647+
self._next_partition_id = 0
636648
return 0, 1
637649
if config.tpu_config.per_host_input_for_training:
638650
num_partitions = max(config.tpu_config.num_shards // 8, 1)
@@ -670,7 +682,7 @@ def input_fn(self,
670682
partition_id, num_partitions = self._dataset_partition(mode, config)
671683

672684
is_training = mode == tf.estimator.ModeKeys.TRAIN
673-
if config.use_tpu:
685+
if config and config.use_tpu:
674686
num_threads = 64
675687
else:
676688
num_threads = 4 if is_training else 1

0 commit comments

Comments
 (0)