Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b287c0e

Browse files
authored
Merge pull request #92 from kolloldas/fix_issue_80
Fix issue 80: Crash with `algorithmic_multiplication_binary40` on `neural_gpu`
2 parents 4d5c599 + be043f9 commit b287c0e

File tree

2 files changed

+8
-30
lines changed

2 files changed

+8
-30
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
293293
static_shape = inputs.get_shape()
294294
if not static_shape or len(static_shape) != 4:
295295
raise ValueError("Inputs to conv must have statically known rank 4.")
296-
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
296+
#inputs.set_shape([static_shape[0], None, None, static_shape[3]])
297297
# Add support for left padding.
298298
if "padding" in kwargs and kwargs["padding"] == "LEFT":
299299
dilation_rate = (1, 1)
@@ -307,9 +307,9 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
307307
width_padding = 0 if static_shape[2] == 1 else cond_padding
308308
padding = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
309309
inputs = tf.pad(inputs, padding)
310+
# Set middle two dimensions to None to prevent convolution from complaining
311+
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
310312
kwargs["padding"] = "VALID"
311-
# Special argument we use to force 2d kernels (see below).
312-
force2d = kwargs.get("force2d", True)
313313

314314
def conv2d_kernel(kernel_size_arg, name_suffix):
315315
"""Call conv2d but add suffix to name."""
@@ -329,17 +329,7 @@ def conv2d_kernel(kernel_size_arg, name_suffix):
329329
kwargs["force2d"] = original_force2d
330330
return result
331331

332-
# Manually setting the shape to be unknown in the middle two dimensions so
333-
# that the `tf.cond` below won't throw an error based on the convolution
334-
# kernels being too large for the data.
335-
inputs._shape = tf.TensorShape([static_shape[0], None, None, static_shape[3]]) # pylint: disable=protected-access
336-
if kernel_size[1] == 1 or force2d:
337-
# Avoiding the cond below can speed up graph and gradient construction.
338-
return conv2d_kernel(kernel_size, "single")
339-
return tf.cond(
340-
tf.equal(tf.shape(inputs)[2],
341-
1), lambda: conv2d_kernel((kernel_size[0], 1), "small"),
342-
lambda: conv2d_kernel(kernel_size, "std"))
332+
return conv2d_kernel(kernel_size, "single")
343333

344334

345335
def conv(inputs, filters, kernel_size, **kwargs):
@@ -566,20 +556,8 @@ def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)):
566556
inputs = tf.pad(inputs, padding_)
567557
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
568558
padding = "VALID"
569-
window_size_small = (window_size[0], 1)
570-
strides_small = (strides[0], 1)
571-
# Manually setting the shape to be unknown in the middle two dimensions so
572-
# that the `tf.cond` below won't throw an error based on the convolution
573-
# kernels being too large for the data.
574-
inputs._shape = tf.TensorShape( # pylint: disable=protected-access
575-
[static_shape[0], None, None, static_shape[3]])
576-
return tf.cond(
577-
tf.equal(tf.shape(inputs)[2], 1),
578-
lambda: tf.nn.pool( # pylint: disable=g-long-lambda
579-
inputs, window_size_small, pooling_type, padding,
580-
strides=strides_small),
581-
lambda: tf.nn.pool( # pylint: disable=g-long-lambda
582-
inputs, window_size, pooling_type, padding, strides=strides))
559+
560+
return tf.nn.pool(inputs, window_size, pooling_type, padding, strides=strides)
583561

584562

585563
def conv_block_downsample(x,

tensor2tensor/models/common_layers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ def testShiftLeft(self):
277277
self.assertAllEqual(actual, expected)
278278

279279
def testConvStride2MultiStep(self):
280-
x1 = np.random.rand(5, 32, 1, 11)
280+
x1 = np.random.rand(5, 32, 16, 11)
281281
with self.test_session() as session:
282282
a = common_layers.conv_stride2_multistep(
283283
tf.constant(x1, dtype=tf.float32), 4, 16)
284284
session.run(tf.global_variables_initializer())
285285
actual = session.run(a[0])
286-
self.assertEqual(actual.shape, (5, 2, 0, 16))
286+
self.assertEqual(actual.shape, (5, 2, 1, 16))
287287

288288
def testDeconvStride2MultiStep(self):
289289
x1 = np.random.rand(5, 2, 1, 11)

0 commit comments

Comments
 (0)