Skip to content

Conversation

mattdangerw
Copy link
Member

We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors.

Fixes #21581

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mattdangerw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue within Keras's Concatenate layer concerning how masking interacts with symbolic input shapes. Previously, the system incorrectly used symbolic input shapes as fixed broadcast shapes. The fix ensures that the input is captured as an input node, allowing its shape to be used for broadcasting at execution time with real input tensors, thereby resolving an issue where masking was not correctly applied with symbolic inputs.

Highlights

  • Masking with Symbolic Inputs: The compute_mask method in Concatenate layers has been updated to correctly handle broadcasting of masks when the mask's dimensionality is less than the input's. The previous broadcast_to approach was replaced with a sequence of operations (expand_dims, cast, zeros_like, cast) to ensure the input is captured as a symbolic input in the operation graph, enabling proper broadcasting at execution time.
  • New Test Case for Symbolic Masking: A new test, test_concatenate_with_mask_symbolic, has been added to merging_test.py. This test specifically validates the fix by checking the behavior of Concatenate with Masking when symbolic inputs are used, ensuring the mask is correctly propagated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses an issue with mask computation in the Concatenate layer when dealing with symbolic inputs. The approach of using zeros_like to facilitate broadcasting in a graph-compatible way is a solid fix. The addition of a targeted test case for symbolic inputs is also a great way to prevent regressions. I have one suggestion to refactor the implementation for better readability and to use more idiomatic boolean operations.

@codecov-commenter
Copy link

codecov-commenter commented Aug 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.54%. Comparing base (ac5c97f) to head (c3a8519).
⚠️ Report is 22 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21611      +/-   ##
==========================================
- Coverage   82.71%   82.54%   -0.18%     
==========================================
  Files         568      571       +3     
  Lines       56897    57585     +688     
  Branches     8890     8994     +104     
==========================================
+ Hits        47063    47533     +470     
- Misses       7640     7759     +119     
- Partials     2194     2293      +99     
Flag Coverage Δ
keras 82.34% <100.00%> (-0.18%) ⬇️
keras-jax 63.55% <100.00%> (-0.11%) ⬇️
keras-numpy 57.89% <0.00%> (-0.37%) ⬇️
keras-openvino 34.34% <0.00%> (-0.21%) ⬇️
keras-tensorflow 64.25% <100.00%> (+0.04%) ⬆️
keras-torch 63.77% <100.00%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mattdangerw mattdangerw force-pushed the masking-fix branch 2 times, most recently from d0a8021 to ca1738b Compare August 25, 2025 17:07
)
# Broadcast mask shape to match in a way where we capture the
# input as a symbolic input in the op graph.
mask_i = ops.logical_or(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're doing an ops.logical_or just to do the broadcast.

Do you understand why this works and ops.broadcast_to doesn't?

In other words, should we try to fix ops.broadcast_to? Is it a bug in tf.broadcast_to?

@tobiasharren
Copy link

I've looked into this a bit and it seems that there is a potential failure mode in BroadcastTo with symbolic shapes, as identified by @mattdangerw I was able to create a fix that prevents this in the not compiled example from #21581
In the compiled variant, this fixes it for jax backend (at least on cpu), but still fails in tensorflow.

I will try to provide the commit.

@tobiasharren
Copy link

I have provided this as a pull request to the fork of @mattdangerw
See: mattdangerw#2

We were trying to grab use a symbolic input shape as a fixed
broadcast shape. Instead we need to capture the input as a input
node who's shape should be used to broadcast at execution time on
real input tensors.
@mattdangerw
Copy link
Member Author

@hertschuh @tobiasharren so I am not sure that other fix will work.

The issue isn't really a bug with broadcast_to, it's just kind of inherent in the design of the op and the fact that Keras doesn't have the notion of a symbolic shape. The previous code was calling broadcast_to(first_input, ops.shape(second_input)). The issue is that when the second input is a symbolic input, we will have a None in the output of ops.shape and there no way to fix that at the broadcast_to level, it get's a None shape as an input argument that cannot be resolved to a concrete shape.

I think we should land this fix for now. Longer term there are two possibilities:

  • Write a broadcast_to_tensor(tensor1, tensor2) op so we can have the symbolic input as the op input and capture it so we can evaluate the shape at execution time.
  • Add some notion of symbolic shape when ops.shape is called on a KerasTensor. This would be tricky and need some though, might not even be feasible.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Sep 4, 2025
@tobiasharren
Copy link

Alright, thanks @mattdangerw for looking into this and creating the fix!

@fchollet fchollet merged commit d936dc9 into keras-team:master Sep 7, 2025
10 of 11 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Sep 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BatchNormalization fails after Concatenation of masked Embeddings
7 participants