Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
eda5176
Fix ModelParallel OOM issue during weight loading
amitsrivastava78 Oct 3, 2025
5da9108
Fix PyTorch backend tensor conversion and refactor variable loading
amitsrivastava78 Oct 7, 2025
9886e40
Optimize JAX backend variable initialization and memory management
amitsrivastava78 Oct 7, 2025
92bf1ed
Remove JAX reference-holding functionality and related tests
amitsrivastava78 Oct 8, 2025
250c19c
Remove quantization/saving changes that conflict with PR #21713
amitsrivastava78 Oct 8, 2025
6e222c9
Remove remaining variable_loading imports and use inline code
amitsrivastava78 Oct 8, 2025
6197d21
Revert embedding.py quantization logic while preserving _direct_assig…
amitsrivastava78 Oct 8, 2025
cfc95da
Fix Dense and EinsumDense layers to use _direct_assign for consistency
amitsrivastava78 Oct 8, 2025
0f02b80
Update all layer load_own_variables methods to use _direct_assign
amitsrivastava78 Oct 9, 2025
2bb83c6
Add shape validation to layer load_own_variables methods for consiste…
amitsrivastava78 Oct 9, 2025
10d0d0f
Add OrbaxCheckpoint callback for backend-agnostic checkpointing
amitsrivastava78 Oct 21, 2025
5dfa590
Add save_metadata support to OrbaxCheckpoint
amitsrivastava78 Oct 21, 2025
e03bcee
Add save_data_iterator support to OrbaxCheckpoint
amitsrivastava78 Oct 21, 2025
10f04f7
Added data iterator test for specific backend
amitsrivastava78 Oct 21, 2025
ac7d4e8
Add missing Orbax checkpoint features to Keras 3.0
amitsrivastava78 Oct 21, 2025
db78464
Remove direct orbax.checkpoint imports from test file
amitsrivastava78 Oct 21, 2025
0c56129
Add test for custom Orbax checkpoint handlers and registry functionality
amitsrivastava78 Oct 21, 2025
2e61b6e
Added few fixes
amitsrivastava78 Oct 21, 2025
a31c8e9
Added bridge design
amitsrivastava78 Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 381 additions & 52 deletions keras/src/backend/jax/core.py

Large diffs are not rendered by default.

130 changes: 128 additions & 2 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,141 @@

import keras
from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import testing
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax.core import JaxVariable
from keras.src.backend.jax.core import _ProtectedShardedArray

if is_nnx_enabled():
from keras.src.backend.jax.core import NnxVariable

if is_nnx_enabled():
from flax import nnx

from keras.src.backend.jax.core import NnxVariable


class JaxCoreTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def test_protected_sharded_array_deletion(self):
"""Test _ProtectedShardedArray prevents deletion of sharded arrays."""
# Create a mock sharded array
array = jax.numpy.ones((10, 10))
sharded_array = jax.device_put(array, jax.devices()[0])
sharded_array.addressable_shards = [
jax.device_put(array, d) for d in jax.devices()
]

protected = _ProtectedShardedArray(sharded_array)

# Attempt deletion (should not delete sharded arrays)
protected.delete()

# Verify array is still accessible
self.assertIs(protected._array, sharded_array)
self.assertTrue(
hasattr(protected, "_is_sharded") and protected._is_sharded
)

def test_jax_variable_strong_references_and_logging(self):
"""Test JaxVariable strong references and logging."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create a sharded variable
var = JaxVariable(jax.numpy.ones((100, 100)))

# Check strong references
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value multiple times to simulate inference
for _ in range(5):
value = var.value
self.assertIsNotNone(
value
) # Ensure no "Array has been deleted" error

# Final check: Value should still be accessible
self.assertIsNotNone(var.value)

@pytest.mark.skipif(not is_nnx_enabled(), reason="NNX not enabled")
def test_nnx_variable_strong_references_and_logging(self):
"""Test NnxVariable strong references and logging."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create NNX variable with sharding
var = NnxVariable(jax.numpy.ones((50, 50)), layout=("model", None))

# Check strong references
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value (simulates inference) and assert no deletion
value = var.value
self.assertIsNotNone(value) # Ensure no "Array has been deleted" error

# Additional accesses to simulate repeated inference
for _ in range(5):
value = var.value
self.assertIsNotNone(value)

def test_variable_loading_with_sharding(self):
"""Test variable loading with sharding support."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create test data
test_data = jax.numpy.ones((10, 10))

# Create variable with sharding
var = JaxVariable(jax.numpy.zeros((10, 10)))
# Load data into it
var._direct_assign(test_data)

# Verify it's a JaxVariable with sharding
self.assertIsInstance(var, JaxVariable)
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value to ensure no deletion
self.assertIsNotNone(var.value)

def test_inference_simulation_no_array_deletion(self):
"""Test inference simulation for no 'Array has been deleted' errors."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create a simple model with sharding
inputs = layers.Input(shape=(10,))
x = layers.Dense(50, name="dense")(inputs)
model = models.Model(inputs, x)

# Build and access weights (triggers sharding and protection)
model.build((None, 10))
for var in model.weights:
value = var.value # Access to trigger protection
self.assertIsNotNone(value) # Ensure initial access succeeds

# Simulate inference (multiple accesses) and assert no deletion
test_input = np.random.randn(1, 10)
for _ in range(10):
output = model(test_input)
self.assertIsNotNone(
output
) # Ensure inference succeeds without errors

# Final check: Weights should still be accessible
for var in model.weights:
self.assertIsNotNone(var.value)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for core Variable integration with NNX.",
Expand All @@ -25,8 +151,8 @@
reason="Test requires NNX backend to be enabled by default for setup.",
)
class NnxVariableTest(testing.TestCase):
def setup(self):
super().setup()
def setUp(self):
super().setUp()

class NNXModel(nnx.Module):
def __init__(self, rngs):
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
reason="Backend specific test",
)
class JaxDistributionLibTest(testing.TestCase):
def _require_min_devices(self, min_devices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this was not needed before?
Line 27 should make it work.
Under what circumstances did you need this?

"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def _create_jax_layout(self, sharding):
# Use jax_layout.Format or jax_layout.Layout if available.
if hasattr(jax_layout, "Format"):
Expand All @@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
return sharding

def test_list_devices(self):
self._require_min_devices(8)
self.assertEqual(len(distribution_lib.list_devices()), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
Expand Down Expand Up @@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
)

def test_distribute_tensor(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand All @@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_variable(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -118,6 +129,7 @@ def test_distribute_variable(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_input_data(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
# The multi-process test lives in g3.
jax_mesh = jax.sharding.Mesh(
Expand All @@ -136,6 +148,7 @@ def test_distribute_input_data(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_tensor_with_jax_layout(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand Down Expand Up @@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
)

def test_distribute_variable_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
)

def test_distribute_input_data_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -212,6 +227,7 @@ def test_processes(self):
self.assertEqual(backend_dlib.num_processes(), 1)

def test_to_backend_mesh(self):
self._require_min_devices(8)
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
axis_names = ["batch", "model"]
Expand All @@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))

def test_to_backend_layout(self):
self._require_min_devices(8)
axes = ["data", None]
mesh = distribution_lib.DeviceMesh(
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
Expand All @@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
backend_dlib._to_backend_layout(layout)

def test_variable_assignment_reuse_layout(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_model(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
)

def test_distribute_data_input(self):
self._require_min_devices(4)
per_process_batch = jax.numpy.arange(24).reshape(
6, 4
) # Example input array
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _initialize(self, value):
).to(get_device())

def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self._dtype)
with torch.no_grad():
self.value.copy_(value)

Expand Down
Loading