Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

  • Modified load_own_variables() to use _direct_assign() for sharded variables
  • Prevents loading full weight tensors on single device before distribution
  • Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
  • Maintains backward compatibility for non-sharded variables
  • Enables loading of models like Gemma2 2B/7B without OOM errors
  • Added EinsumDense layer testing to ModelParallel sharded variable loading

- Modified load_own_variables() to use _direct_assign() for sharded variables
- Prevents loading full weight tensors on single device before distribution
- Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
- Maintains backward compatibility for non-sharded variables
- Enables loading of models like Gemma2 2B/7B without OOM errors
- Added EinsumDense layer testing to ModelParallel sharded variable loading
@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 7, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical Out-Of-Memory (OOM) issue encountered when loading large models using ModelParallel in the JAX backend. The core problem stemmed from loading entire weight tensors onto a single device before distributing them, leading to RESOURCE_EXHAUSTED errors. The solution refactors the variable assignment mechanism during weight loading to directly distribute sharded variables across devices, thereby avoiding peak memory usage on any single device. This enhancement significantly improves the ability to load and utilize large-scale models within the Keras JAX ecosystem.

Highlights

  • Memory Optimization for ModelParallel: Modified the load_own_variables() method to utilize _direct_assign() for sharded variables, preventing Out-Of-Memory (OOM) errors during weight loading of large models in JAX's ModelParallel.
  • Direct Sharded Variable Assignment: The change ensures that full weight tensors are not loaded onto a single device before distribution, instead distributing them directly across devices, which is crucial for handling large models like Gemma2 2B/7B.
  • Backward Compatibility: The modifications maintain backward compatibility for non-sharded variables, ensuring existing models continue to function as expected.
  • Enhanced Testing: Added comprehensive testing for ModelParallel sharded variable loading, including specific tests for EinsumDense layers and general inference simulation to prevent array deletion issues.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a crucial fix for an out-of-memory issue when loading sharded models with ModelParallel in the JAX backend. The core change, which involves using a new _direct_assign method to distribute weights before assigning them to device variables, is well-implemented and effectively prevents loading full tensors onto a single device. The addition of _ProtectedShardedArray and strong referencing to prevent premature garbage collection of sharded arrays is a clever solution to a common problem in JAX. The refactoring of sharding logic into a shared _initialize_variable_with_sharding helper improves code clarity and maintainability. The comprehensive end-to-end tests for sharded variable loading across various layer types provide strong confidence in the correctness of the fix. Overall, this is an excellent and thorough contribution that significantly improves the usability of model parallelism in Keras.

@codecov-commenter
Copy link

codecov-commenter commented Oct 7, 2025

Codecov Report

❌ Patch coverage is 73.77049% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.57%. Comparing base (0ecb55d) to head (2bb83c6).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/core/einsum_dense.py 68.42% 3 Missing and 3 partials ⚠️
keras/src/layers/core/dense.py 78.94% 2 Missing and 2 partials ⚠️
keras/src/layers/core/embedding.py 71.42% 2 Missing and 2 partials ⚠️
keras/src/layers/preprocessing/index_lookup.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21723      +/-   ##
==========================================
- Coverage   82.59%   82.57%   -0.02%     
==========================================
  Files         572      572              
  Lines       58535    58581      +46     
  Branches     9158     9166       +8     
==========================================
+ Hits        48345    48375      +30     
- Misses       7853     7862       +9     
- Partials     2337     2344       +7     
Flag Coverage Δ
keras 82.38% <73.77%> (-0.02%) ⬇️
keras-jax 63.19% <72.13%> (-0.01%) ⬇️
keras-numpy 57.56% <72.13%> (+<0.01%) ⬆️
keras-openvino 34.31% <4.91%> (-0.03%) ⬇️
keras-tensorflow 63.94% <72.13%> (-0.01%) ⬇️
keras-torch 63.49% <73.77%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Fix PyTorch backend CI failures by adding _direct_assign method for proper numpy-to-tensor conversion
- Restore JAX export functionality using jax_export.symbolic_shape for dynamic shape handling
- Refactor variable loading logic to eliminate duplication between Dense and EinsumDense layers
- Create shared utility function get_quantized_variable_load_order in keras/src/utils/variable_loading.py
- Update layer implementations to use the shared variable loading utility
- All tests passing: PyTorch backend, JAX backend, and layer-specific legacy loading tests
- Improve host memory allocation for sharded variables by preferring JAX arrays over NumPy conversion
- Remove unnecessary jax.block_until_ready() calls as JAX automatically blocks when needed
- Add comprehensive documentation for memory stability protection and host allocation
- Enhance logging for variable initialization and assignment operations
- Add support for both NumPy and JAX arrays in variable assignment methods
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amitsrivastava78

One overall note: We basically never log in the successful case, the logs are just way too noisy.

Can you create two separate small PRs:

  • One for the fix for the initializer. My reading of it is that this is the only change needed:
    def _initialize_with_initializer(self, initializer):
        """Initialize variable with initializer, running on CPU if sharding
        is needed."""
        if self._layout is not None:
            # For sharded variables, run initializer on CPU to avoid device
            # placement issues
            with jax.default_device(jax.devices("cpu")[0]):
                value = self._convert_to_tensor(
                    initializer(self._shape, dtype=self._dtype)
                )
        else:
            # For non-sharded variables, use the default behavior
            value = self._convert_to_tensor(
                initializer(self._shape, dtype=self._dtype)
            )
        self._initialize(value)

