Skip to content

Conversation

buildwithsuhana
Copy link
Contributor

@buildwithsuhana buildwithsuhana commented Oct 8, 2025

This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras, enabling users to train models that are too large to fit on a single accelerator. The core of this feature is a new distribution strategy, AutoTPDistribution, which provides a simple, high-level API to shard an existing Keras model across multiple devices.

Description
This framework is designed to automate the complex process of model sharding and inter-device communication, making large-scale model training more accessible. The implementation is broken down into several key components:

  1. High-Level User API: AutoTPDistribution
    A new distribution strategy, AutoTPDistribution, is introduced in keras/src/distribution/distribution_lib.py. This class serves as the primary user entry point. The workflow is straightforward:
  • A user defines their hardware topology using a DeviceMesh.
  • They instantiate the AutoTPDistribution strategy.
  • They pass their standard Keras model to the distribution.shard() method.
  • This returns a new, sharded model instance ready for distributed training.
  1. Core Sharding Engine: TensorParallelKeras
    A new TensorParallelKeras model class (keras.src.distribution.tensor_parallel.tensor_parallel.py) acts as the core engine. When distribution.shard() is called, this class wraps the original model and performs the following actions:

Auto-Configures Hardware: Discovers and assigns available devices (TPU, GPU, or CPU).

Shards Parameters: It analyzes the model's layers and applies column-parallel or row-parallel sharding strategies to the weights and biases of relevant layers (e.g., Dense).

Builds a Unified Graph: It creates a single, assembled Keras Functional model that internally manages the parallel computation. This clever design encapsulates the communication logic (e.g., AllGather, ReduceScatter) within the model's call graph, simplifying the execution and enabling JIT compilation. Partial outputs from each device shard are correctly combined (e.g., concatenation for column-parallel, summation for row-parallel).

Coordinates Gradients: It overrides the compile method to wrap the user's optimizer in a TensorParallelOptimizer, which handles the synchronized computation and application of gradients across all shards.

Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing

This is the 4th (out of 4) PR for AutoSharding Keras.

@buildwithsuhana buildwithsuhana marked this pull request as draft October 8, 2025 07:47
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This PR introduces automatic tensor parallelism to Keras, allowing models to be sharded across multiple devices for training larger models. It provides a high-level AutoTPDistribution API that uses a new TensorParallelKeras class to handle device auto-configuration, parameter sharding (column-parallel, row-parallel), and the construction of a unified functional model that manages inter-device communication and gradient coordination. New JAX backend utilities and extensive tests for sharding logic and communication have also been added.

Highlights

  • Intent: This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras. The primary goal is to enable users to train models that are too large to fit on a single accelerator by automating model sharding and inter-device communication. It provides a high-level API to shard existing Keras models across multiple devices.
  • Key Changes: The core of this feature is the new AutoTPDistribution strategy, which serves as the user's entry point to shard a Keras model. This strategy leverages a new TensorParallelKeras model class that acts as the core sharding engine. This engine automatically configures hardware, applies column-parallel or row-parallel sharding to relevant layer weights (e.g., Dense layers), and builds a unified Keras Functional model that encapsulates parallel computation and communication logic (like AllGather, ReduceScatter). It also coordinates gradients by wrapping the user's optimizer in a TensorParallelOptimizer.
  • Detailed File Changes: 1. keras/src/backend/jax/distribution_lib.py: Adds new JAX-specific utilities for device information, auto-configuration of tensor parallelism, and device validation. It also includes helper functions like get_best_devices and get_device_memory_info.
  1. keras/src/distribution/distribution_lib.py: Introduces the AutoTPDistribution class, which is the public API for automatic tensor parallelism. This class handles the creation of a DeviceMesh (if not provided) and orchestrates the sharding process by instantiating and configuring the TensorParallelKeras model.
  2. keras/src/distribution/tensor_parallel/tensor_parallel.py (NEW FILE): This new file defines the TensorParallelKeras class, which wraps an original Keras model. It contains the logic for discovering devices, adjusting device lists, applying sharding rules to model parameters, and building a unified functional model (assembled_model) that manages the parallel execution and communication. It also overrides the compile method to integrate with a TensorParallelOptimizer for gradient coordination.
  3. Test Files (keras/src/backend/jax/distribution_lib_test.py, keras/src/distribution/distribution_lib_test.py, keras/src/distribution/tensor_parallel/tensor_parallel_test.py): Significant additions to test the new functionality. This includes environment setup for JAX backend, tests for AutoTPDistribution's sharding correctness (including uneven splits), and comprehensive tests for TensorParallelKeras covering initialization, non-distributed behavior, forward pass output shapes, and the correct invocation of communication primitives during gradient slicing and backward passes.
  • Reviewer Activity: No specific reviewer activity was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a great PR that introduces a powerful and much-needed feature for automatic tensor parallelism in Keras. The high-level API via AutoTPDistribution and the core engine TensorParallelKeras are well-designed, making large-scale model training more accessible. The use of a unified Functional model to encapsulate the parallel logic is particularly clever and should simplify execution and JIT compilation.

I've identified a few areas for improvement, mainly concerning API clarity, robustness, and consistency with Keras design principles. My comments focus on improving docstrings, handling edge cases more gracefully, and ensuring the code is as clear and maintainable as possible. I've also pointed out a few potential bugs and inconsistencies.

Overall, this is a fantastic contribution. Addressing these points will help ensure the new API is robust, intuitive, and easy for users to adopt.

@buildwithsuhana buildwithsuhana marked this pull request as ready for review October 15, 2025 20:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants