Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c669cda

Browse files
authored
Merge pull request #725 from Intel-tensorflow/nhasabni/t2t-cpu-optimizations
Enabling setting Tensorflow inter_op and intra_op
2 parents 807cf8f + 0e91b45 commit c669cda

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
5858
"Temporary storage directory, used if --generate_data.")
5959
flags.DEFINE_bool("profile", False, "Profile performance?")
60+
flags.DEFINE_integer("inter_op_parallelism_threads", 0, "Number of inter_op_parallelism_threads "
61+
"to use for CPU. See TensorFlow config.proto for details.")
62+
flags.DEFINE_integer("intra_op_parallelism_threads", 0, "Number of intra_op_parallelism_threads "
63+
"to use for CPU. See TensorFlow config.proto for details.")
6064

6165
# To maintain compatibility with some internal libs, we guard against these flag
6266
# definitions possibly erring. Apologies for the ugliness.
@@ -206,7 +210,9 @@ def create_run_config(hp):
206210
worker_id=FLAGS.worker_id,
207211
worker_job=FLAGS.worker_job,
208212
random_seed=FLAGS.random_seed,
209-
tpu_infeed_sleep_secs=FLAGS.tpu_infeed_sleep_secs)
213+
tpu_infeed_sleep_secs=FLAGS.tpu_infeed_sleep_secs,
214+
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
215+
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
210216

211217

212218
def generate_data():

tensor2tensor/utils/trainer_lib.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
def create_session_config(log_device_placement=False,
4141
enable_graph_rewriter=False,
4242
gpu_mem_fraction=0.95,
43-
use_tpu=False):
43+
use_tpu=False,
44+
inter_op_parallelism_threads=0,
45+
intra_op_parallelism_threads=0):
4446
"""The TensorFlow Session config to use."""
4547
if use_tpu:
4648
graph_options = tf.GraphOptions()
@@ -60,7 +62,9 @@ def create_session_config(log_device_placement=False,
6062
allow_soft_placement=True,
6163
graph_options=graph_options,
6264
gpu_options=gpu_options,
63-
log_device_placement=log_device_placement)
65+
log_device_placement=log_device_placement,
66+
inter_op_parallelism_threads=inter_op_parallelism_threads,
67+
intra_op_parallelism_threads=intra_op_parallelism_threads)
6468
return config
6569

6670

@@ -108,13 +112,17 @@ def create_run_config(master="",
108112
random_seed=None,
109113
sync=False,
110114
tpu_infeed_sleep_secs=None,
111-
use_tpu=False):
115+
use_tpu=False,
116+
inter_op_parallelism_threads=0,
117+
intra_op_parallelism_threads=0):
112118
"""Create RunConfig, TPUConfig, and Parallelism object."""
113119
session_config = create_session_config(
114120
log_device_placement=log_device_placement,
115121
enable_graph_rewriter=enable_graph_rewriter,
116122
gpu_mem_fraction=gpu_mem_fraction,
117-
use_tpu=use_tpu)
123+
use_tpu=use_tpu,
124+
inter_op_parallelism_threads=inter_op_parallelism_threads,
125+
intra_op_parallelism_threads=intra_op_parallelism_threads)
118126
run_config_args = {
119127
"master": master,
120128
"evaluation_master": master,

0 commit comments

Comments
 (0)