Skip to content

Commit 71d6d6e

Browse files
ckkuangtensorflower-gardener
authored andcommitted
Add ParameterServerStrategy combination to ketrics_metrics_test.
PiperOrigin-RevId: 381424814
1 parent 3fef2b1 commit 71d6d6e

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

keras/distribute/keras_metrics_test.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
# ==============================================================================
1515
"""Tests for Keras metrics."""
1616

17-
import tensorflow.compat.v2 as tf
18-
1917
from absl.testing import parameterized
2018
from keras import metrics
2119
from keras.engine import base_layer
20+
import tensorflow.compat.v2 as tf
21+
22+
combinations = tf.__internal__.distribute.combinations
2223

2324

2425
def _labeled_dataset_fn():
@@ -67,18 +68,18 @@ def _regression_dataset_fn():
6768
def all_combinations():
6869
return tf.__internal__.test.combinations.combine(
6970
distribution=[
70-
tf.__internal__.distribute.combinations.default_strategy,
71-
tf.__internal__.distribute.combinations.one_device_strategy,
72-
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu,
73-
tf.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus
71+
combinations.default_strategy, combinations.one_device_strategy,
72+
combinations.mirrored_strategy_with_gpu_and_cpu,
73+
combinations.mirrored_strategy_with_two_gpus
7474
],
7575
mode=["graph"])
7676

7777

7878
def tpu_combinations():
7979
return tf.__internal__.test.combinations.combine(
80-
distribution=[tf.__internal__.distribute.combinations.tpu_strategy,],
81-
mode=["graph"])
80+
distribution=[
81+
combinations.tpu_strategy,
82+
], mode=["graph"])
8283

8384

8485
class KerasMetricsTest(tf.test.TestCase, parameterized.TestCase):
@@ -106,7 +107,7 @@ def _test_metric(self, distribution, dataset_fn, metric_init_fn, expected_fn):
106107
if batches_consumed >= 4: # Consume 4 input batches in total.
107108
break
108109

109-
@tf.__internal__.distribute.combinations.generate(all_combinations() + tpu_combinations())
110+
@combinations.generate(all_combinations() + tpu_combinations())
110111
def testMean(self, distribution):
111112
def _dataset_fn():
112113
return tf.data.Dataset.range(1000).map(tf.compat.v1.to_float).batch(
@@ -118,20 +119,21 @@ def _expected_fn(num_batches):
118119

119120
self._test_metric(distribution, _dataset_fn, metrics.Mean, _expected_fn)
120121

121-
@tf.__internal__.distribute.combinations.generate(
122+
@combinations.generate(
122123
tf.__internal__.test.combinations.combine(
123124
distribution=[
124-
tf.__internal__.distribute.combinations.mirrored_strategy_with_one_cpu,
125-
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu,
126-
tf.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus,
127-
tf.__internal__.distribute.combinations.tpu_strategy_packed_var
125+
combinations.mirrored_strategy_with_one_cpu,
126+
combinations.mirrored_strategy_with_gpu_and_cpu,
127+
combinations.mirrored_strategy_with_two_gpus,
128+
combinations.tpu_strategy_packed_var,
129+
combinations.parameter_server_strategy_1worker_2ps_cpu,
130+
combinations.parameter_server_strategy_1worker_2ps_1gpu,
128131
],
129132
mode=["eager"],
130-
jit_compile=[False]) +
131-
tf.__internal__.test.combinations.combine(
132-
distribution=[tf.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus],
133-
mode=["eager"],
134-
jit_compile=[True]))
133+
jit_compile=[False]) + tf.__internal__.test.combinations.combine(
134+
distribution=[combinations.mirrored_strategy_with_two_gpus],
135+
mode=["eager"],
136+
jit_compile=[True]))
135137
def testAddMetric(self, distribution, jit_compile):
136138
if not tf.__internal__.tf2.enabled():
137139
self.skipTest("Skip test since tf2 is not enabled. Pass "
@@ -164,7 +166,13 @@ def func():
164166
def run():
165167
return distribution.run(func)
166168

167-
run()
169+
if distribution._should_use_with_coordinator:
170+
coord = tf.distribute.experimental.coordinator.ClusterCoordinator(
171+
distribution)
172+
coord.schedule(run)
173+
coord.join()
174+
else:
175+
run()
168176

169177
self.assertEqual(layer.metrics[0].result().numpy(),
170178
1.0 * distribution.num_replicas_in_sync)

0 commit comments

Comments
 (0)