40
40
def create_session_config (log_device_placement = False ,
41
41
enable_graph_rewriter = False ,
42
42
gpu_mem_fraction = 0.95 ,
43
- use_tpu = False ):
43
+ use_tpu = False ,
44
+ inter_op_parallelism_threads = 0 ,
45
+ intra_op_parallelism_threads = 0 ):
44
46
"""The TensorFlow Session config to use."""
45
47
if use_tpu :
46
48
graph_options = tf .GraphOptions ()
@@ -60,7 +62,9 @@ def create_session_config(log_device_placement=False,
60
62
allow_soft_placement = True ,
61
63
graph_options = graph_options ,
62
64
gpu_options = gpu_options ,
63
- log_device_placement = log_device_placement )
65
+ log_device_placement = log_device_placement ,
66
+ inter_op_parallelism_threads = inter_op_parallelism_threads ,
67
+ intra_op_parallelism_threads = intra_op_parallelism_threads )
64
68
return config
65
69
66
70
@@ -108,13 +112,17 @@ def create_run_config(master="",
108
112
random_seed = None ,
109
113
sync = False ,
110
114
tpu_infeed_sleep_secs = None ,
111
- use_tpu = False ):
115
+ use_tpu = False ,
116
+ inter_op_parallelism_threads = 0 ,
117
+ intra_op_parallelism_threads = 0 ):
112
118
"""Create RunConfig, TPUConfig, and Parallelism object."""
113
119
session_config = create_session_config (
114
120
log_device_placement = log_device_placement ,
115
121
enable_graph_rewriter = enable_graph_rewriter ,
116
122
gpu_mem_fraction = gpu_mem_fraction ,
117
- use_tpu = use_tpu )
123
+ use_tpu = use_tpu ,
124
+ inter_op_parallelism_threads = inter_op_parallelism_threads ,
125
+ intra_op_parallelism_threads = intra_op_parallelism_threads )
118
126
run_config_args = {
119
127
"master" : master ,
120
128
"evaluation_master" : master ,
0 commit comments