Skip to content

Commit c3a8519

Browse files
committed
Fix issue with concatenate, masking and symbolic inputs
We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors.
1 parent ac5c97f commit c3a8519

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

keras/src/layers/merging/concatenate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,13 @@ def compute_mask(self, inputs, mask=None):
145145
# Input is unmasked. Append all 1s to masks,
146146
masks.append(ops.ones_like(input_i, dtype="bool"))
147147
elif mask_i.ndim < input_i.ndim:
148-
# Mask is smaller than the input, expand it
149-
masks.append(
150-
ops.broadcast_to(
151-
ops.expand_dims(mask_i, axis=-1), ops.shape(input_i)
152-
)
148+
# Broadcast mask shape to match in a way where we capture the
149+
# input as a symbolic input in the op graph.
150+
mask_i = ops.logical_or(
151+
ops.expand_dims(mask_i, axis=-1),
152+
ops.zeros_like(input_i, dtype="bool"),
153153
)
154+
masks.append(mask_i)
154155
else:
155156
masks.append(mask_i)
156157
concatenated = ops.concatenate(masks, axis=self.axis)

keras/src/layers/merging/merging_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ def test_concatenate_with_mask(self):
340340
)
341341
self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]])
342342

343+
def test_concatenate_with_mask_symbolic(self):
344+
input1 = layers.Input((4, 2))
345+
input2 = layers.Input((4, 2))
346+
mask = layers.Masking()
347+
output = layers.Concatenate(axis=1)([mask(input1), input2])
348+
model = models.Model(
349+
inputs=[input1, input2], outputs=output._keras_mask
350+
)
351+
x1 = backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])
352+
x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])
353+
self.assertAllClose(model([x1, x2]), [[0, 1, 0, 1, 1, 1, 1, 1]])
354+
343355
def test_concatenate_errors(self):
344356
# This should work
345357
x1 = np.ones((1, 1, 1, 1, 5))

0 commit comments

Comments
 (0)