@@ -257,6 +257,11 @@ def ref_depthwise_conv(x, w, padding, data_format):
257
257
return y
258
258
259
259
260
+ def ref_separable_conv (x , w1 , w2 , padding , data_format ):
261
+ x2 = ref_depthwise_conv (x , w1 , padding , data_format )
262
+ return ref_conv (x2 , w2 , padding , data_format )
263
+
264
+
260
265
def ref_rnn (x , w , init , go_backwards = False , mask = None , unroll = False , input_length = None ):
261
266
w_i , w_h , w_o = w
262
267
h = []
@@ -1086,27 +1091,30 @@ def legacy_test_conv3d(self):
1086
1091
BACKENDS , cntk_dynamicity = True ,
1087
1092
data_format = data_format )
1088
1093
1089
- def test_separable_conv2d (self ):
1090
- for (input_shape , data_format ) in [
1091
- ((2 , 3 , 4 , 5 ), 'channels_first' ),
1092
- ((2 , 3 , 5 , 6 ), 'channels_first' ),
1093
- ((1 , 6 , 5 , 3 ), 'channels_last' )]:
1094
- input_depth = input_shape [1 ] if data_format == 'channels_first' else input_shape [- 1 ]
1095
- _ , x_val = parse_shape_or_val (input_shape )
1096
- x_tf = KTF .variable (x_val )
1097
- for kernel_shape in [(2 , 2 ), (4 , 3 )]:
1098
- for depth_multiplier in [1 , 2 ]:
1099
- _ , depthwise_val = parse_shape_or_val (kernel_shape + (input_depth , depth_multiplier ))
1100
- _ , pointwise_val = parse_shape_or_val ((1 , 1 ) + (input_depth * depth_multiplier , 7 ))
1101
-
1102
- z_tf = KTF .eval (KTF .separable_conv2d (x_tf , KTF .variable (depthwise_val ),
1103
- KTF .variable (pointwise_val ),
1104
- data_format = data_format ))
1105
- z_c = cntk_func_three_tensor ('separable_conv2d' , input_shape ,
1106
- depthwise_val ,
1107
- pointwise_val ,
1108
- data_format = data_format )([x_val ])[0 ]
1109
- assert_allclose (z_tf , z_c , 1e-3 )
1094
+ @pytest .mark .skipif (K .backend () == 'theano' , reason = 'Not supported.' )
1095
+ @pytest .mark .parametrize ('op,input_shape,kernel_shape,depth_multiplier,padding,data_format' , [
1096
+ ('separable_conv2d' , (2 , 3 , 4 , 5 ), (3 , 3 ), 1 , 'same' , 'channels_first' ),
1097
+ ('separable_conv2d' , (2 , 3 , 5 , 6 ), (4 , 3 ), 2 , 'valid' , 'channels_first' ),
1098
+ ('separable_conv2d' , (1 , 6 , 5 , 3 ), (3 , 4 ), 1 , 'valid' , 'channels_last' ),
1099
+ ('separable_conv2d' , (1 , 7 , 6 , 3 ), (3 , 3 ), 2 , 'same' , 'channels_last' ),
1100
+ ])
1101
+ def test_separable_conv2d (self , op , input_shape , kernel_shape , depth_multiplier , padding , data_format ):
1102
+ input_depth = input_shape [1 ] if data_format == 'channels_first' else input_shape [- 1 ]
1103
+ _ , x = parse_shape_or_val (input_shape )
1104
+ _ , depthwise = parse_shape_or_val (kernel_shape + (input_depth , depth_multiplier ))
1105
+ _ , pointwise = parse_shape_or_val ((1 , 1 ) + (input_depth * depth_multiplier , 7 ))
1106
+ y1 = ref_separable_conv (x , depthwise , pointwise , padding , data_format )
1107
+ if K .backend () == 'cntk' :
1108
+ y2 = cntk_func_three_tensor (
1109
+ op , input_shape ,
1110
+ depthwise , pointwise ,
1111
+ padding = padding , data_format = data_format )([x ])[0 ]
1112
+ else :
1113
+ y2 = K .eval (getattr (K , op )(
1114
+ K .variable (x ),
1115
+ K .variable (depthwise ), K .variable (pointwise ),
1116
+ padding = padding , data_format = data_format ))
1117
+ assert_allclose (y1 , y2 , atol = 1e-05 )
1110
1118
1111
1119
def test_pool2d (self ):
1112
1120
check_single_tensor_operation ('pool2d' , (5 , 10 , 12 , 3 ),
0 commit comments