From c3a85192bb3c350ce9660e43f4c477c94871d59c Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Sat, 23 Aug 2025 23:03:42 +0000 Subject: [PATCH] Fix issue with concatenate, masking and symbolic inputs We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors. --- keras/src/layers/merging/concatenate.py | 11 ++++++----- keras/src/layers/merging/merging_test.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/keras/src/layers/merging/concatenate.py b/keras/src/layers/merging/concatenate.py index b19f3c0e6e4d..1ee3913b6581 100644 --- a/keras/src/layers/merging/concatenate.py +++ b/keras/src/layers/merging/concatenate.py @@ -145,12 +145,13 @@ def compute_mask(self, inputs, mask=None): # Input is unmasked. Append all 1s to masks, masks.append(ops.ones_like(input_i, dtype="bool")) elif mask_i.ndim < input_i.ndim: - # Mask is smaller than the input, expand it - masks.append( - ops.broadcast_to( - ops.expand_dims(mask_i, axis=-1), ops.shape(input_i) - ) + # Broadcast mask shape to match in a way where we capture the + # input as a symbolic input in the op graph. + mask_i = ops.logical_or( + ops.expand_dims(mask_i, axis=-1), + ops.zeros_like(input_i, dtype="bool"), ) + masks.append(mask_i) else: masks.append(mask_i) concatenated = ops.concatenate(masks, axis=self.axis) diff --git a/keras/src/layers/merging/merging_test.py b/keras/src/layers/merging/merging_test.py index 0008ffd7af86..977ad9c2cc1d 100644 --- a/keras/src/layers/merging/merging_test.py +++ b/keras/src/layers/merging/merging_test.py @@ -340,6 +340,18 @@ def test_concatenate_with_mask(self): ) self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]]) + def test_concatenate_with_mask_symbolic(self): + input1 = layers.Input((4, 2)) + input2 = layers.Input((4, 2)) + mask = layers.Masking() + output = layers.Concatenate(axis=1)([mask(input1), input2]) + model = models.Model( + inputs=[input1, input2], outputs=output._keras_mask + ) + x1 = backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + self.assertAllClose(model([x1, x2]), [[0, 1, 0, 1, 1, 1, 1, 1]]) + def test_concatenate_errors(self): # This should work x1 = np.ones((1, 1, 1, 1, 5))