Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ed9e3bd

Browse files
authored
Merge pull request #751 from rsepassi/push
v1.6.1
2 parents 757b529 + 1027772 commit ed9e3bd

File tree

213 files changed

+5164
-366
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

213 files changed

+5164
-366
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ script:
4141
--ignore=tensor2tensor/problems_test.py
4242
--ignore=tensor2tensor/bin/t2t_trainer_test.py
4343
--ignore=tensor2tensor/data_generators/algorithmic_math_test.py
44+
--ignore=tensor2tensor/models/research/r_transformer_test.py # Requires new feature in tf.foldl (rm with TF 1.9)
4445
- pytest tensor2tensor/utils/registry_test.py
4546
- pytest tensor2tensor/utils/trainer_lib_test.py
4647
- pytest tensor2tensor/visualization/visualization_test.py

docs/distributed_training.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ distributed training:
5151
Parameter servers only need `--master=grpc://$ADDRESS` and
5252
`--schedule=run_std_server`.
5353

54-
>> Note about `output_dir`: All the workers (masters and parameter servers) should use the same `output_dir`. If training
55-
>> on separate nodes, output_dir can be a shared filesystem like NFS or an object store like GCS.
54+
>> Note about `--output_dir`: All the nodes should use the same `--output_dir`.
55+
>> When using multiple machines, `output_dir` should point to a shared
56+
>> filesystem like NFS or an object store like Google Cloud Storage
57+
>> (`gs://...`).
5658
5759
## Utility to produce `TF_CONFIG` and flags
5860

setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.6.0',
8+
version='1.6.1',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',
@@ -14,6 +14,7 @@
1414
packages=find_packages(),
1515
package_data={
1616
'tensor2tensor.data_generators': ['test_data/*'],
17+
'tensor2tensor.data_generators.wikisum': ['test_data/*'],
1718
'tensor2tensor.visualization': [
1819
'attention.js', 'TransformerVisualization.ipynb'
1920
],
@@ -37,7 +38,8 @@
3738
'gevent',
3839
'google-api-python-client',
3940
'gunicorn',
40-
'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues.
41+
'gym',
42+
'h5py',
4143
'numpy',
4244
'requests',
4345
'scipy',
@@ -47,7 +49,7 @@
4749
extras_require={
4850
'tensorflow': ['tensorflow>=1.5.0'],
4951
'tensorflow_gpu': ['tensorflow-gpu>=1.5.0'],
50-
'tests': ['pytest', 'h5py', 'mock'],
52+
'tests': ['pytest', 'mock'],
5153
},
5254
classifiers=[
5355
'Development Status :: 4 - Beta',

tensor2tensor/bin/make_tf_configs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
"""Output command line arguments and json-encoded TF_CONFIGs.
1716
1817
Usage:

tensor2tensor/bin/t2t_avg_all.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
"""Script to continuously average last N checkpoints in a given directory."""
1716
from __future__ import absolute_import
1817
from __future__ import division

tensor2tensor/bin/t2t_bleu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
"""Evaluate BLEU score for all checkpoints/translations in a given directory.
1716
1817
This script can be used in two ways.

tensor2tensor/bin/t2t_datagen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
"""Produces the training and dev data for --problem into --data_dir.
1716
1817
Produces sharded and shuffled TFRecord files of tensorflow.Example protocol

tensor2tensor/bin/t2t_decoder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
r"""Decode from trained T2T models.
1716
1817
This binary performs inference using the Estimator API.
@@ -82,9 +81,13 @@ def create_decode_hparams():
8281

8382
def decode(estimator, hparams, decode_hp):
8483
if FLAGS.decode_interactive:
84+
if estimator.config.use_tpu:
85+
raise ValueError("TPU can only decode from dataset.")
8586
decoding.decode_interactively(estimator, hparams, decode_hp,
8687
checkpoint_path=FLAGS.checkpoint_path)
8788
elif FLAGS.decode_from_file:
89+
if estimator.config.use_tpu:
90+
raise ValueError("TPU can only decode from dataset.")
8891
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
8992
decode_hp, FLAGS.decode_to_file,
9093
checkpoint_path=FLAGS.checkpoint_path)
@@ -160,7 +163,6 @@ def main(_):
160163
tf.logging.set_verbosity(tf.logging.INFO)
161164
trainer_lib.set_random_seed(FLAGS.random_seed)
162165
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
163-
FLAGS.use_tpu = False # decoding not supported on TPU
164166

165167
if FLAGS.score_file:
166168
filename = os.path.expanduser(FLAGS.score_file)
@@ -183,7 +185,7 @@ def main(_):
183185
hp,
184186
t2t_trainer.create_run_config(hp),
185187
decode_hparams=decode_hp,
186-
use_tpu=False)
188+
use_tpu=FLAGS.use_tpu)
187189

188190
decode(estimator, hp, decode_hp)
189191

tensor2tensor/bin/t2t_distill.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
r"""Perform distillation for a teacher to student.
1716
1817
This script is intended to be used with --model=distillation. See the model for

tensor2tensor/bin/t2t_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
"""Train and evaluate."""
1716
from __future__ import absolute_import
1817
from __future__ import division
@@ -87,6 +86,8 @@
8786
"Name of Cloud TPU instance to use or create.")
8887
flags.DEFINE_bool("cloud_delete_on_done", False,
8988
"Whether to delete the VM and TPU instance when done.")
89+
flags.DEFINE_bool("cloud_skip_confirmation", False,
90+
"Whether to skip launch confirmations.")
9091

9192
# Google Cloud ML Engine
9293
flags.DEFINE_bool("cloud_mlengine", False,
@@ -319,7 +320,8 @@ def maybe_cloud_tpu():
319320
with cloud_tpu.cloud_tpu(
320321
FLAGS.cloud_vm_name,
321322
FLAGS.cloud_tpu_name,
322-
delete_on_done=FLAGS.cloud_delete_on_done) as tpu_master:
323+
delete_on_done=FLAGS.cloud_delete_on_done,
324+
skip_confirmation=FLAGS.cloud_skip_confirmation) as tpu_master:
323325
FLAGS.master = tpu_master
324326
yield
325327

0 commit comments

Comments
 (0)