Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
eda5176
Fix ModelParallel OOM issue during weight loading
amitsrivastava78 Oct 3, 2025
5da9108
Fix PyTorch backend tensor conversion and refactor variable loading
amitsrivastava78 Oct 7, 2025
9886e40
Optimize JAX backend variable initialization and memory management
amitsrivastava78 Oct 7, 2025
92bf1ed
Remove JAX reference-holding functionality and related tests
amitsrivastava78 Oct 8, 2025
250c19c
Remove quantization/saving changes that conflict with PR #21713
amitsrivastava78 Oct 8, 2025
6e222c9
Remove remaining variable_loading imports and use inline code
amitsrivastava78 Oct 8, 2025
6197d21
Revert embedding.py quantization logic while preserving _direct_assig…
amitsrivastava78 Oct 8, 2025
cfc95da
Fix Dense and EinsumDense layers to use _direct_assign for consistency
amitsrivastava78 Oct 8, 2025
0f02b80
Update all layer load_own_variables methods to use _direct_assign
amitsrivastava78 Oct 9, 2025
2bb83c6
Add shape validation to layer load_own_variables methods for consiste…
amitsrivastava78 Oct 9, 2025
10d0d0f
Add OrbaxCheckpoint callback for backend-agnostic checkpointing
amitsrivastava78 Oct 21, 2025
5dfa590
Add save_metadata support to OrbaxCheckpoint
amitsrivastava78 Oct 21, 2025
e03bcee
Add save_data_iterator support to OrbaxCheckpoint
amitsrivastava78 Oct 21, 2025
10f04f7
Added data iterator test for specific backend
amitsrivastava78 Oct 21, 2025
ac7d4e8
Add missing Orbax checkpoint features to Keras 3.0
amitsrivastava78 Oct 21, 2025
db78464
Remove direct orbax.checkpoint imports from test file
amitsrivastava78 Oct 21, 2025
0c56129
Add test for custom Orbax checkpoint handlers and registry functionality
amitsrivastava78 Oct 21, 2025
2e61b6e
Added few fixes
amitsrivastava78 Oct 21, 2025
a31c8e9
Added bridge design
amitsrivastava78 Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
35 changes: 35 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,38 @@ class name_scope(backend_name_scope):
@keras_export("keras.device")
def device(device_name):
return device_scope(device_name) # noqa: F405


def get_process_index():
"""Get the index of the current process in a distributed setup.

Returns:
int: The process index (0 for primary process, >0 for others).
Returns 0 if not in a distributed setup.
"""
backend_name = backend()
if backend_name == "jax":
try:
import jax

return jax.process_index()
except (ImportError, AttributeError):
return 0
elif backend_name == "tensorflow":
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
elif backend_name == "torch":
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
else:
return 0
14 changes: 12 additions & 2 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
from keras.src.backend.jax.core import NnxVariable


class JaxCoreTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for core Variable integration with NNX.",
Expand All @@ -25,8 +35,8 @@
reason="Test requires NNX backend to be enabled by default for setup.",
)
class NnxVariableTest(testing.TestCase):
def setup(self):
super().setup()
def setUp(self):
super().setUp()

class NNXModel(nnx.Module):
def __init__(self, rngs):
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
reason="Backend specific test",
)
class JaxDistributionLibTest(testing.TestCase):
def _require_min_devices(self, min_devices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this was not needed before?
Line 27 should make it work.
Under what circumstances did you need this?

"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def _create_jax_layout(self, sharding):
# Use jax_layout.Format or jax_layout.Layout if available.
if hasattr(jax_layout, "Format"):
Expand All @@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
return sharding

def test_list_devices(self):
self._require_min_devices(8)
self.assertEqual(len(distribution_lib.list_devices()), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
Expand Down Expand Up @@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
)

def test_distribute_tensor(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand All @@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_variable(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -118,6 +129,7 @@ def test_distribute_variable(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_input_data(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
# The multi-process test lives in g3.
jax_mesh = jax.sharding.Mesh(
Expand All @@ -136,6 +148,7 @@ def test_distribute_input_data(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_tensor_with_jax_layout(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand Down Expand Up @@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
)

def test_distribute_variable_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
)

def test_distribute_input_data_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -212,6 +227,7 @@ def test_processes(self):
self.assertEqual(backend_dlib.num_processes(), 1)

def test_to_backend_mesh(self):
self._require_min_devices(8)
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
axis_names = ["batch", "model"]
Expand All @@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))

def test_to_backend_layout(self):
self._require_min_devices(8)
axes = ["data", None]
mesh = distribution_lib.DeviceMesh(
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
Expand All @@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
backend_dlib._to_backend_layout(layout)

def test_variable_assignment_reuse_layout(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_model(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
)

def test_distribute_data_input(self):
self._require_min_devices(4)
per_process_batch = jax.numpy.arange(24).reshape(
6, 4
) # Example input array
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _initialize(self, value):
).to(get_device())

def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self._dtype)
with torch.no_grad():
self.value.copy_(value)

Expand Down
6 changes: 6 additions & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback

try:
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
except ImportError:
OrbaxCheckpoint = None

from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading