From 836f6e28d548db4282258846f7052d164ecd2555 Mon Sep 17 00:00:00 2001 From: praveenhosdrug123 Date: Fri, 17 Oct 2025 20:23:46 +0530 Subject: [PATCH 1/2] Fix OOM Issue --- keras_hub/src/models/gemma/gemma_backbone.py | 16 +- .../src/models/gemma3/gemma3_backbone.py | 30 +- keras_hub/src/models/llama/llama_backbone.py | 8 +- keras_hub/src/models/qwen/qwen_backbone.py | 8 +- keras_hub/src/models/qwen3/qwen3_backbone.py | 8 +- keras_hub/src/utils/dist_initializer.py | 647 ++++++++++++++++++ keras_hub/src/utils/dist_initializer_test.py | 315 +++++++++ 7 files changed, 1019 insertions(+), 13 deletions(-) create mode 100644 keras_hub/src/utils/dist_initializer.py create mode 100644 keras_hub/src/utils/dist_initializer_test.py diff --git a/keras_hub/src/models/gemma/gemma_backbone.py b/keras_hub/src/models/gemma/gemma_backbone.py index a08eedeca8..c9f660a58d 100644 --- a/keras_hub/src/models/gemma/gemma_backbone.py +++ b/keras_hub/src/models/gemma/gemma_backbone.py @@ -10,6 +10,16 @@ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +def _gemma_embedding_initializer(): + from keras_hub.src.utils import dist_initializer + + return dist_initializer.DistributedVarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + ) + + @keras_hub_export("keras_hub.models.GemmaBackbone") class GemmaBackbone(Backbone): """Gemma core network with hyperparameters. @@ -110,11 +120,7 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=True, - embeddings_initializer=keras.initializers.VarianceScaling( - scale=1.0, - mode="fan_in", - distribution="untruncated_normal", - ), + embeddings_initializer=_gemma_embedding_initializer(), dtype=dtype, logit_soft_cap=final_logit_soft_cap, name="token_embedding", diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index a65dbd726b..e36f01b75c 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -13,6 +13,26 @@ ) +# Note: For LoRA or quantization, apply after model loading using +# model.backbone.enable_lora() or model.quantize() respectively. +# Distributed initialization is designed for initial text only model loading. +def _gemma3_embedding_initializer(text_only_model): + if text_only_model: + from keras_hub.src.utils import dist_initializer + + return dist_initializer.DistributedVarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + ) + else: + return keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + ) + + @keras_hub_export("keras_hub.models.Gemma3Backbone") class Gemma3Backbone(Backbone): """Gemma3 core network with hyperparameters. @@ -202,14 +222,14 @@ def __init__( **kwargs, ): # === Layers === + text_only_model = True if vision_encoder is None else False + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=True, - embeddings_initializer=keras.initializers.VarianceScaling( - scale=1.0, - mode="fan_in", - distribution="untruncated_normal", + embeddings_initializer=_gemma3_embedding_initializer( + text_only_model ), dtype=dtype, logit_soft_cap=final_logit_soft_cap, @@ -217,7 +237,7 @@ def __init__( ) self.vision_encoder = vision_encoder - text_only_model = True if vision_encoder is None else False + if not text_only_model: self.interleave_embeddings = Gemma3InterleaveEmbeddings( num_vision_tokens_per_image=self.vision_encoder.num_vision_tokens_per_image, diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index 57ac319bfb..3c01515b56 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -14,6 +14,12 @@ def _llama_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) +def _llama_embedding_initializer(stddev=0.01): + from keras_hub.src.utils import dist_initializer + + return dist_initializer.DistributedRandomNormal(stddev=0.01) + + @keras_hub_export("keras_hub.models.LlamaBackbone") class LlamaBackbone(Backbone): """ @@ -111,7 +117,7 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_word_embeddings, - embeddings_initializer=_llama_kernel_initializer(stddev=0.01), + embeddings_initializer=_llama_embedding_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) diff --git a/keras_hub/src/models/qwen/qwen_backbone.py b/keras_hub/src/models/qwen/qwen_backbone.py index b857f7adaf..c86b968fb3 100644 --- a/keras_hub/src/models/qwen/qwen_backbone.py +++ b/keras_hub/src/models/qwen/qwen_backbone.py @@ -14,6 +14,12 @@ def _qwen_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) +def _qwen_embedding_initializer(stddev=0.01): + from keras_hub.src.utils import dist_initializer + + return dist_initializer.DistributedRandomNormal(stddev=0.01) + + @keras_hub_export( [ "keras_hub.models.QwenBackbone", @@ -114,7 +120,7 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_word_embeddings, - embeddings_initializer=_qwen_kernel_initializer(stddev=0.01), + embeddings_initializer=_qwen_embedding_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) diff --git a/keras_hub/src/models/qwen3/qwen3_backbone.py b/keras_hub/src/models/qwen3/qwen3_backbone.py index 4db7f79a43..f84c294e09 100644 --- a/keras_hub/src/models/qwen3/qwen3_backbone.py +++ b/keras_hub/src/models/qwen3/qwen3_backbone.py @@ -14,6 +14,12 @@ def _qwen3_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) +def _qwen3_embedding_initializer(stddev=0.01): + from keras_hub.src.utils import dist_initializer + + return dist_initializer.DistributedRandomNormal(stddev=0.01) + + @keras_hub_export("keras_hub.models.Qwen3Backbone") class Qwen3Backbone(Backbone): """The Qwen3 Transformer core architecture with hyperparameters. @@ -105,7 +111,7 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_word_embeddings, - embeddings_initializer=_qwen3_kernel_initializer(stddev=0.01), + embeddings_initializer=_qwen3_embedding_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) diff --git a/keras_hub/src/utils/dist_initializer.py b/keras_hub/src/utils/dist_initializer.py new file mode 100644 index 0000000000..fd90cf86b3 --- /dev/null +++ b/keras_hub/src/utils/dist_initializer.py @@ -0,0 +1,647 @@ +import math +from functools import partial + +import keras.src.saving +from jax.random import normal +from jax.random import truncated_normal +from jax.random import uniform +from keras.src.backend import backend +from keras.src.backend import distribution_lib +from keras.src.backend import random +from keras.src.backend.config import floatx +from keras.src.initializers.random_initializers import RandomNormal +from keras.src.initializers.random_initializers import RandomUniform +from keras.src.initializers.random_initializers import TruncatedNormal +from keras.src.initializers.random_initializers import VarianceScaling + +# Constant for token embedding path in layout map for all text based models. +# This may need to be updated if models with different layout maps are used. +# The failsafe is to just use non-distributed initializers if no layout is found + +TOKEN_EMBEDDING_PATH = "token_embedding/embeddings" + + +def _get_token_embedding_layout(): + """Get tensor layout for token embeddings if distribution is enabled.""" + from keras.src.distribution import distribution as get_distribution + + current_distribution = get_distribution() + + if current_distribution is None: + return None + if not hasattr(current_distribution, "_layout_map"): + return None + + layout_map = current_distribution._layout_map + tensor_layout = layout_map.get(TOKEN_EMBEDDING_PATH, None) + + if tensor_layout is None or tensor_layout.device_mesh is None: + return None + + if backend() != "jax": + return None + + return tensor_layout + + +def _get_number_of_shards_per_axis(tensor_layout): + """Get number of shards on each axis from tensor layout. + + NOTE: This function extracts sharding info but does NOT affect scaling. + + Design principle: "Aggregation pattern dictates variance scaling" + + For standard model parallelism (vocab sharding): + - Use GLOBAL fan_in/fan_out (don't divide by shards) + - Because all vocab entries are logically accessible via communication + - Example: Standard transformer with sharded embedding table + + For true vocabulary partitioning (rare): + - Use LOCAL fan_in/fan_out (divide by shards) + - Because each device operates on disjoint subsets with no aggregation + - Example: Multi-lingual model where device 0 only handles English tokens, + device 1 only handles French tokens, etc. + + Currently, we assume standard model parallelism (global fan_in/fan_out). + """ + + if tensor_layout is None or tensor_layout.device_mesh is None: + return None + + mesh = tensor_layout.device_mesh + + # Currently only support 2D sharding for text models + if len(mesh.shape) != 2 or len(tensor_layout.axes) != 2: + return None + + mesh_dict = dict(zip(mesh.axis_names, mesh.shape)) + + num_shards_per_axis = tuple( + mesh_dict.get(axis, 1) for axis in tensor_layout.axes + ) + + return num_shards_per_axis + + +def jax_draw_seed(seed): + import jax + from keras.src.random.seed_generator import draw_seed + + if isinstance(seed, jax.Array): + return seed + else: + return draw_seed(seed) + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedRandomNormal(RandomNormal): + """Distributed Random normal initializer. + Given we have the apt layout, this initializer will use the layout to + shard the initialized weights. If no layout is found, it will fall back to + standard RandomNormal initializer. + Draws samples from a normal distribution for given parameters. + + Args: + mean: A python scalar or a scalar keras tensor. Mean of the random + values to generate. + stddev: A python scalar or a scalar keras tensor. Standard deviation of + the random values to generate. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __call__(self, shape, dtype=None): + tensor_layout = _get_token_embedding_layout() + if tensor_layout is None: + return random.normal( + shape=shape, + mean=self.mean, + stddev=self.stddev, + seed=self.seed, + dtype=dtype, + ) + else: + dtype = dtype or floatx() + seed = jax_draw_seed(self.seed) + init_func = partial( + normal, + shape=shape, + dtype=dtype, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=self.mean, + stddev=self.stddev, + seed=seed, + layout=tensor_layout, + ) + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedTruncatedNormal(TruncatedNormal): + """Distributed Initializer that generates a truncated normal distribution. + Given we have the apt layout, this initializer will use the layout to + shard the initialized weights. If no layout is found, it will fall back to + standard TruncatedNormal initializer. + The values generated are similar to values from a + `RandomNormal` initializer, except that values more + than two standard deviations from the mean are + discarded and re-drawn. + + Args: + mean: A python scalar or a scalar keras tensor. Mean of the random + values to generate. + stddev: A python scalar or a scalar keras tensor. Standard deviation of + the random values to generate. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __call__(self, shape, dtype=None): + tensor_layout = _get_token_embedding_layout() + if tensor_layout is None: + return random.truncated_normal( + shape=shape, + mean=self.mean, + stddev=self.stddev, + seed=self.seed, + dtype=dtype, + ) + else: + dtype = dtype or floatx() + seed = jax_draw_seed(self.seed) + init_func = partial( + truncated_normal, + shape=shape, + dtype=dtype, + lower=-2.0, + upper=2.0, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=self.mean, + stddev=self.stddev, + seed=seed, + layout=tensor_layout, + ) + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedRandomUniform(RandomUniform): + """Distributed Random uniform initializer. + Given we have the apt layout, this initializer will use the layout to + shard the initialized weights. If no layout is found, it will fall back to + standard RandomUniform initializer. + Draws samples from a uniform distribution for given parameters. + + Args: + minval: A python scalar or a scalar keras tensor. Lower bound of the + range of random values to generate (inclusive). + maxval: A python scalar or a scalar keras tensor. Upper bound of the + range of random values to generate (exclusive). + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __call__(self, shape, dtype=None): + tensor_layout = _get_token_embedding_layout() + if tensor_layout is None: + return random.uniform( + shape=shape, + minval=self.minval, + maxval=self.maxval, + seed=self.seed, + dtype=dtype, + ) + else: + dtype = dtype or floatx() + seed = jax_draw_seed(self.seed) + init_func = partial( + uniform, + shape=shape, + dtype=dtype, + minval=self.minval, + maxval=self.maxval, + ) + + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=None, + stddev=None, + seed=seed, + layout=tensor_layout, + ) + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedVarianceScaling(VarianceScaling): + """Distributed Initializer that adapts its scaling. + Given we have the apt layout, this initializer will use the layout to + shard the initialized weights. If no layout is found, it will fall back to + standard VarianceScaling initializer. + With `distribution="truncated_normal" or "untruncated_normal"`, samples are + drawn from a truncated/untruncated normal distribution with a mean of zero + and a standard deviation (after truncation, if used) `stddev = sqrt(scale / + n)`, where `n` is: + + - number of input units in the weight tensor, if `mode="fan_in"` + - number of output units, if `mode="fan_out"` + - average of the numbers of input and output units, if `mode="fan_avg"` + + With `distribution="uniform"`, samples are drawn from a uniform distribution + within `[-limit, limit]`, where `limit = sqrt(3 * scale / n)`. + + + Args: + scale: Scaling factor (positive float). + mode: One of `"fan_in"`, `"fan_out"`, `"fan_avg"`. + distribution: Random distribution to use. + One of `"truncated_normal"`, `"untruncated_normal"`, or `"uniform"`. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __call__(self, shape, dtype=None): + scale = self.scale + fan_in, fan_out = compute_fans(shape) + if self.mode == "fan_in": + scale /= max(1.0, fan_in) + elif self.mode == "fan_out": + scale /= max(1.0, fan_out) + else: + scale /= max(1.0, (fan_in + fan_out) / 2.0) + tensor_layout = _get_token_embedding_layout() + if tensor_layout is None: + if self.distribution == "truncated_normal": + stddev = math.sqrt(scale) / 0.87962566103423978 + return random.truncated_normal( + shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + ) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random.normal( + shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + ) + else: + limit = math.sqrt(3.0 * scale) + return random.uniform( + shape, + minval=-limit, + maxval=limit, + dtype=dtype, + seed=self.seed, + ) + else: + dtype = dtype or floatx() + seed = jax_draw_seed(self.seed) + if self.distribution == "truncated_normal": + stddev = math.sqrt(scale) / 0.87962566103423978 + init_func = partial( + truncated_normal, + shape=shape, + dtype=dtype, + lower=-2.0, + upper=2.0, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=0.0, + stddev=stddev, + seed=seed, + layout=tensor_layout, + ) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + init_func = partial( + normal, + shape=shape, + dtype=dtype, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=0.0, + stddev=stddev, + seed=seed, + layout=tensor_layout, + ) + else: + limit = math.sqrt(3.0 * scale) + init_func = partial( + uniform, + shape=shape, + dtype=dtype, + minval=-limit, + maxval=limit, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=None, + stddev=None, + seed=seed, + layout=tensor_layout, + ) + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedGlorotUniform(DistributedVarianceScaling): + """Distributed The Glorot uniform initializer, also called Xavier uniform + initializer. Given apt layout, this initializer will use the layout to + shard the initialized weights. If no layout is found, it will fall back to + standard GlorotUniform initializer. + + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(6 / (fan_in + fan_out))` (`fan_in` is the number of input + units in the weight tensor and `fan_out` is the number of output units). + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_avg", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedGlorotNormal(DistributedVarianceScaling): + """Distributed Glorot normal initializer, also called + Xavier normal initializer. Given apt layout, this initializer + will use the layout to shard the initialized weights.If no layout is found, + it will fall back to standard GlorotNormal initializer. + + Draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of + input units in the weight tensor and `fan_out` is the number of output units + in the weight tensor. + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, + mode="fan_avg", + distribution="truncated_normal", + seed=seed, + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedLecunNormal(DistributedVarianceScaling): + """Distributed Lecun normal initializer. + Given apt layout, this initializer will use the layout to shard the + initialized weights.If no layout is found, it will fall back to standard + LecunNormal initializer.Initializers allow you to pre-specify an + initialization strategy, encoded in the Initializer object, without + knowing the shape and dtype of the variable being initialized. + + Draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of input units in + the weight tensor. + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_in", distribution="truncated_normal", seed=seed + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedLecunUniform(DistributedVarianceScaling): + """Distributed Lecun uniform initializer. + Given apt layout, this initializer will use the layout to shard the + initialized weights.If no layout is found, it will fall back to standard + LecunUniform initializer. Initializers allow you to pre-specify an + initialization strategy, encoded in the Initializer object, without knowing + the shape and dtype of the variable being initialized. + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(3 / fan_in)` (`fan_in` is the number of input units in the + weight tensor). + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_in", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedHeNormal(DistributedVarianceScaling): + """Distributed He normal initializer. + Given apt layout, this initializer will use the layout to shard the + initialized weights. If no layout is found, it will fall back to standard + HeNormal initializer. Initializers allow you to pre-specify an + initialization strategy, encoded in the Initializer object, without knowing + the shape and dtype of the variable being initialized. + It draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in + the weight tensor. + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [He et al., 2015](https://arxiv.org/abs/1502.01852) + """ + + def __init__(self, seed=None): + super().__init__( + scale=2.0, mode="fan_in", distribution="truncated_normal", seed=seed + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +# Registered for deserialization, but not in public API +@keras.src.saving.register_keras_serializable(package="keras_hub") +class DistributedHeUniform(DistributedVarianceScaling): + """Distributed He uniform variance scaling initializer. + Given apt layout, this initializer will use the layout to shard the + initialized weights. If no layout is found, it will fall back to standard + HeUniform initializer. Initializers allow you to pre-specify an + initialization strategy, encoded in the Initializer object, without + knowing the shape and dtype of the variable being initialized. + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(6 / fan_in)` (`fan_in` is the number of input units in the + weight tensor). + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [He et al., 2015](https://arxiv.org/abs/1502.01852) + """ + + def __init__(self, seed=None): + super().__init__( + scale=2.0, mode="fan_in", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": keras.src.saving.serialization_lib.serialize_keras_object( + self._init_seed + ) + } + + +def compute_fans(shape): + """Computes the number of input and output units for a weight shape. + + Args: + shape: Integer shape tuple. + + Returns: + A tuple of integer scalars: `(fan_in, fan_out)`. + """ + shape = tuple(shape) + if len(shape) < 1: # Just to avoid errors for constants. + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + else: + # Assuming convolution kernels (2D, 3D, or more). + # kernel shape: (..., input_depth, depth) + receptive_field_size = 1 + for dim in shape[:-2]: + receptive_field_size *= dim + fan_in = shape[-2] * receptive_field_size + fan_out = shape[-1] * receptive_field_size + return int(fan_in), int(fan_out) diff --git a/keras_hub/src/utils/dist_initializer_test.py b/keras_hub/src/utils/dist_initializer_test.py new file mode 100644 index 0000000000..3138cfe0b8 --- /dev/null +++ b/keras_hub/src/utils/dist_initializer_test.py @@ -0,0 +1,315 @@ +import jax +import keras.saving +import pytest +from keras import backend +from keras import distribution + +from keras_hub.src.utils import dist_initializer + + +@pytest.mark.skipif(backend.backend() != "jax", reason="jax only") +class TestDistributedInitializer: + @pytest.fixture(autouse=True) + def setUp(self): + """Set up distribution context for all tests""" + # Skip if not enough devices + if len(jax.devices()) < 8: + pytest.skip("requires 8+ devices") + + devices = jax.devices() + self.device_mesh = distribution.DeviceMesh( + (1, 8), ["batch", "model"], devices=devices + ) + self.layout_map = distribution.LayoutMap(self.device_mesh) + self.layout_map["token_embedding/embeddings"] = ("model", None) + + self.distribution = distribution.ModelParallel( + device_mesh=self.device_mesh, layout_map=self.layout_map + ) + distribution.set_distribution(self.distribution) + + yield + + distribution.set_distribution(None) + + def test_distributed_random_normal(self): + """Test DistributedRandomNormal creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedRandomNormal() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_truncated_normal(self): + """Test DistributedTruncatedNormal creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedTruncatedNormal() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_variance_scaling(self): + """Test DistributedVarianceScaling creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedVarianceScaling( + scale=2.0, mode="fan_in" + ) + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_random_uniform(self): + """Test DistributedRandomUniform creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedRandomUniform() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_glorot_uniform(self): + """Test DistributedGlorotUniform creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedGlorotUniform() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_glorot_normal(self): + """Test DistributedGlorotNormal creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedGlorotNormal() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_lecun_normal(self): + """Test DistributedLecunNormal creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedLecunNormal() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_lecun_uniform(self): + """Test DistributedLecunUniform creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedLecunUniform() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_he_normal(self): + """Test DistributedHeNormal creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedHeNormal() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) + + def test_distributed_he_uniform(self): + """Test DistributedHeUniform creates sharded arrays + and config serialization/deserialization""" + init = dist_initializer.DistributedHeUniform() + result = init(shape=(768, 256), dtype="float32") + assert isinstance(result, jax.Array) + assert result.shape == (768, 256) + assert len(result.sharding.device_set) == 8 + + # Get config + config = init.get_config() + + # Serialize and deserialize + serialized = keras.saving.serialize_keras_object(init) + + restored = keras.saving.deserialize_keras_object(serialized) + + # Check types match + assert type(init) is type(restored), ( + f"Type mismatch: {type(init)} vs {type(restored)}" + ) + + # Check config matches + restored_config = restored.get_config() + assert config == restored_config, ( + f"Config mismatch: {config} vs {restored_config}" + ) From fe1a9529aab5021f28484cbe403094cf68de8312 Mon Sep 17 00:00:00 2001 From: praveenhosdrug123 Date: Fri, 17 Oct 2025 20:51:26 +0530 Subject: [PATCH 2/2] Deleting sharded or unsharded weights --- keras_hub/src/utils/preset_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 71725eaaa5..5dd976c474 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -502,17 +502,11 @@ def jax_memory_cleanup(layer): # For jax, delete all previous allocated memory to avoid temporarily # duplicating variable allocations. torch and tensorflow have stateful # variable types and do not need this fix. - # Skip deletion for sharded arrays to avoid breaking references in - # distributed setups. + # Deleting all layer weights sharded or unsharded helps in loading phase if keras.config.backend() == "jax": for weight in layer.weights: - if weight._value is not None: - # Do not delete sharded arrays, as they may be referenced in - # JAX's distributed computation graph and deletion can cause - # errors. - sharding = getattr(weight._value, "sharding", None) - if sharding is None: - weight._value.delete() + if getattr(weight, "_value", None) is not None: + weight._value.delete() def set_dtype_in_config(config, dtype=None):