diff --git a/recml/core/training/keras_trainer.py b/recml/core/training/keras_trainer.py index 6c24223..7c26c34 100644 --- a/recml/core/training/keras_trainer.py +++ b/recml/core/training/keras_trainer.py @@ -118,6 +118,7 @@ def __init__( max_checkpoints_to_keep: int = 5, checkpoint_save_interval_epochs: int = 1, rng_seed: int = core.DEFAULT_RNG_SEED, + legacy_format: bool = True, ): """Initializes the instance.""" @@ -148,12 +149,14 @@ def __init__( model_dir, core.TRAINING_COMPLETE_MARKER_FILE ) self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR) + self._legacy_format = legacy_format if keras.backend.backend() == "jax": self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager( checkpoint_dir=self._checkpoint_dir, max_to_keep=max_checkpoints_to_keep, save_interval_epochs=checkpoint_save_interval_epochs, + legacy_format=self._legacy_format, ) self._train_callbacks = [ keras_utils.EpochSummaryCallback( @@ -314,13 +317,18 @@ def __init__( self, checkpoint_dir: str, epoch: int, + legacy_format: bool, ): self._checkpoint_dir = checkpoint_dir self._epoch = epoch + self._legacy_format = legacy_format def on_test_begin(self, logs: Mapping[str, Any] | None = None): keras_utils.restore_keras_model( - model, self._checkpoint_dir, step=self._epoch + model, + self._checkpoint_dir, + step=self._epoch, + legacy_format=self._legacy_format, ) history = None @@ -329,7 +337,9 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None): timeout=self._continuous_eval_timeout, timeout_fn=timeout_fn, ): - restore_callback = _RestoreCallback(self._checkpoint_dir, epoch) + restore_callback = _RestoreCallback( + self._checkpoint_dir, epoch, self._legacy_format + ) [tb_cbk] = [ cbk for cbk in self._eval_callbacks diff --git a/recml/core/training/keras_trainer_test.py b/recml/core/training/keras_trainer_test.py index 46844da..a986dd3 100644 --- a/recml/core/training/keras_trainer_test.py +++ b/recml/core/training/keras_trainer_test.py @@ -94,6 +94,22 @@ def test_keras_task_and_trainer(self, mode: str): ): self.assertEqual(history.history["num_params/trainable"][0], 2) + def test_new_checkpoint_format(self): + if keras.backend.backend() != "jax": + self.skipTest("Only supported on the Jax backend.") + trainer = keras_trainer.KerasTrainer( + distribution=keras.distribution.DataParallel(), + train_steps=5, + steps_per_eval=3, + steps_per_loop=2, + model_dir=self.create_tempdir().full_path, + continuous_eval_timeout=5, + legacy_format=False, + ) + experiment = core.Experiment(_KerasTask(), trainer) + core.run_experiment(experiment, core.Experiment.Mode.TRAIN) + core.run_experiment(experiment, core.Experiment.Mode.CONTINUOUS_EVAL) + if __name__ == "__main__": absltest.main() diff --git a/recml/core/utils/keras_utils.py b/recml/core/utils/keras_utils.py index 1ee6b39..bd54e86 100644 --- a/recml/core/utils/keras_utils.py +++ b/recml/core/utils/keras_utils.py @@ -13,7 +13,9 @@ # limitations under the License. """Utilities for training Keras models on Jax backend.""" -from collections.abc import Mapping +from collections.abc import Mapping, Sequence +import json +import re from typing import Any from absl import logging @@ -34,6 +36,187 @@ def _assert_variables_built(model: keras.Model): ) +class PathTrie: + """A class to create a Prefix Tree (Trie) from file paths.""" + + def __init__(self): + """Initializes the Trie with an empty dictionary as its root. + + Also initializes data structures for the context-aware re-indexing logic. + """ + self.root = {} + # Stores mapping from an original full path part (e.g., 'aaa/bcd/ee_1') + # to its new processed name (e.g., 'ee_1'). This acts as a cache. + self.reindex_map = {} + # Stores the next available index for a given full base name key. + # e.g., {'aaa/bcd/ee': 2} means the next part with this prefix + # and base name will become 'ee_2'. + self.base_name_counts = {} + + def reset(self): + """Resets the Trie to its initial state.""" + self.root = {} + self.reindex_map = {} + self.base_name_counts = {} + + def insert(self, path: str): + """Inserts a complete path string into the trie. + + It processes each path component, re-indexing it based on its context, + which is defined by the full path of processed parts leading up to it. + + Args: + path: A string representing the path, with components separated by '/'. + + Returns: + The reindexed path. + """ + node = self.root + original_parts = path.strip("/").split("/") + processed_prefix_parts = [] # The path built from *new* names + + for i, part in enumerate(original_parts): + # The key for memoization must be unique to the original part in its + # context. The full original path up to and including this part serves as + # this unique key. + original_path_to_part = "/".join(original_parts[: i + 1]) + + # Check if we've processed this exact part in its exact context before. + if original_path_to_part in self.reindex_map: + processed_part = self.reindex_map[original_path_to_part] + else: + # This is a new, unique part we haven't seen before. + + # Determine the local base name by stripping any '_' suffix. + match = re.match(r"(.+)_(\d+)$", part) + local_base_name = match.group(1) if match else part + + # The key for counting is the path of processed parts leading to + # the current part's base name. This provides the context. + full_base_name_key = "/".join( + processed_prefix_parts + [local_base_name] + ) + + # Get the current count for this base name to determine its new index. + count = self.base_name_counts.get(full_base_name_key, 0) + + # The first occurrence (count=0) of a base name is the local base name. + if count == 0: + processed_part = local_base_name + # Subsequent parts with the same base name context get a numeric suffix. + else: + processed_part = f"{local_base_name}_{count}" + + # Store the result in our memoization map for the original path part. + self.reindex_map[original_path_to_part] = processed_part + + # Increment the count for this base name context. + self.base_name_counts[full_base_name_key] = count + 1 + + # Move down the tree using the newly processed part name. + node = node.setdefault(processed_part, {}) + # Add the processed part to our prefix for the next iteration's context. + processed_prefix_parts.append(processed_part) + return ("/").join(processed_prefix_parts) + + def get_all_paths(self) -> list[str]: + """Traverses the trie to reconstruct and return all full paths. + + Returns: + A list of strings, where each string is a reconstructed path. + """ + all_paths = [] + self._traverse_paths(self.root, [], all_paths) + return all_paths + + def _traverse_paths( + self, + node: dict[str, Any], + current_path_parts: list[Any], + all_paths: list[Any], + ): + """A recursive helper function to perform a depth-first traversal of the trie. + + Args: + node: The current node (dictionary) in the trie. + current_path_parts: The list of parts forming the path to the current + node. + all_paths: The master list to which complete paths are added. + """ + # Iterate through all children of the current node. + for part, child_node in node.items(): + # Add the current part to our path tracker + current_path_parts.append(part) + + # If the child node is empty, it signifies the end of a complete path. + if not child_node: + all_paths.append("/".join(current_path_parts)) + else: + # If there are more parts, recurse deeper into the tree. + self._traverse_paths(child_node, current_path_parts, all_paths) + + # Backtrack: remove the current part to explore other branches (siblings). + current_path_parts.pop() + + def __str__(self): + """Returns a string representation of the trie in a readable JSON format.""" + return json.dumps(self.root, indent=4) + + +def _get_jax_state_with_keys( + model: keras.Model, + trainable_variables: bool = False, + non_trainable_variables: bool = False, + optimizer_variables: bool = False, + metrics_variables: bool = False, + purge_model_variables: bool = False, +) -> tuple[Sequence[Mapping[str, jax.Array]], Sequence[Mapping[str, int]]]: + """Returns a dictionary of variables with keys. + + Modified from _get_jax_state. + + Args: + model: The Keras model to get the variables from. + trainable_variables: Whether to get the trainable variables. + non_trainable_variables: Whether to get the non-trainable variables. + optimizer_variables: Whether to get the optimizer variables. + metrics_variables: Whether to get the metrics variables. + purge_model_variables: Whether to purge the model variables. + + Returns: + A list of dictionaries of variables with keys. + A list of indexes of the variables with keys. + """ + variable_list = [] + variable_index_list = [] + variables_path_trie = PathTrie() + for include_variables, variables in [ + (trainable_variables, model.trainable_variables), + (non_trainable_variables, model.non_trainable_variables), + (optimizer_variables, model.optimizer.variables), + (metrics_variables, model.metrics_variables), + ]: + if include_variables: + index = 0 + variable_dict = {} + variable_index = {} + for v in variables: + variable_key = variables_path_trie.insert(v.path) + variable_dict[variable_key] = v.value + variable_index[variable_key] = index + index += 1 + variable_list.append(variable_dict) + variable_index_list.append(variable_index) + if purge_model_variables: + model._purge_model_variables( # pylint: disable=protected-access + trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables, + metrics_variables=metrics_variables, + ) + return variable_list, variable_index_list + + class KerasOrbaxCheckpointManager(ocp.CheckpointManager): """An Orbax checkpoint manager for Keras 3.""" @@ -42,6 +225,7 @@ def __init__( checkpoint_dir: str, max_to_keep: int = 5, save_interval_epochs: int = 1, + legacy_format: bool = True, ): """Initializes a KerasOrbaxCheckpointManager. @@ -49,6 +233,7 @@ def __init__( checkpoint_dir: The directory to save checkpoints to. max_to_keep: The maximum number of checkpoints to keep. save_interval_epochs: The interval (in epochs) to save checkpoints. + legacy_format: Whether to use the legacy checkpoint format. """ super().__init__( directory=checkpoint_dir, @@ -58,6 +243,7 @@ def __init__( max_to_keep=max_to_keep, ), ) + self._legacy_format = legacy_format def save_model_variables( self, @@ -66,28 +252,54 @@ def save_model_variables( logs: Mapping[str, Any] | None = None, ): _assert_variables_built(model) - state = model._get_jax_state( # pylint: disable=protected-access - trainable_variables=True, - non_trainable_variables=True, - optimizer_variables=True, - # metrics_variables is default to False because we don't want to save - # metrics variables in the checkpoint. The metrics varibles are reset - # after each epoch. We need to recalculate them after restoring from - # the checkpoint. - metrics_variables=False, - ) logging.info("Writing checkpoint for epoch %s...", epoch) + if self._legacy_format: + state = model._get_jax_state( # pylint: disable=protected-access + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + # metrics_variables is default to False because we don't want to save + # metrics variables in the checkpoint. The metrics varibles are reset + # after each epoch. We need to recalculate them after restoring from + # the checkpoint. + metrics_variables=False, + ) + else: + state, _ = _get_jax_state_with_keys( + model, + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + ) + + logging.info( + "Save Checkpointing state with keys '%s'", [v.keys() for v in state] + ) self.save(step=epoch, items=state, metrics=logs) def restore_model_variables(self, model: keras.Model, epoch: int): _assert_variables_built(model) - state = model._get_jax_state( # pylint: disable=protected-access - trainable_variables=True, - non_trainable_variables=True, - optimizer_variables=True, - purge_model_variables=True, - ) + if self._legacy_format: + state = model._get_jax_state( # pylint: disable=protected-access + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + purge_model_variables=True, + ) + variable_index_list = None + else: + state, variable_index_list = _get_jax_state_with_keys( + model, + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=False, + purge_model_variables=True, + ) + logging.info( + "Restore Checkpointing state with keys '%s'", + [v.keys() for v in state], + ) logging.info("Restoring checkpoint for epoch %s...", epoch) model._jax_state_synced = False # pylint: disable=protected-access @@ -124,16 +336,21 @@ def _restore(value): ) logging.info("Restored checkpoint for epoch %s.", epoch) model._initial_epoch = epoch + 1 # pylint: disable=protected-access + if not self._legacy_format and variable_index_list is not None: + for i in range(len(variable_index_list)): + restored_state[i] = [ + restored_state[i][k] for k, _ in variable_index_list[i].items() + ] ( trainable_variables, non_trainable_variables, - optimizer_variables, - ) = restored_state + ) = restored_state[:2] model._jax_state = { # pylint: disable=protected-access "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, - "optimizer_variables": optimizer_variables, } + if self._legacy_format: + model._jax_state["optimizer_variables"] = restored_state[2] # pylint: disable=protected-access model.jax_state_sync() @@ -185,6 +402,8 @@ def restore_keras_model( restore_optimizer_vars: bool = True, restore_steps: bool = True, restore_iterations: bool = True, + legacy_format: bool = True, + transforms: Mapping[str, Any] | None = None, ): """Restores a Keras 3 Jax backend model from an Orbax checkpoint. @@ -203,8 +422,13 @@ def restore_keras_model( restore_iterations: Whether to restore the model's iterations. If `True` then the model will continue training from the iteration the checkpoint was saved at. This is an optimizer variable used for controlling the - learning rate schedule. This is not supported if restore_optimizer_vars - is `False`. + learning rate schedule. This is not supported if restore_optimizer_vars is + `False`. + legacy_format: Whether to use the legacy format for restoring the model. + transforms: A mapping from variable keys to the corresponding restore + arguments. If None, the model will be restored with the same variable + structure as the checkpoint. If provided, the model will be restored with + the provided transforms. Raises: FileNotFoundError: If no checkpoints are found in the checkpoint directory. @@ -216,6 +440,9 @@ def restore_keras_model( "This function only supports restoring a Keras 3 Jax backend model from" " a TF Saved Model." ) + if not legacy_format: + # TODO(zixiangzhou): Remove this once the optimizer format is supported. + restore_optimizer_vars = False _assert_variables_built(model) @@ -241,12 +468,25 @@ def restore_keras_model( ORBAX_CHECKPOINT_DEFAULT_KEY: ocp.handlers.PyTreeCheckpointHandler() }) ) - state = model._get_jax_state( # pylint: disable=protected-access - trainable_variables=True, - non_trainable_variables=True, - optimizer_variables=restore_optimizer_vars, - purge_model_variables=True, - ) + if legacy_format: + state = model._get_jax_state( # pylint: disable=protected-access + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=restore_optimizer_vars, + purge_model_variables=True, + ) + variable_index_list = None + else: + state, variable_index_list = _get_jax_state_with_keys( + model, + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=restore_optimizer_vars, + purge_model_variables=True, + ) + logging.info( + "Restore Checkpointing state with keys '%s'", [v.keys() for v in state] + ) model._jax_state_synced = False # pylint: disable=protected-access # Delete the state to save memory. @@ -270,13 +510,18 @@ def restore_keras_model( args=ocp.args.Composite(**{ ORBAX_CHECKPOINT_DEFAULT_KEY: ocp.args.PyTreeRestore( item=abstract_state, - transforms={}, + transforms={} if transforms is None else transforms, restore_args=ocp.checkpoint_utils.construct_restore_args( abstract_state ), ), }), )[ORBAX_CHECKPOINT_DEFAULT_KEY] + if not legacy_format and variable_index_list is not None: + for i in range(len(variable_index_list)): + restored_state[i] = [ + restored_state[i][k] for k, _ in variable_index_list[i].items() + ] ( trainable_variables, non_trainable_variables, @@ -287,7 +532,9 @@ def restore_keras_model( } if restore_optimizer_vars: optimizer_variables = restored_state[2] - model._jax_state["optimizer_variables"] = optimizer_variables # pylint: disable=protected-access + model._jax_state["optimizer_variables"] = ( # pylint: disable=protected-access + optimizer_variables + ) model.jax_state_sync() if restore_steps: model._initial_epoch = step + 1 # pylint: disable=protected-access diff --git a/recml/core/utils/keras_utils_test.py b/recml/core/utils/keras_utils_test.py index 010707a..0582761 100644 --- a/recml/core/utils/keras_utils_test.py +++ b/recml/core/utils/keras_utils_test.py @@ -14,6 +14,7 @@ """Tests or utilities.""" from collections.abc import Sequence +import logging from absl import flags from absl.testing import absltest @@ -23,6 +24,7 @@ import keras import keras_hub import numpy as np +import orbax.checkpoint as ocp from recml.core.utils import keras_utils _LEARNING_RATE_SCHEDULE = keras.optimizers.schedules.PolynomialDecay( @@ -68,31 +70,66 @@ def setUp(self): "testcase_name": "single_core", "data_parallel": False, "restore_with_checkpointer": True, + "legacy_format": False, }, { "testcase_name": "data_parallel", "data_parallel": True, "restore_with_checkpointer": True, + "legacy_format": False, }, { "testcase_name": "restore_without_checkpointer_data_parallel", "data_parallel": True, "restore_with_checkpointer": False, + "legacy_format": False, }, { "testcase_name": "restore_without_checkpointer_model_parallel", "data_parallel": False, "restore_with_checkpointer": False, + "legacy_format": False, + }, + { + "testcase_name": "single_core_legacy_format", + "data_parallel": False, + "restore_with_checkpointer": True, + "legacy_format": True, + }, + { + "testcase_name": "data_parallel_legacy_format", + "data_parallel": True, + "restore_with_checkpointer": True, + "legacy_format": True, + }, + { + "testcase_name": ( + "restore_without_checkpointer_data_parallel_legacy_format" + ), + "data_parallel": True, + "restore_with_checkpointer": False, + "legacy_format": True, + }, + { + "testcase_name": ( + "restore_without_checkpointer_model_parallel_legacy_format" + ), + "data_parallel": False, + "restore_with_checkpointer": False, + "legacy_format": True, }, ) def test_keras_orbax_checkpointer( - self, data_parallel: bool, restore_with_checkpointer: bool + self, + data_parallel: bool, + restore_with_checkpointer: bool, + legacy_format: bool, ): if data_parallel: keras.distribution.set_distribution(keras.distribution.DataParallel()) checkpoint_dir = self.create_tempdir().full_path checkpointer = keras_utils.KerasOrbaxCheckpointManager( - checkpoint_dir, max_to_keep=5 + checkpoint_dir, max_to_keep=5, legacy_format=legacy_format ) epoch = 1 dummy_inputs = { @@ -142,7 +179,9 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model: if restore_with_checkpointer: checkpointer.restore_model_variables(bert_pretrainer, epoch) else: - keras_utils.restore_keras_model(bert_pretrainer, checkpoint_dir) + keras_utils.restore_keras_model( + bert_pretrainer, checkpoint_dir, legacy_format=legacy_format + ) restored_state = ( [v.value for v in bert_pretrainer.trainable_variables], [v.value for v in bert_pretrainer.non_trainable_variables], @@ -160,6 +199,40 @@ def _close(a: jax.Array, b: jax.Array): # Ensures predictions are identical. self.assertTrue(_close(preds, preds_after_restoration)) + def test_path_trie(self): + path_trie = keras_utils.PathTrie() + input_list_1 = [ + "seed_generator_31/seed_generator_state", + "transformer_layer_0/self_attention_layer/seed_generator_32/seed_generator_state", + "transformer_layer_0/seed_generator_33/seed_generator_state", + "transformer_layer_0/seed_generator_34/seed_generator_state", + "aaa/bcd/ee_0/xyx", + "aaa/bcd/ee_0/ee_1", + "aaa/bcd/ee_0/ee_2", + "aaa/bcd/ee_1/zzz", + "aaa/cde/ee_2/fgh", + "bbb/data/images/img_01.jpg", + "bbb/data/images/img_02.jpg", + "bbb/data/text/doc.txt", + ] + for path in input_list_1: + path_trie.insert(path) + expected_output = [ + "seed_generator/seed_generator_state", + "transformer_layer/self_attention_layer/seed_generator/seed_generator_state", + "transformer_layer/seed_generator/seed_generator_state", + "transformer_layer/seed_generator_1/seed_generator_state", + "aaa/bcd/ee/xyx", + "aaa/bcd/ee/ee", + "aaa/bcd/ee/ee_1", + "aaa/bcd/ee_1/zzz", + "aaa/cde/ee/fgh", + "bbb/data/images/img_01.jpg", + "bbb/data/images/img_02.jpg", + "bbb/data/text/doc.txt", + ] + self.assertEqual(path_trie.get_all_paths(), expected_output) + def test_restore_keras_model_error_cases(self): checkpoint_dir = self.create_tempdir().full_path checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir) @@ -355,6 +428,55 @@ def test_restore_keras_model_with_different_options( target_bert_pretrainer._initial_epoch, expected_initial_epoch ) + def test_restore_ckpt_with_transform(self): + checkpoint_dir = self.create_tempdir().full_path + checkpointer = keras_utils.KerasOrbaxCheckpointManager( + checkpoint_dir, legacy_format=False + ) + epoch = 2 + dummy_inputs = { + "token_ids": jax.random.randint( + jax.random.key(0), (64, 128), minval=0, maxval=50_000 + ), + "segment_ids": jax.random.randint( + jax.random.key(0), (64, 128), minval=0, maxval=7 + ), + "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), + "mask_positions": jax.random.randint( + jax.random.key(0), (64, 20), minval=0, maxval=128 + ), + } + + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + checkpointer.save_model_variables(bert_pretrainer, epoch) + checkpointer.wait_until_finished() + target_bert_pretrainer = _create_model( + jax.tree.map(jnp.shape, dummy_inputs) + ) + keras_utils.restore_keras_model( + target_bert_pretrainer, + checkpoint_dir, + legacy_format=False, + transforms={ + r"(.*)transformer_layer_2(.*)": ocp.Transform( + original_key=r"\1transformer_layer_1\2" + ), + }, + ) + target_variables_list = [] + source_variables_list = [] + for v in target_bert_pretrainer.trainable_variables: + if "transformer_layer_2" in v.path: + target_variables_list.append(v) + for v in bert_pretrainer.trainable_variables: + if "transformer_layer_1" in v.path: + source_variables_list.append(v) + for v1, v2 in zip(target_variables_list, source_variables_list): + np.testing.assert_almost_equal( + keras.ops.convert_to_numpy(v1.value), + keras.ops.convert_to_numpy(v2.value), + ) + if __name__ == "__main__": absltest.main()