@@ -293,7 +293,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
293
293
static_shape = inputs .get_shape ()
294
294
if not static_shape or len (static_shape ) != 4 :
295
295
raise ValueError ("Inputs to conv must have statically known rank 4." )
296
- inputs .set_shape ([static_shape [0 ], None , None , static_shape [3 ]])
296
+ # inputs.set_shape([static_shape[0], None, None, static_shape[3]])
297
297
# Add support for left padding.
298
298
if "padding" in kwargs and kwargs ["padding" ] == "LEFT" :
299
299
dilation_rate = (1 , 1 )
@@ -307,9 +307,9 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
307
307
width_padding = 0 if static_shape [2 ] == 1 else cond_padding
308
308
padding = [[0 , 0 ], [height_padding , 0 ], [width_padding , 0 ], [0 , 0 ]]
309
309
inputs = tf .pad (inputs , padding )
310
+ # Set middle two dimensions to None to prevent convolution from complaining
311
+ inputs .set_shape ([static_shape [0 ], None , None , static_shape [3 ]])
310
312
kwargs ["padding" ] = "VALID"
311
- # Special argument we use to force 2d kernels (see below).
312
- force2d = kwargs .get ("force2d" , True )
313
313
314
314
def conv2d_kernel (kernel_size_arg , name_suffix ):
315
315
"""Call conv2d but add suffix to name."""
@@ -329,17 +329,7 @@ def conv2d_kernel(kernel_size_arg, name_suffix):
329
329
kwargs ["force2d" ] = original_force2d
330
330
return result
331
331
332
- # Manually setting the shape to be unknown in the middle two dimensions so
333
- # that the `tf.cond` below won't throw an error based on the convolution
334
- # kernels being too large for the data.
335
- inputs ._shape = tf .TensorShape ([static_shape [0 ], None , None , static_shape [3 ]]) # pylint: disable=protected-access
336
- if kernel_size [1 ] == 1 or force2d :
337
- # Avoiding the cond below can speed up graph and gradient construction.
338
- return conv2d_kernel (kernel_size , "single" )
339
- return tf .cond (
340
- tf .equal (tf .shape (inputs )[2 ],
341
- 1 ), lambda : conv2d_kernel ((kernel_size [0 ], 1 ), "small" ),
342
- lambda : conv2d_kernel (kernel_size , "std" ))
332
+ return conv2d_kernel (kernel_size , "single" )
343
333
344
334
345
335
def conv (inputs , filters , kernel_size , ** kwargs ):
@@ -566,20 +556,8 @@ def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)):
566
556
inputs = tf .pad (inputs , padding_ )
567
557
inputs .set_shape ([static_shape [0 ], None , None , static_shape [3 ]])
568
558
padding = "VALID"
569
- window_size_small = (window_size [0 ], 1 )
570
- strides_small = (strides [0 ], 1 )
571
- # Manually setting the shape to be unknown in the middle two dimensions so
572
- # that the `tf.cond` below won't throw an error based on the convolution
573
- # kernels being too large for the data.
574
- inputs ._shape = tf .TensorShape ( # pylint: disable=protected-access
575
- [static_shape [0 ], None , None , static_shape [3 ]])
576
- return tf .cond (
577
- tf .equal (tf .shape (inputs )[2 ], 1 ),
578
- lambda : tf .nn .pool ( # pylint: disable=g-long-lambda
579
- inputs , window_size_small , pooling_type , padding ,
580
- strides = strides_small ),
581
- lambda : tf .nn .pool ( # pylint: disable=g-long-lambda
582
- inputs , window_size , pooling_type , padding , strides = strides ))
559
+
560
+ return tf .nn .pool (inputs , window_size , pooling_type , padding , strides = strides )
583
561
584
562
585
563
def conv_block_downsample (x ,
0 commit comments