14
14
# ==============================================================================
15
15
"""Tests for Keras metrics."""
16
16
17
- import tensorflow .compat .v2 as tf
18
-
19
17
from absl .testing import parameterized
20
18
from keras import metrics
21
19
from keras .engine import base_layer
20
+ import tensorflow .compat .v2 as tf
21
+
22
+ combinations = tf .__internal__ .distribute .combinations
22
23
23
24
24
25
def _labeled_dataset_fn ():
@@ -67,18 +68,18 @@ def _regression_dataset_fn():
67
68
def all_combinations ():
68
69
return tf .__internal__ .test .combinations .combine (
69
70
distribution = [
70
- tf .__internal__ .distribute .combinations .default_strategy ,
71
- tf .__internal__ .distribute .combinations .one_device_strategy ,
72
- tf .__internal__ .distribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
73
- tf .__internal__ .distribute .combinations .mirrored_strategy_with_two_gpus
71
+ combinations .default_strategy , combinations .one_device_strategy ,
72
+ combinations .mirrored_strategy_with_gpu_and_cpu ,
73
+ combinations .mirrored_strategy_with_two_gpus
74
74
],
75
75
mode = ["graph" ])
76
76
77
77
78
78
def tpu_combinations ():
79
79
return tf .__internal__ .test .combinations .combine (
80
- distribution = [tf .__internal__ .distribute .combinations .tpu_strategy ,],
81
- mode = ["graph" ])
80
+ distribution = [
81
+ combinations .tpu_strategy ,
82
+ ], mode = ["graph" ])
82
83
83
84
84
85
class KerasMetricsTest (tf .test .TestCase , parameterized .TestCase ):
@@ -106,7 +107,7 @@ def _test_metric(self, distribution, dataset_fn, metric_init_fn, expected_fn):
106
107
if batches_consumed >= 4 : # Consume 4 input batches in total.
107
108
break
108
109
109
- @tf . __internal__ . distribute . combinations .generate (all_combinations () + tpu_combinations ())
110
+ @combinations .generate (all_combinations () + tpu_combinations ())
110
111
def testMean (self , distribution ):
111
112
def _dataset_fn ():
112
113
return tf .data .Dataset .range (1000 ).map (tf .compat .v1 .to_float ).batch (
@@ -118,20 +119,21 @@ def _expected_fn(num_batches):
118
119
119
120
self ._test_metric (distribution , _dataset_fn , metrics .Mean , _expected_fn )
120
121
121
- @tf . __internal__ . distribute . combinations .generate (
122
+ @combinations .generate (
122
123
tf .__internal__ .test .combinations .combine (
123
124
distribution = [
124
- tf .__internal__ .distribute .combinations .mirrored_strategy_with_one_cpu ,
125
- tf .__internal__ .distribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
126
- tf .__internal__ .distribute .combinations .mirrored_strategy_with_two_gpus ,
127
- tf .__internal__ .distribute .combinations .tpu_strategy_packed_var
125
+ combinations .mirrored_strategy_with_one_cpu ,
126
+ combinations .mirrored_strategy_with_gpu_and_cpu ,
127
+ combinations .mirrored_strategy_with_two_gpus ,
128
+ combinations .tpu_strategy_packed_var ,
129
+ combinations .parameter_server_strategy_1worker_2ps_cpu ,
130
+ combinations .parameter_server_strategy_1worker_2ps_1gpu ,
128
131
],
129
132
mode = ["eager" ],
130
- jit_compile = [False ]) +
131
- tf .__internal__ .test .combinations .combine (
132
- distribution = [tf .__internal__ .distribute .combinations .mirrored_strategy_with_two_gpus ],
133
- mode = ["eager" ],
134
- jit_compile = [True ]))
133
+ jit_compile = [False ]) + tf .__internal__ .test .combinations .combine (
134
+ distribution = [combinations .mirrored_strategy_with_two_gpus ],
135
+ mode = ["eager" ],
136
+ jit_compile = [True ]))
135
137
def testAddMetric (self , distribution , jit_compile ):
136
138
if not tf .__internal__ .tf2 .enabled ():
137
139
self .skipTest ("Skip test since tf2 is not enabled. Pass "
@@ -164,7 +166,13 @@ def func():
164
166
def run ():
165
167
return distribution .run (func )
166
168
167
- run ()
169
+ if distribution ._should_use_with_coordinator :
170
+ coord = tf .distribute .experimental .coordinator .ClusterCoordinator (
171
+ distribution )
172
+ coord .schedule (run )
173
+ coord .join ()
174
+ else :
175
+ run ()
168
176
169
177
self .assertEqual (layer .metrics [0 ].result ().numpy (),
170
178
1.0 * distribution .num_replicas_in_sync )
0 commit comments