But we should check that it does what we think it does.

  • One for the fix for the weight loading. It will have this change, and only this change, for the relevant layers:
        for i, variable in enumerate(target_variables):
            variable._direct_assign(store[str(i)])

And jax_memory_cleanup should be addressed differently, which we can talk about.

- Remove _ProtectedShardedArray class and _maybe_create_strong_reference method from core.py
- Remove jax.block_until_ready calls that are no longer needed
- Simplify variable initialization and assignment logic
- Remove all test cases related to reference holding from core_test.py
- Tests now pass and are consistent with the simplified implementation
…1713

- Remove variable_loading.py (quantization/saving related)
- Fix duplicate import in core_test.py
- Revert layer files to remove quantization changes
- Keep only core JAX memory management changes for OOM fix
- Remove get_quantized_variable_load_order imports from dense.py and einsum_dense.py
- Replace function calls with inline variable ordering logic
- Maintain compatibility with quantization loading
…n usage

- Remove quantization-specific variable ordering in _legacy_load_own_variables
- Keep _direct_assign usage for OOM prevention during sharded variable loading
- Maintain compatibility with quantization_variable_spec
- Change dense.py and einsum_dense.py _legacy_load_own_variables to use _direct_assign
- Maintains OOM prevention for ModelParallel while ensuring consistency across all layers
- All layers now use _direct_assign for variable loading
hertschuh
hertschuh previously approved these changes Oct 9, 2025
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. But what about base_conv, embedding? Don't they need this change too?

https://github.com/search?q=repo%3Akeras-team%2Fkeras%20variable.assign(store%5Bstr(i)%5D)&type=code

Oh and what about base_optimizer?

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 9, 2025
- Change embedding.py, dense.py, and einsum_dense.py regular load_own_variables
  methods to use _direct_assign instead of assign
- Ensures consistent OOM prevention for ModelParallel across all loading paths
- base_conv.py and base_optimizer.py already used _direct_assign correctly
- All variable loading now uses the same _direct_assign approach
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 9, 2025
@amitsrivastava78
Copy link
Collaborator Author

This looks good to me. But what about base_conv, embedding? Don't they need this change too?

https://github.com/search?q=repo%3Akeras-team%2Fkeras%20variable.assign(store%5Bstr(i)%5D)&type=code

Oh and what about base_optimizer?

Yes base_conv, embedding do have weights which needs to be sharded, same for base_optimizer

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a closer look at this, I really don't understand why calling variable._direct_assign instead of variable.assign makes any difference.

The implementation of variable.assign calls variable._direct_assign after doing the shape check (which you've now replicated 8 times).

_direct_assign is just the backend dependent implementation of the assignment.

Can you explain why it changes anything?

@hertschuh hertschuh dismissed their stale review October 9, 2025 22:50

Many code changes and unclear if the code changes actually save memory.

@amitsrivastava78
Copy link
Collaborator Author

Taking a closer look at this, I really don't understand why calling variable._direct_assign instead of variable.assign makes any difference.

The implementation of variable.assign calls variable._direct_assign after doing the shape check (which you've now replicated 8 times).

_direct_assign is just the backend dependent implementation of the assignment.

Can you explain why it changes anything?

Ah.. yes you are right, the only issue then is the jax_memory_cleanup , should we remove this completely, or conditionally skip memory delete if sharding is present ?

reason="Backend specific test",
)
class JaxDistributionLibTest(testing.TestCase):
def _require_min_devices(self, min_devices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this was not needed before?
Line 27 should make it work.
Under what circumstances did you need this?

- Add OrbaxCheckpoint callback with similar API to ModelCheckpoint
- Supports async saving, best-only mode, max_to_keep, and batch/epoch saving
- Backend-agnostic implementation with conditional imports
- Add get_process_index() utility for distributed training support
- Comprehensive test suite with 8 test methods
- All code formatted to 80-character line limit
- Add save_metadata parameter to OrbaxCheckpoint constructor
- Support both static dict and callable metadata functions
- Include metadata in composite checkpoint state
- Add comprehensive tests for metadata saving functionality
- Ensure line length compliance and proper error handling
- Add save_data_iterator parameter to constructor for saving iterator state
- Support both static dict and callable iterator state functions
- Include iterator state in composite checkpoint state
- Add comprehensive test for iterator state saving functionality
- Ensure line length compliance and proper error handling
- Add async timeout, background delete, and post-finalization callbacks
- Add metrics state saving and restoration for composite checkpoints
- Add comprehensive tests for all new features
- Fix line lengths to comply with 80-character limit
- Enable loading checkpoints into different model instances
- Export CheckpointManager, SaveArgs, StandardRestore from orbax_checkpoint.py
- Update test file to import these classes from orbax_checkpoint module
- Remove direct import orbax.checkpoint as ocp from test file
- Fix line lengths to comply with 80-column limit
- Remove unused variable to fix linting error
- All tests pass with clean API separation
- Implement test_custom_handler_and_registry that demonstrates custom TypeHandler
- Test saves and restores custom dataclass objects using PyTreeCheckpointer
- Validates that type handlers work for individual custom objects
- Documents limitation that custom objects cannot be used in composite checkpoints
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Gemma Gemma model specific issues size:XL

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants