Skip to content

Commit 62c395e

Browse files
taehoonleefchollet
authored andcommitted
Make separable conv backend tests efficient (#9570)
1 parent 614a8b4 commit 62c395e

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

tests/keras/backend/backend_test.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ def ref_depthwise_conv(x, w, padding, data_format):
257257
return y
258258

259259

260+
def ref_separable_conv(x, w1, w2, padding, data_format):
261+
x2 = ref_depthwise_conv(x, w1, padding, data_format)
262+
return ref_conv(x2, w2, padding, data_format)
263+
264+
260265
def ref_rnn(x, w, init, go_backwards=False, mask=None, unroll=False, input_length=None):
261266
w_i, w_h, w_o = w
262267
h = []
@@ -1086,27 +1091,30 @@ def legacy_test_conv3d(self):
10861091
BACKENDS, cntk_dynamicity=True,
10871092
data_format=data_format)
10881093

1089-
def test_separable_conv2d(self):
1090-
for (input_shape, data_format) in [
1091-
((2, 3, 4, 5), 'channels_first'),
1092-
((2, 3, 5, 6), 'channels_first'),
1093-
((1, 6, 5, 3), 'channels_last')]:
1094-
input_depth = input_shape[1] if data_format == 'channels_first' else input_shape[-1]
1095-
_, x_val = parse_shape_or_val(input_shape)
1096-
x_tf = KTF.variable(x_val)
1097-
for kernel_shape in [(2, 2), (4, 3)]:
1098-
for depth_multiplier in [1, 2]:
1099-
_, depthwise_val = parse_shape_or_val(kernel_shape + (input_depth, depth_multiplier))
1100-
_, pointwise_val = parse_shape_or_val((1, 1) + (input_depth * depth_multiplier, 7))
1101-
1102-
z_tf = KTF.eval(KTF.separable_conv2d(x_tf, KTF.variable(depthwise_val),
1103-
KTF.variable(pointwise_val),
1104-
data_format=data_format))
1105-
z_c = cntk_func_three_tensor('separable_conv2d', input_shape,
1106-
depthwise_val,
1107-
pointwise_val,
1108-
data_format=data_format)([x_val])[0]
1109-
assert_allclose(z_tf, z_c, 1e-3)
1094+
@pytest.mark.skipif(K.backend() == 'theano', reason='Not supported.')
1095+
@pytest.mark.parametrize('op,input_shape,kernel_shape,depth_multiplier,padding,data_format', [
1096+
('separable_conv2d', (2, 3, 4, 5), (3, 3), 1, 'same', 'channels_first'),
1097+
('separable_conv2d', (2, 3, 5, 6), (4, 3), 2, 'valid', 'channels_first'),
1098+
('separable_conv2d', (1, 6, 5, 3), (3, 4), 1, 'valid', 'channels_last'),
1099+
('separable_conv2d', (1, 7, 6, 3), (3, 3), 2, 'same', 'channels_last'),
1100+
])
1101+
def test_separable_conv2d(self, op, input_shape, kernel_shape, depth_multiplier, padding, data_format):
1102+
input_depth = input_shape[1] if data_format == 'channels_first' else input_shape[-1]
1103+
_, x = parse_shape_or_val(input_shape)
1104+
_, depthwise = parse_shape_or_val(kernel_shape + (input_depth, depth_multiplier))
1105+
_, pointwise = parse_shape_or_val((1, 1) + (input_depth * depth_multiplier, 7))
1106+
y1 = ref_separable_conv(x, depthwise, pointwise, padding, data_format)
1107+
if K.backend() == 'cntk':
1108+
y2 = cntk_func_three_tensor(
1109+
op, input_shape,
1110+
depthwise, pointwise,
1111+
padding=padding, data_format=data_format)([x])[0]
1112+
else:
1113+
y2 = K.eval(getattr(K, op)(
1114+
K.variable(x),
1115+
K.variable(depthwise), K.variable(pointwise),
1116+
padding=padding, data_format=data_format))
1117+
assert_allclose(y1, y2, atol=1e-05)
11101118

11111119
def test_pool2d(self):
11121120
check_single_tensor_operation('pool2d', (5, 10, 12, 3),

0 commit comments

Comments
 (0)