-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Enable Automatic Tensor Parallelism #21726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Enable Automatic Tensor Parallelism #21726
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This PR introduces automatic tensor parallelism to Keras, allowing models to be sharded across multiple devices for training larger models. It provides a high-level Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is a great PR that introduces a powerful and much-needed feature for automatic tensor parallelism in Keras. The high-level API via AutoTPDistribution
and the core engine TensorParallelKeras
are well-designed, making large-scale model training more accessible. The use of a unified Functional model to encapsulate the parallel logic is particularly clever and should simplify execution and JIT compilation.
I've identified a few areas for improvement, mainly concerning API clarity, robustness, and consistency with Keras design principles. My comments focus on improving docstrings, handling edge cases more gracefully, and ensuring the code is as clear and maintainable as possible. I've also pointed out a few potential bugs and inconsistencies.
Overall, this is a fantastic contribution. Addressing these points will help ensure the new API is robust, intuitive, and easy for users to adopt.
This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras, enabling users to train models that are too large to fit on a single accelerator. The core of this feature is a new distribution strategy, AutoTPDistribution, which provides a simple, high-level API to shard an existing Keras model across multiple devices.
Description
This framework is designed to automate the complex process of model sharding and inter-device communication, making large-scale model training more accessible. The implementation is broken down into several key components:
A new distribution strategy, AutoTPDistribution, is introduced in keras/src/distribution/distribution_lib.py. This class serves as the primary user entry point. The workflow is straightforward:
A new TensorParallelKeras model class (keras.src.distribution.tensor_parallel.tensor_parallel.py) acts as the core engine. When distribution.shard() is called, this class wraps the original model and performs the following actions:
Auto-Configures Hardware: Discovers and assigns available devices (TPU, GPU, or CPU).
Shards Parameters: It analyzes the model's layers and applies column-parallel or row-parallel sharding strategies to the weights and biases of relevant layers (e.g., Dense).
Builds a Unified Graph: It creates a single, assembled Keras Functional model that internally manages the parallel computation. This clever design encapsulates the communication logic (e.g., AllGather, ReduceScatter) within the model's call graph, simplifying the execution and enabling JIT compilation. Partial outputs from each device shard are correctly combined (e.g., concatenation for column-parallel, summation for row-parallel).
Coordinates Gradients: It overrides the compile method to wrap the user's optimizer in a TensorParallelOptimizer, which handles the synchronized computation and application of gradients across all shards.
Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing
This is the 4th (out of 4) PR for AutoSharding Keras.