61
61
_LOCAL_DEVICES = None
62
62
63
63
64
+ def _get_default_graph ():
65
+ try :
66
+ return tf .get_default_graph ()
67
+ except AttributeError :
68
+ raise RuntimeError (
69
+ 'It looks like you are trying to use '
70
+ 'a version of multi-backend Keras that '
71
+ 'does not support TensorFlow 2.0. We recommend '
72
+ 'using `tf.keras`, or alternatively, '
73
+ 'downgrading to TensorFlow 1.14.' )
74
+
75
+
64
76
def get_uid (prefix = '' ):
65
77
"""Get the uid for the default graph.
66
78
@@ -71,7 +83,7 @@ def get_uid(prefix=''):
71
83
A unique identifier for the graph.
72
84
"""
73
85
global _GRAPH_UID_DICTS
74
- graph = tf . get_default_graph ()
86
+ graph = _get_default_graph ()
75
87
if graph not in _GRAPH_UID_DICTS :
76
88
_GRAPH_UID_DICTS [graph ] = defaultdict (int )
77
89
_GRAPH_UID_DICTS [graph ][prefix ] += 1
@@ -101,7 +113,7 @@ def clear_session():
101
113
shape = (),
102
114
name = 'keras_learning_phase' )
103
115
_GRAPH_LEARNING_PHASES = {}
104
- _GRAPH_LEARNING_PHASES [tf . get_default_graph ()] = phase
116
+ _GRAPH_LEARNING_PHASES [_get_default_graph ()] = phase
105
117
106
118
107
119
def manual_variable_initialization (value ):
@@ -130,7 +142,7 @@ def learning_phase():
130
142
# Returns
131
143
Learning phase (scalar integer tensor or Python integer).
132
144
"""
133
- graph = tf . get_default_graph ()
145
+ graph = _get_default_graph ()
134
146
if graph not in _GRAPH_LEARNING_PHASES :
135
147
with tf .name_scope ('' ):
136
148
phase = tf .placeholder_with_default (
@@ -154,7 +166,7 @@ def set_learning_phase(value):
154
166
if value not in {0 , 1 }:
155
167
raise ValueError ('Expected learning phase to be '
156
168
'0 or 1.' )
157
- _GRAPH_LEARNING_PHASES [tf . get_default_graph ()] = value
169
+ _GRAPH_LEARNING_PHASES [_get_default_graph ()] = value
158
170
159
171
160
172
def get_session ():
@@ -247,7 +259,7 @@ def _get_current_tf_device():
247
259
the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
248
260
return `None`.
249
261
"""
250
- g = tf . get_default_graph ()
262
+ g = _get_default_graph ()
251
263
op = _TfDeviceCaptureOp ()
252
264
g ._apply_device_functions (op )
253
265
return op .device
0 commit comments