Skip to content

Conversation

adityagupta1089
Copy link
Contributor

… spec handling and statistics aggregation

Copy link

Summary of Changes

Hello @adityagupta1089, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the DistributedEmbedding layer and its associated utility functions by streamlining how table specifications are handled and how input statistics are aggregated. The primary goal is to reduce code complexity and improve maintainability by leveraging more functionality directly from the underlying jax_tpu_embedding library, thereby making the process of synchronizing and updating embedding configurations more efficient and robust.

Highlights

  • Refactored DistributedEmbedding: The DistributedEmbedding layer now directly utilizes functions from the jax_tpu_embedding.sparsecore.lib.nn.embedding module for retrieving table specifications, simplifying its internal logic.
  • Streamlined Statistics Aggregation: The process of aggregating input statistics across devices has been significantly refactored. It now uses a pmax_aggregate function for efficient synchronization and directly updates preprocessing parameters via embedding.update_preprocessing_parameters when statistics change, replacing previous custom logic.
  • Simplified embedding_utils: The embedding_utils.py file has been cleaned up by removing redundant functions like InputStatsPerTable, get_stacked_table_stats, and update_stacked_table_stats. Its get_table_specs function now acts as a direct passthrough to the embedding library's equivalent, and other functions have been updated to return the more direct embedding.SparseDenseMatmulInputStats object.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors DistributedEmbedding and embedding_utils to better leverage the jax_tpu_embedding library. The changes streamline table spec handling and statistics aggregation by removing custom utility functions and types in favor of their upstream equivalents. This reduces code duplication and improves maintainability. My review includes a suggestion to refactor a loop for checking statistics changes, which also fixes a potential TypeError.

@adityagupta1089 adityagupta1089 force-pushed the manuadg/decouple-jax-sc-keras-rs-2 branch 6 times, most recently from 785bee3 to 31fa994 Compare October 2, 2025 16:40
@adityagupta1089 adityagupta1089 force-pushed the manuadg/decouple-jax-sc-keras-rs-2 branch from 31fa994 to 5494318 Compare October 2, 2025 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants