diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index c1f4e8066e37..be95bfdead36 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2742,6 +2742,45 @@ def round(x, decimals=0): def tile(x, repeats): x = convert_to_tensor(x) + + # Check if repeats contains only concrete integers + # If so, keep it as a Python list/tuple for better shape inference + try: + if isinstance(repeats, (list, tuple)): + # Try to extract concrete integer values + concrete_repeats = [] + for r in repeats: + if isinstance(r, int): + concrete_repeats.append(r) + elif hasattr(r, "numpy") and r.shape == (): + # Scalar tensor with concrete value + concrete_repeats.append(int(r.numpy())) + else: + # Not a concrete value, fall back to tensor path + concrete_repeats = None + break + + if concrete_repeats is not None: + # Use concrete repeats directly for better shape inference + repeats = concrete_repeats + # Pad or trim repeats to match x rank + x_rank = x.shape.rank + if x_rank is not None: + if len(repeats) < x_rank: + repeats = [1] * (x_rank - len(repeats)) + repeats + elif len(repeats) > x_rank: + # Need to reshape x to match repeats length + x_shape_list = [1] * (len(repeats) - x_rank) + [ + d if d is not None else -1 + for d in x.shape.as_list() + ] + x = tf.reshape(x, x_shape_list) + return tf.tile(x, repeats) + except (AttributeError, TypeError, ValueError): + # If anything goes wrong, fall back to original implementation + pass + + # Original dynamic implementation for non-concrete repeats repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1]) repeats_size = tf.size(repeats) repeats = tf.pad( diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 5190ff2cd807..214ebfe194f0 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -6411,6 +6411,15 @@ def compute_output_spec(self, x): repeats = self.repeats if isinstance(repeats, int): repeats = [repeats] + + # Convert repeats to list if it's a tuple or other iterable + # and extract concrete integer values + if not isinstance(repeats, list): + try: + repeats = list(repeats) + except TypeError: + repeats = [repeats] + if len(x_shape) > len(repeats): repeats = [1] * (len(x_shape) - len(repeats)) + repeats else: @@ -6418,10 +6427,15 @@ def compute_output_spec(self, x): output_shape = [] for x_size, repeat in zip(x_shape, repeats): + # Check if repeat is a concrete integer value + # If it's a symbolic tensor or unknown, we can't infer the size if x_size is None: output_shape.append(None) - else: + elif isinstance(repeat, int): output_shape.append(x_size * repeat) + else: + # repeat is symbolic (e.g., KerasTensor, tf.Tensor, etc.) + output_shape.append(None) return KerasTensor(output_shape, dtype=x.dtype) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 42a8c37b49e3..6a452ca4f46d 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1821,6 +1821,10 @@ def test_tile(self): self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6)) self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6)) + # Test with multi-dimensional input + x = KerasTensor((None, 3, 2, 2)) + self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2)) + def test_trace(self): x = KerasTensor((None, 3, None, 5)) self.assertEqual(knp.trace(x).shape, (None, 5)) @@ -9507,3 +9511,24 @@ def call(self, x): model.compile(jit_compile=jit_compile) model.predict(np.random.randn(1, 8)) + + def test_tile_shape_inference_in_layer(self): + """Test that ops.tile properly infers output shape in a Layer. + + This is a regression test for issue #20914 where TensorFlow backend + would return all-None shapes when tile was called inside a Layer's + call method with concrete integer repeats. + """ + + class TileLayer(keras.layers.Layer): + def call(self, x): + # Use concrete integer repeats + repeats = [1, 2, 1, 1] + return knp.tile(x, repeats) + + inputs = keras.Input(shape=(3, 2, 2)) + output = TileLayer()(inputs) + + # With the fix, output shape should be (None, 6, 2, 2) + # Before the fix, it was (None, None, None, None) + self.assertEqual(output.shape, (None, 6, 2, 2))