From 8767bfd138a0b1df9a07d63bd2fcd597ad576dfa Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Wed, 26 Nov 2025 14:24:59 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 837257237 --- checkpoint/CHANGELOG.md | 13 +++ .../standard_checkpoint_handler_test_utils.py | 42 ++++++- .../orbax/checkpoint/_src/path/atomicity.py | 4 +- .../orbax/checkpoint/_src/path/deleter.py | 104 +++++++----------- .../checkpoint/_src/path/deleter_test.py | 50 +++++++++ .../orbax/checkpoint/_src/path/gcs_utils.py | 8 +- checkpoint/orbax/checkpoint/_src/path/step.py | 9 +- .../_src/serialization/jax_array_handlers.py | 24 ++-- .../_src/serialization/tensorstore_utils.py | 59 ++++++++++ .../serialization/tensorstore_utils_test.py | 54 +++++++++ .../_src/serialization/type_handlers.py | 60 ++++------ .../checkpoint_manager_benchmark.py | 9 +- .../checkpoint_manager_perf_benchmark.py | 21 +--- .../orbax/checkpoint/checkpoint_manager.py | 5 - .../orbax/checkpoint/checkpoint_utils.py | 8 +- .../local_checkpoint_data_debugging.py | 3 +- .../experimental/v1/_src/context/options.py | 9 +- .../v1/_src/handlers/compatibility.py | 13 ++- .../v1/_src/handlers/composite_handler.py | 3 +- .../v1/_src/layout/orbax_layout.py | 4 +- .../v1/_src/layout/pytorch_layout.py | 8 +- .../experimental/v1/_src/layout/registry.py | 5 +- .../experimental/v1/_src/loading/loading.py | 13 +-- .../experimental/v1/_src/metadata/loading.py | 14 +-- .../experimental/v1/_src/partial/merging.py | 4 +- .../experimental/v1/_src/partial/path.py | 10 +- .../experimental/v1/_src/partial/saving.py | 8 +- .../experimental/v1/_src/path/types.py | 3 +- .../experimental/v1/_src/saving/execution.py | 5 +- .../v1/_src/serialization/compatibility.py | 3 +- checkpoint/orbax/checkpoint/version.py | 2 +- export/orbax/export/constants.py | 9 +- export/orbax/export/jax_module.py | 10 ++ export/orbax/export/modules/obm_module.py | 45 ++++---- .../orbax/export/modules/obm_module_test.py | 26 +++++ export/orbax/export/obm_configs.py | 19 ++++ export/orbax/export/serving_config.py | 12 ++ ...nist_oex_orchestration_pipelines.textproto | 1 + export/orbax/export/utils.py | 38 +++++++ model/orbax/experimental/model/cli/README.md | 1 + .../model/core/python/compile_options_util.py | 46 ++++++-- .../core/python/compile_options_util_test.py | 77 +++++-------- .../orbax/experimental/model/jax2obm/utils.py | 2 +- 43 files changed, 571 insertions(+), 292 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 78cf33006..5e8ba0915 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,11 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.11.30] - 2025-11-26 + +### Fixed + +- Roll back earlier change altering metadata format, which was observed to cause +breakages. + +## [0.11.29] - 2025-11-25 + ### Fixed - Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to include an arbitrary `step_prefix` with any character(s) such as underscores. - Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`. +- Fix using jax.eval_shape with StandardRestore ### Changed @@ -19,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 state, avoiding additional I/O to re-read metadata. - add `support_format` to utils.to_shape_dtype_struct() - Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`. +- Replace usage of `get_json_tpec_read` and delegate functionality to new + function `build_array_read_spec` which constructs and returns an + `ArrayReadSpec`. ## [0.11.28] - 2025-11-06 diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py index 5c01be41f..9e3bf5d46 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py @@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler: return StandardCheckpointHandler() def test_with_random_keys(self): + # TODO(b/393160483) investigate Pathways remote Python support for + # random.keys. if utils.is_pathways_backend(): - self.skipTest('Pathways does not support random keys checkpoint.') + self.skipTest( + 'Disabled on Pathways because random keys are not supported by' + ' remote Python.' + ) def create_random_keys(seed): duplicated_sharding = jax.sharding.NamedSharding( @@ -559,3 +564,38 @@ def create_random_keys(seed): args=self.restore_args_cls(abstract_tree), ) test_utils.assert_tree_equal(self, self.pytree, restored) + + def test_save_restore_random_keys_with_jax_eval_shape(self): + # TODO(b/393160483) investigate Pathways remote Python support for + # random.keys. + if utils.is_pathways_backend(): + self.skipTest( + 'Disabled on Pathways because random keys are not supported by' + ' remote Python.' + ) + + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + @functools.partial( + jax.jit, + in_shardings=sharding, + out_shardings=sharding, + ) + def sharded_create_state_fn(root_key): + return dict( + matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + rngkey=jax.random.fold_in(root_key, 42), + ) + + pytree = sharded_create_state_fn(jax.random.key(0)) + abstract_pytree = jax.eval_shape( + sharded_create_state_fn, jax.random.key(0) + ) + + self.handler.save(self.directory, args=self.save_args_cls(pytree)) + + restored = self.handler.restore( + self.directory, args=self.restore_args_cls(abstract_pytree) + ) + test_utils.assert_tree_equal(self, pytree, restored) diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py index 8ffa2c7ca..4f76da675 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -161,13 +161,13 @@ async def _create_tmp_directory( def _get_tmp_directory(final_path: epath.Path) -> epath.Path: # Path may not be completely unique if a preemption occurs. We rely on the # existing tmp directory being deleted elsewhere. - return epath.Path(final_path.parent) / (final_path.name + TMP_DIR_SUFFIX) + return final_path.parent / (final_path.name + TMP_DIR_SUFFIX) def _get_final_directory(tmp_path: epath.Path) -> epath.Path: if (suffix_idx := tmp_path.name.find(TMP_DIR_SUFFIX)) == -1: raise ValueError(f'Expected {tmp_path} to end with "{TMP_DIR_SUFFIX}".') - return epath.Path(tmp_path.parent) / tmp_path.name[:suffix_idx] + return tmp_path.parent / tmp_path.name[:suffix_idx] class TemporaryPathBase(atomicity_types.TemporaryPath): diff --git a/checkpoint/orbax/checkpoint/_src/path/deleter.py b/checkpoint/orbax/checkpoint/_src/path/deleter.py index b88a1dc58..d24483653 100644 --- a/checkpoint/orbax/checkpoint/_src/path/deleter.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter.py @@ -21,6 +21,7 @@ import threading import time from typing import Optional, Protocol, Sequence +from urllib import parse from absl import logging from etils import epath @@ -31,6 +32,7 @@ from orbax.checkpoint._src.path import step as step_lib +urlparse = parse.urlparse PurePosixPath = pathlib.PurePosixPath _THREADED_DELETE_DURATION = ( @@ -183,7 +185,9 @@ def delete(self, step: int) -> None: # Attempt to rename using GCS HNS API if configured. if self._todelete_full_path is not None: if gcs_utils.is_gcs_path(self._directory): - self._rename_gcs_step_with_hns(step, delete_target) + # This is recommended for GCS buckets with HNS enabled and requires + # `_todelete_full_path` to be specified. + self._gcs_rename_step(step, delete_target) else: raise NotImplementedError() # Attempt to rename to local subdirectory using `todelete_subdir` @@ -204,88 +208,56 @@ def delete(self, step: int) -> None: time.time() - start, ) - def _rename_gcs_step_with_hns( + def _gcs_rename_step( self, step: int, delete_target: epath.Path ): - """Renames a GCS directory using the Storage Control API. + """Renames a GCS directory to a temporary location for deletion. + + This method renames the directory using the + underlying `tf.io.gfile.rename` method. This underlying + implementation will automatically detect if the bucket is HNS-enabled + and use a fast atomic rename, or fall back to a legacy + copy/delete rename if it is not. Args: step: The checkpoint step number. delete_target: The path to the directory to be renamed. - - Raises: - ValueError: If the GCS bucket is not HNS-enabled, as this is a - hard requirement for this operation. """ - logging.info( - 'Condition: GCS path with `todelete_full_path` set. Checking for HNS.' - ) - bucket_name, _ = gcs_utils.parse_gcs_path(self._directory) - if not gcs_utils.is_hierarchical_namespace_enabled(self._directory): - raise ValueError( - f'Bucket "{bucket_name}" does not have Hierarchical Namespace' - ' enabled, which is required when _todelete_full_path is set.' - ) - - logging.info('HNS bucket detected. Attempting to rename step %d.', step) - # pylint: disable=g-import-not-at-top - from google.api_core import exceptions as google_exceptions # pytype: disable=import-error try: - from google.cloud import storage_control_v2 # pytype: disable=import-error - import google.auth # pytype: disable=import-error - - # Use default credentials, but without a quota project to avoid - # quota issues with this API. - credentials, _ = google.auth.default() - creds_without_quota_project = credentials.with_quota_project(None) - client = storage_control_v2.StorageControlClient( - credentials=creds_without_quota_project - ) - # Destination parent is the absolute path to the bucket. - destination_parent_dir_str = ( + # Get the bucket name from the source path + bucket_name = urlparse(str(delete_target)).netloc + if not bucket_name: + raise ValueError( + f'Could not parse bucket name from path: {delete_target}' + ) + + # Construct the destination path inside the `_todelete_full_path` dir. + destination_parent_path = epath.Path( f'gs://{bucket_name}/{self._todelete_full_path}' ) - destination_parent_path = PurePosixPath(destination_parent_dir_str) - logging.info( - 'Ensuring destination parent folder exists via HNS API: %s', - destination_parent_dir_str, - ) - try: - parent_folder_id = str( - destination_parent_path.relative_to(f'gs://{bucket_name}') - ) - bucket_resource_name = f'projects/_/buckets/{bucket_name}' - client.create_folder( - request=storage_control_v2.CreateFolderRequest( - parent=bucket_resource_name, - folder_id=parent_folder_id, - recursive=True, - ) - ) - logging.info('HNS parent folder creation request sent.') - except google_exceptions.AlreadyExists: - logging.info('HNS parent folder already exists, proceeding.') + destination_parent_path.mkdir(parents=True, exist_ok=True) + # Create a unique name for the destination to avoid collisions. now = datetime.datetime.now() timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f') new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}' dest_path = destination_parent_path / new_name_with_timestamp - source_folder_id = str(delete_target.relative_to(f'gs://{bucket_name}')) - destination_folder_id = str(dest_path.relative_to(f'gs://{bucket_name}')) - source_resource_name = ( - f'projects/_/buckets/{bucket_name}/folders/{source_folder_id}' - ) - logging.info('Rename API call: Source: %s', source_resource_name) - logging.info('Rename API call: Destination ID: %s', destination_folder_id) - request = storage_control_v2.RenameFolderRequest( - name=source_resource_name, - destination_folder_id=destination_folder_id, + + logging.info( + 'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`', + delete_target, + dest_path, ) - op = client.rename_folder(request=request) - op.result() + + # Call the high-level rename method. + # This will be fast on HNS and slow (but functional) on non-HNS. + delete_target.rename(dest_path) logging.info('Successfully renamed step %d to %s', step, dest_path) - except google_exceptions.GoogleAPIError as e: - logging.error('HNS rename failed for step %d. Error: %s', step, e) + + except Exception as e: + message = f'Rename failed for step {step}. Error: {e}' + logging.error(message) + raise RuntimeError(message) from e def _rename_step_to_subdir(self, step: int, delete_target: epath.Path): """Renames a step directory to its corresponding todelete_subdir.""" diff --git a/checkpoint/orbax/checkpoint/_src/path/deleter_test.py b/checkpoint/orbax/checkpoint/_src/path/deleter_test.py index 1702a58f4..977744cb1 100644 --- a/checkpoint/orbax/checkpoint/_src/path/deleter_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter_test.py @@ -13,6 +13,10 @@ # limitations under the License. """To test Orbax in single-host setup.""" + +import unittest +from unittest import mock + from absl.testing import absltest from absl.testing import parameterized from etils import epath @@ -64,5 +68,51 @@ def test_checkpoint_deleter_delete( deleter.close() +class GcsRenameTest(unittest.TestCase): + + @mock.patch('orbax.checkpoint._src.path.deleter.epath.Path') + def test_gcs_rename_logic_directly(self, mock_epath_constructor): + """Tests path construction and rename call logic.""" + standard_checkpoint_deleter = deleter_lib.StandardCheckpointDeleter + + deleter = standard_checkpoint_deleter( + directory=mock.MagicMock(), + name_format=step_lib.standard_name_format(), + primary_host=None, + todelete_subdir=None, + todelete_full_path='trash_bin', + enable_hns=False, + ) + # When epath.Path() is called inside the code, it returns this mock parent + mock_dest_parent = mock.MagicMock() + mock_epath_constructor.return_value = mock_dest_parent + + # When the code does (parent / child), return a specific final mock + mock_final_dest = mock.MagicMock() + mock_final_dest.__str__.return_value = 'gs://mocked/final/destination' + mock_dest_parent.__truediv__.return_value = mock_final_dest + + # Setup the "Source" Mock (The step being deleted) + mock_step_path = mock.MagicMock() + mock_step_path.__str__.return_value = 'gs://my-bucket/checkpoints/step_10' + mock_step_path.name = 'step_10' + + deleter._gcs_rename_step(step=10, delete_target=mock_step_path) + + # Verify mkdir was called on the destination parent. + mock_dest_parent.mkdir.assert_called_with(parents=True, exist_ok=True) + + # Verify the Parent Path string was constructed correctly + # The code does: epath.Path(f'gs://{bucket}/{todelete_full_path}') + (parent_path_arg,), _ = mock_epath_constructor.call_args + self.assertEqual(parent_path_arg, 'gs://my-bucket/trash_bin') + + # Verify the Child Filename was constructed correctly + (child_name_arg,), _ = mock_dest_parent.__truediv__.call_args + self.assertIn('step_10-', child_name_arg) + + # Verify the Rename was actually called + mock_step_path.rename.assert_called_with(mock_final_dest) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py index 2fb12c963..aa64afeca 100644 --- a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py @@ -50,6 +50,12 @@ def get_bucket(bucket_name: str): def is_hierarchical_namespace_enabled(path: epath.PathLike) -> bool: """Return whether hierarchical namespace is enabled.""" + parsed = parse.urlparse(str(path)) + if parsed.scheme != 'gs': + return False bucket_name, _ = parse_gcs_path(path) bucket = get_bucket(bucket_name) - return bucket.hierarchical_namespace_enabled + return ( + hasattr(bucket, 'hierarchical_namespace_enabled') + and bucket.hierarchical_namespace_enabled + ) diff --git a/checkpoint/orbax/checkpoint/_src/path/step.py b/checkpoint/orbax/checkpoint/_src/path/step.py index 0e1bb44c6..282cced49 100644 --- a/checkpoint/orbax/checkpoint/_src/path/step.py +++ b/checkpoint/orbax/checkpoint/_src/path/step.py @@ -321,13 +321,11 @@ class _StandardNameFormat(NameFormat[Metadata]): single_host_load_and_broadcast: If True, the jax process=0 will list all steps and broadcast them to all other processes. NOTE: Ignored if jax backend is not multi controller. - enable_hns: Enables HNS-specific path logic. """ step_prefix: Optional[str] = None step_format_fixed_length: Optional[int] = None single_host_load_and_broadcast: bool = False - enable_hns: bool = False def __str__(self): return f'StandardNameFormat("{self.build_name(1234)}")' @@ -375,9 +373,7 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]: """Returns step paths under `base_path`.""" base_path = epath.Path(base_path) # _?<0 padding>?* - if self.enable_hns and gcs_utils.is_hierarchical_namespace_enabled( - base_path - ): + if gcs_utils.is_hierarchical_namespace_enabled(base_path): logging.vlog( 1, 'HNS enabled. Using GCS API to list step paths at %s', @@ -560,7 +556,6 @@ def standard_name_format( step_prefix: Optional[str] = None, step_format_fixed_length: Optional[int] = None, single_host_load_and_broadcast: bool = False, - enable_hns: bool = False, ) -> NameFormat[Metadata]: """Returns NameFormat for 'standard' steps for common Orbax use cases. @@ -580,13 +575,11 @@ def standard_name_format( single_host_load_and_broadcast: If True, the jax process=0 will list all steps and broadcast them to all other processes. NOTE: Ignored if jax backend is not multi controller. - enable_hns: Enables HNS-specific path logic. """ return _StandardNameFormat( step_prefix=step_prefix, step_format_fixed_length=step_format_fixed_length, single_host_load_and_broadcast=single_host_load_and_broadcast, - enable_hns=enable_hns, ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index e43a39118..40876a472 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -763,12 +763,13 @@ async def _async_deserialize( await _validate_non_ocdbt_files(infos, metadata_key) deserialize_ops = [] for info, arg, sharding in zip(infos, args, shardings): - tspec = ts_utils.get_json_tspec_read( + array_read_spec = ts_utils.build_array_read_spec( info, use_ocdbt=use_ocdbt, metadata_key=metadata_key, raise_array_data_missing_error=info.raise_array_data_missing_error, ) + tspec = array_read_spec.json tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg) # set dtype=None to deserialize for random keys @@ -939,19 +940,6 @@ def __init__( def has_dispatcher(self) -> bool: return self._dispatcher is not None - def _get_json_tspec_read( - self, - info: types.ParamInfo, - use_ocdbt: bool, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for reading.""" - return ts_utils.get_json_tspec_read( - info, - use_ocdbt=use_ocdbt, - metadata_key=self._metadata_key, - raise_array_data_missing_error=info.raise_array_data_missing_error, - ) - def typestr(self) -> str: return JAX_ARRAY_TYPE_STR @@ -968,7 +956,13 @@ async def metadata( for info in infos: # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json open_ops.append( ts.open(ts.Spec(tspec), open=True, context=info.ts_context) ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index 826c6c0e7..02a954aa8 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec( return array_tspec +class ArrayReadSpec: + """Full TensorStore spec for reading an array.""" + + def __init__( + self, + directory: str, + relative_array_filename: str, + use_zarr3: bool, + *, + use_ocdbt: bool, + metadata_key: str | None = None, + raise_array_data_missing_error: bool = True, + ): + """Builds a TensorStore spec for reading an array.""" + kvstore_tspec = build_kvstore_tspec( + directory, + name=relative_array_filename, + use_ocdbt=use_ocdbt, + process_id=None, + ) + + tspec = { + 'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2, + 'kvstore': kvstore_tspec, + 'recheck_cached_data': False, + 'recheck_cached_metadata': False, + # Raise error if data is missing. + 'fill_missing_data_reads': not raise_array_data_missing_error, + } + if metadata_key is not None: + tspec['metadata_key'] = metadata_key + self._json_spec = tspec + + @property + def json(self) -> JsonSpec: + """Spec to be used to open a TensorStore for reading the array.""" + return self._json_spec + + class ArrayWriteSpec: """Full TensorStore spec for writing an array.""" @@ -677,6 +716,26 @@ def get_json_tspec_write( return tspec +def build_array_read_spec( + info: types.ParamInfo, + *, + use_ocdbt: bool, + metadata_key: str | None = None, + raise_array_data_missing_error: bool = True, +) -> ArrayReadSpec: + """Gets ArrayReadSpec for reading.""" + if info.name is None or info.parent_dir is None: + raise ValueError('Must provide info.name and info.parent_dir.') + return ArrayReadSpec( + directory=info.parent_dir.as_posix(), + relative_array_filename=info.name, + use_zarr3=info.use_zarr3, + use_ocdbt=use_ocdbt, + metadata_key=metadata_key, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + + def build_array_write_spec( info: types.ParamInfo, arg: types.SaveArgs | None = None, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py index c052be72e..aa3420bfa 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py @@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self): self.assertTrue(ts_utils.is_remote_storage(nested_tspec)) +class BuildArrayTSpecForReadTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = self.create_tempdir().full_path + self.param_name = 'params/a' + + self.array_read_spec_constructor = functools.partial( + ts_utils.ArrayReadSpec, + directory=self.directory, + relative_array_filename=self.param_name, + ) + + @parameterized.product( + use_zarr3=(True, False), + use_ocdbt=(True, False), + ) + def test_basic(self, use_zarr3: bool, use_ocdbt: bool): + tspec = self.array_read_spec_constructor( + use_zarr3=use_zarr3, + use_ocdbt=use_ocdbt, + ) + json_spec = tspec.json + self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr') + self.assertEqual( + json_spec['kvstore']['driver'], + 'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER, + ) + self.assertFalse(json_spec['recheck_cached_data']) + self.assertFalse(json_spec['recheck_cached_metadata']) + self.assertFalse(json_spec['fill_missing_data_reads']) + self.assertNotIn('metadata_key', json_spec) + + def test_metadata_key(self): + tspec = self.array_read_spec_constructor( + use_zarr3=False, + use_ocdbt=False, + metadata_key='custom_metadata', + ) + self.assertEqual(tspec.json['metadata_key'], 'custom_metadata') + + @parameterized.parameters(True, False) + def test_fill_missing_data_reads(self, raise_array_data_missing_error): + tspec = self.array_read_spec_constructor( + use_zarr3=False, + use_ocdbt=False, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + self.assertEqual( + tspec.json['fill_missing_data_reads'], + not raise_array_data_missing_error, + ) + + class GetTsContextTest(parameterized.TestCase): @parameterized.product( diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index 53a234d70..18668be86 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -77,39 +77,6 @@ def __init__( self._metadata_key = metadata_key self._override_ocdbt_process_id = ocdbt_process_id - def _get_array_write_spec( - self, - info: types.ParamInfo, - value: np.ndarray, - use_ocdbt: bool, - process_index: Optional[Union[int, str]] = None, - arg: Optional[types.SaveArgs] = None, - ) -> ts_utils.ArrayWriteSpec: - """Gets ArrayWriteSpec for writing.""" - return ts_utils.build_array_write_spec( - info=info, - arg=arg, - global_shape=value.shape, - local_shape=value.shape, - dtype=value.dtype, - use_ocdbt=use_ocdbt, - process_index=process_index, - metadata_key=self._metadata_key, - ) - - def _get_json_tspec_read( - self, - info: types.ParamInfo, - use_ocdbt: bool, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for reading.""" - return ts_utils.get_json_tspec_read( - info, - use_ocdbt=use_ocdbt, - metadata_key=self._metadata_key, - raise_array_data_missing_error=info.raise_array_data_missing_error, - ) - def typestr(self) -> str: return 'np.ndarray' @@ -120,7 +87,13 @@ async def metadata( for info in infos: # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json open_ops.append( ts.open(ts.Spec(tspec), open=True, context=info.ts_context) ) @@ -149,15 +122,18 @@ async def _background_serialize( """Serializes numpy arrays in a background thread.""" write_coros = [] for value, info, arg in zip(values, infos, args): - array_write_spec = self._get_array_write_spec( - info, - value, + array_write_spec = ts_utils.build_array_write_spec( + info=info, + arg=arg, + global_shape=value.shape, + local_shape=value.shape, + dtype=value.dtype, use_ocdbt=info.is_ocdbt_checkpoint, process_index=ocdbt_utils.get_process_index_for_subdir( use_ocdbt=info.is_ocdbt_checkpoint, override_ocdbt_process_id=self._override_ocdbt_process_id, ), - arg=arg, + metadata_key=self._metadata_key, ) tspec = array_write_spec.json if logging.vlog_is_on(1): @@ -205,7 +181,13 @@ async def deserialize( ) # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg) if logging.vlog_is_on(1): diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py index 8324fc670..067896fe2 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py @@ -20,10 +20,10 @@ from absl import logging import jax import numpy as np +import orbax.checkpoint as ocp from orbax.checkpoint import args as args_lib from orbax.checkpoint import checkpoint_manager from orbax.checkpoint import multihost -from orbax.checkpoint import type_handlers from orbax.checkpoint import utils from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib @@ -57,6 +57,13 @@ def test_fn( options = context.options assert isinstance(options, CheckpointManagerBenchmarkOptions) + if multihost.is_pathways_backend(): + checkpointing_impl = ocp.pathways.CheckpointingImpl.from_options( + use_remote_python=True + ) + ocp.pathways.register_type_handlers( + checkpointing_impl=checkpointing_impl, + ) cm_options = checkpoint_manager.CheckpointManagerOptions( save_interval_steps=options.save_interval_steps, diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py index 3e7d7d087..0087d726f 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark.py @@ -20,7 +20,6 @@ from typing import Any import jax -import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib @@ -118,24 +117,6 @@ def test_fn( with metrics.measure(f'train_step_{i}'): pytree = self._train_step(pytree) - save_times = np.array(save_times) - total_save_times = np.array(total_save_times) - - # Exclude step 0 from assertions; setup may take extra time. - asserting_save_times = save_times[1:] - asserting_total_save_times = total_save_times[1:] - - mean_save_time = np.mean(asserting_save_times) - mean_total_save_time = np.mean(asserting_total_save_times) - - assert np.all(asserting_save_times <= 2 * mean_save_time), ( - f'Save times={asserting_save_times}, mean save time={mean_save_time}' - ) - assert np.all(asserting_total_save_times <= 2 * mean_total_save_time), ( - f'Total save times={asserting_total_save_times}, mean total save' - f' time={mean_total_save_time}' - ) - abstract_pytree = jax.tree.map( lambda x: ocp.utils.to_shape_dtype_struct(x) if isinstance(x, jax.Array) @@ -154,5 +135,7 @@ def test_fn( ), ) + # TODO(nikhilbansall) : Add assertions for this test. + mngr.close() return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index ce9f9d9da..c519f3767 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -318,8 +318,6 @@ class CheckpointManagerOptions: gs://my-bucket/trash/. Useful when direct deletion is time consuming. It gathers all deleted items in a centralized path for future cleanup. - enable_hns: If True, enables HNS-specific path manipulation logic. - Experimental feature. enable_background_delete: If True, old checkpoint deletions will be done in a background thread, otherwise, it will be done at the end of each save. When it's enabled, make sure to call CheckpointManager.close() or use context to @@ -386,7 +384,6 @@ class CheckpointManagerOptions: single_host_load_and_broadcast: bool = False todelete_subdir: Optional[str] = None todelete_full_path: Optional[str] = None - enable_hns: bool = False enable_background_delete: bool = False read_only: bool = False enable_async_checkpointing: bool = True @@ -878,7 +875,6 @@ def __init__( single_host_load_and_broadcast=( self._options.single_host_load_and_broadcast ), - enable_hns=self._options.enable_hns, ) ) @@ -909,7 +905,6 @@ def __init__( primary_host=self._multiprocessing_options.primary_host, todelete_subdir=self._options.todelete_subdir, todelete_full_path=self._options.todelete_full_path, - enable_hns=self._options.enable_hns, enable_background_delete=self._options.enable_background_delete, ) ) diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils.py b/checkpoint/orbax/checkpoint/checkpoint_utils.py index 142389f20..f8d00904e 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils.py @@ -506,10 +506,16 @@ def _array_restore_args( sharding: Optional[jax.sharding.Sharding | Format], # pytype: disable=unsupported-operands dtype: Optional[np.dtype] = None, ) -> type_handlers.ArrayRestoreArgs: + global_shape = None + # For random keys, we only allow overriding the sharding. + if set_global_shape and not jax.dtypes.issubdtype( + value.dtype, jax.dtypes.prng_key + ): + global_shape = value.shape return type_handlers.ArrayRestoreArgs( restore_type=jax.Array, sharding=sharding, - global_shape=value.shape if set_global_shape else None, + global_shape=global_shape, dtype=dtype, ) diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py b/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py index 6e78304d0..ff57fb24f 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py @@ -101,7 +101,8 @@ async def open_tensorstore( use_zarr3=use_zarr3, ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), ) - tspec = ts_utils.get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec(info, use_ocdbt=use_ocdbt) + tspec = array_read_spec.json return await ts.open( ts.Spec(tspec), read=True, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 31cecfa0a..df6dff164 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -20,6 +20,7 @@ import enum from typing import Any, Callable, Protocol, Type +from etils import epath import numpy as np from orbax.checkpoint import options as v0_options_lib from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib @@ -27,6 +28,7 @@ from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types +from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -102,10 +104,15 @@ class FileOptions: temporary_path_class: A class that is used to create and finallize temporary paths, and to ensure atomicity. + path_class: + The implementation of `path_types.Path` to use. Defaults to + `etils.epath.Path`, but may be overridden to some other subclass of + `path_types.Path`. """ path_permission_mode: int | None = None - temporary_path_class: atomicity_types.TemporaryPath | None = None + temporary_path_class: type[atomicity_types.TemporaryPath] | None = None + path_class: type[path_types.Path] = epath.Path def v0(self) -> v0_options_lib.FileOptions: """Converts this FileOptions to a v0 FileOptions.""" diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py index 6c98985c5..538e1f27d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py @@ -20,7 +20,6 @@ import dataclasses from typing import Any -from etils import epath from orbax.checkpoint import checkpoint_args from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.futures import future @@ -64,7 +63,7 @@ def __init__(self, handler: handler_types.CheckpointableHandler): async def async_save( self, - directory: epath.Path, + directory: path_types.Path, args: Args, ) -> list[future.Future] | None: async_path = _PathAwaitingCreation( @@ -77,7 +76,7 @@ async def _background_save(): return [future.CommitFuture(_background_save())] - def save(self, directory: epath.Path, *args, **kwargs): + def save(self, directory: path_types.Path, *args, **kwargs): async def async_save(*args, **kwargs): commit_futures = await self.async_save(*args, **kwargs) # pytype: disable=bad-return-type # Futures are already running, so sequential waiting is equivalent to @@ -88,7 +87,9 @@ async def async_save(*args, **kwargs): asyncio_utils.run_sync(async_save(directory, *args, **kwargs)) - def restore(self, directory: epath.Path, args: Args | None = None) -> Any: + def restore( + self, directory: path_types.Path, args: Args | None = None + ) -> Any: abstract_checkpointable = args.checkpointable if args else None async def _synchronous_load(): @@ -99,10 +100,10 @@ async def _synchronous_load(): return asyncio_utils.run_sync(_synchronous_load()) - def metadata(self, directory: epath.Path) -> Any | None: + def metadata(self, directory: path_types.Path) -> Any | None: return asyncio_utils.run_sync(self._handler.metadata(directory)) - def finalize(self, directory: epath.Path) -> None: + def finalize(self, directory: path_types.Path) -> None: pass def close(self): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py index 79fc58eff..1043ff8ef 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py @@ -21,7 +21,6 @@ from typing import Any, Awaitable from absl import logging -from etils import epath from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost @@ -57,7 +56,7 @@ def _subdirs(directory: path_types.Path, *, limit: int = 3) -> list[str]: ) -def _existing_checkpointable_names(directory: epath.Path) -> set[str]: +def _existing_checkpointable_names(directory: path_types.Path) -> set[str]: return {p.name for p in directory.iterdir() if p.is_dir()} diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index c7d6b7b2b..1fe8aa60a 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -18,7 +18,6 @@ from typing import Any, Awaitable from absl import logging -from etils import epath from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import temporary_paths from orbax.checkpoint.experimental.v1._src.context import context as context_lib @@ -70,7 +69,8 @@ def is_orbax_checkpoint(path: path_types.PathLike) -> bool: Returns: True if the path is an Orbax checkpoint, False otherwise. """ - path = epath.Path(path) + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) try: OrbaxLayout(path).validate() return True diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/pytorch_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/pytorch_layout.py index a5ceae195..7296c7cb2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/pytorch_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/pytorch_layout.py @@ -21,19 +21,19 @@ from typing import Any, Awaitable import zipfile -from etils import epath import jax import numpy as np from orbax.checkpoint._src.path import async_path +from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types -from orbax.checkpoint.experimental.v1._src.path import types +from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types CheckpointLayout = checkpoint_layout.CheckpointLayout InvalidLayoutError = checkpoint_layout.InvalidLayoutError -Path = types.Path +Path = path_types.Path _PICKLE_FILENAME = "data.pkl" @@ -228,7 +228,7 @@ def _read_zip_contents_sync(path: Path) -> tuple[bytes, dict[str, bytes]]: if name.endswith(_PICKLE_FILENAME): pickle_bytes = zf.read(name) else: - p = epath.Path(name) + p = context_lib.get_context().file_options.path_class(name) if p.parent.name == _STORAGE_PREFIX: storage_id = p.name # Accommodate different key formats. Some PyTorch versions may use diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index cefc2598e..22b5dd3a3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -14,7 +14,7 @@ """Registry for checkpoint layouts.""" -from etils import epath +from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import orbax_layout @@ -42,7 +42,8 @@ async def get_checkpoint_layout( InvalidLayoutError: If the path is not a valid checkpoint for any registered layout, with details from each layout's validation attempt. """ - path = epath.Path(path) + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) match layout_enum: case CheckpointLayoutEnum.ORBAX: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index f155af5ba..5b9f5ac08 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -19,7 +19,6 @@ from typing import Any from absl import logging -from etils import epath from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint.experimental.v1._src.context import context as context_lib @@ -104,10 +103,11 @@ def load_pytree( start_time = time.time() asyncio_utils.maybe_apply_nest_asyncio() logging.info('Loading checkpoint from %s.', path) - path = epath.Path(path) + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) layout, checkpointable_name = asyncio.run( layout_registry.get_checkpoint_layout_pytree( - path, context_lib.get_context().checkpoint_layout, checkpointable_name + path, ctx.checkpoint_layout, checkpointable_name ) ) return _load_checkpointables_impl( @@ -173,11 +173,10 @@ def load_checkpointables( start_time = time.time() asyncio_utils.maybe_apply_nest_asyncio() logging.info('Loading checkpoint from %s.', path) - path = epath.Path(path) + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) layout = asyncio.run( - layout_registry.get_checkpoint_layout( - path, context_lib.get_context().checkpoint_layout - ) + layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout) ) return _load_checkpointables_impl( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py index de9a130fc..ce258f604 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py @@ -17,7 +17,6 @@ import asyncio from typing import Any -from etils import epath from orbax.checkpoint.experimental.v1 import errors from orbax.checkpoint.experimental.v1._src.context import context as context_lib import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import @@ -88,10 +87,11 @@ def _get_abstract_array(arr): A `CheckpointMetadata[PyTreeMetadata]` object. """ asyncio_utils.maybe_apply_nest_asyncio() - path = epath.Path(path) + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) layout, checkpointable_name = asyncio.run( layout_registry.get_checkpoint_layout_pytree( - path, context_lib.get_context().checkpoint_layout, checkpointable_name + path, ctx.checkpoint_layout, checkpointable_name ) ) metadata = _checkpointables_metadata_impl(layout) @@ -131,12 +131,10 @@ def checkpointables_metadata( A `CheckpointMetadata[dict[str, Any]]` object. """ asyncio_utils.maybe_apply_nest_asyncio() - path = epath.Path(path) - context = context_lib.get_context() + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) layout = asyncio.run( - layout_registry.get_checkpoint_layout( - path, context.checkpoint_layout - ) + layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout) ) return _checkpointables_metadata_impl(layout) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/merging.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/merging.py index a28e8c58b..bf4fab6bc 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/merging.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/merging.py @@ -16,11 +16,11 @@ from typing import Any, NamedTuple, TypeVar -from etils import epath import jax from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.tree import structure_utils from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading +from orbax.checkpoint.experimental.v1._src.path import types as path_types PyTree = Any T = TypeVar('T') @@ -34,7 +34,7 @@ class SourceIndexedMetadata(NamedTuple): def merge_ckpt_metadata( - ckpts_to_merge: list[epath.Path], + ckpts_to_merge: list[path_types.Path], ) -> PyTreeOf[SourceIndexedMetadata]: """Merges metadata from multiple checkpoints, labeling each leaf with its source index.""" labeled_ckpt_metadata: list[PyTreeOf[SourceIndexedMetadata]] = [ diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/path.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/path.py index f5258cf7f..9183463e7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/path.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/path.py @@ -14,27 +14,23 @@ """Utility functions for partial saving paths.""" -from etils import epath from orbax.checkpoint.experimental.v1._src.path import types as path_types PARTIAL_SAVE_SUFFIX = '.partial_save' def is_partial_save_path( - path: path_types.PathLike, allow_tmp_dir: bool = False + path: path_types.Path, allow_tmp_dir: bool = False ) -> bool: - path = epath.Path(path) if allow_tmp_dir: return PARTIAL_SAVE_SUFFIX in path.name else: return path.name.endswith(PARTIAL_SAVE_SUFFIX) -def add_partial_save_suffix(path: path_types.PathLike) -> path_types.Path: - path = epath.Path(path) +def add_partial_save_suffix(path: path_types.Path) -> path_types.Path: return path.parent / (path.name + PARTIAL_SAVE_SUFFIX) -def remove_partial_save_suffix(path: path_types.PathLike) -> path_types.Path: - path = epath.Path(path) +def remove_partial_save_suffix(path: path_types.Path) -> path_types.Path: return path.parent / path.name.removesuffix(PARTIAL_SAVE_SUFFIX) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index ea5bfee0f..7543616bc 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -14,7 +14,6 @@ """Defines free-function interface for partial saving and finalizing.""" -from etils import epath from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.path import async_path from orbax.checkpoint.experimental.v1._src.context import context as context_lib @@ -185,7 +184,9 @@ def save_pytree_async( FileExistsError: If a finalized checkpoint already exists at `path`. To overwrite, it must be deleted first. """ - if epath.Path(path).exists(): + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) + if path.exists(): raise FileExistsError(f'Finalized checkpoint already exists at {path}.') return execution.save_checkpointables_impl( @@ -238,8 +239,7 @@ def finalize(path: path_types.PathLike) -> None: This can happen if `ocp.partial.save_*` was not called first. """ context = context_lib.get_context() - - path = epath.Path(path) + path = context.file_options.path_class(path) if partial_path_lib.is_partial_save_path(path): final_path = partial_path_lib.remove_partial_save_suffix(path) partial_path = path diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/path/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/path/types.py index d6447a151..85031edeb 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/path/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/path/types.py @@ -21,8 +21,9 @@ from etils import epath + Path = epath.Path -PathLike = epath.PathLike +PathLike = Path | str @typing.runtime_checkable diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index 34adae6cf..22c3a3cf7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -21,7 +21,6 @@ import uuid from absl import logging -from etils import epath import jax import numpy as np from orbax.checkpoint._src.futures import future @@ -266,7 +265,7 @@ def create_save_response( ) -def check_directory_consistency(directory: epath.PathLike): +def check_directory_consistency(directory: path_types.PathLike): """Raises error if directory paths are not consistent across processes.""" if multihost.process_count() <= 1: return @@ -310,8 +309,8 @@ def save_checkpointables_impl( asyncio_utils.maybe_apply_nest_asyncio() context = context_lib.get_context() + path = context.file_options.path_class(path) check_directory_consistency(path) - path = epath.Path(path) path_exists = path.exists() if partial_save else False # Prevent internal mutation from affecting the caller. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py index eac5ddae5..7ae370393 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py @@ -18,7 +18,6 @@ from typing import Any, Generic, Sequence, Tuple, Type, cast, get_args from absl import logging -from etils import epath import jax from jax import tree_util as jtu import jax.numpy as jnp @@ -254,7 +253,7 @@ def _validate_deserialization_infos( def _convert_v1_metadata_to_v0( name: str, - directory: epath.Path | None, + directory: path_types.Path | None, metadata: types.AbstractShardedArray, ) -> value_metadata.Metadata: """Wrap V1 metadata into V0Metadata.""" diff --git a/checkpoint/orbax/checkpoint/version.py b/checkpoint/orbax/checkpoint/version.py index 01b90483f..d25afc72f 100644 --- a/checkpoint/orbax/checkpoint/version.py +++ b/checkpoint/orbax/checkpoint/version.py @@ -17,7 +17,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. # LINT.IfChange -__version__ = '0.11.28' +__version__ = '0.11.30' # LINT.ThenChange(//depot//orbax/checkpoint/CHANGELOG.md) diff --git a/export/orbax/export/constants.py b/export/orbax/export/constants.py index c1fbeb91d..68128d942 100644 --- a/export/orbax/export/constants.py +++ b/export/orbax/export/constants.py @@ -97,8 +97,13 @@ class ExportModelType(enum.Enum): # Mesh for the model. JAX_MESH = 'jax_mesh' -# Whether to strip XLA flags from the model. -STRIP_XLA_FLAGS = 'strip_xla_flags' +# Whether to persist XLA flags in the model. +PERSIST_XLA_FLAGS = 'persist_xla_flags' + +# Whether to enable bf16 optimization for the model. +# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2): +# Adding filter flags once the flag is applied to the pre/post processors. +ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization' ################################################################################ # Proto field names diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index 5ba509c51..3d076ac59 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]: tensorflow_module.TensorFlowModule, self._export_module ).jax2tf_kwargs_map + @property + def jax2obm_kwargs(self) -> Mapping[str, Any]: + """Returns the jax2obm_kwargs.""" + if self._export_version == constants.ExportModelType.TF_SAVEDMODEL: + raise TypeError( + 'jax2obm_kwargs is not implemented for export version' + ' ExportModelType.TF_SAVEDMODEL.' + ) + return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs + @property def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]: """Returns the polymorphic shapes.""" diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index 3debc78d6..378d66983 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -73,34 +73,40 @@ def __init__( ) # It is possible for jax2obm_kwargs to be None if the key is present. - if not jax2obm_kwargs: - jax2obm_kwargs = {} + self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {} + + enable_bf16_optimization = self.jax2obm_kwargs.get( + constants.ENABLE_BF16_OPTIMIZATION, False + ) + + if enable_bf16_optimization: + mapped_apply_fn = utils.to_bfloat16(apply_fn) + self._params_args_spec = utils.to_bfloat16(params) + else: + mapped_apply_fn = apply_fn + self._params_args_spec = params ( self._apply_fn_map, self.input_polymorphic_shape_map, self.input_polymorphic_shape_symbol_values_map, ) = self._normalize_apply_fn_map( - apply_fn, + mapped_apply_fn, input_polymorphic_shape, input_polymorphic_shape_symbol_values, ) - self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None) - self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False) + self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None) - self.polymorphic_constraints = self._maybe_set_polymorphic_constraints( - jax2obm_kwargs - ) + self.polymorphic_constraints = self._maybe_set_polymorphic_constraints() self._native_serialization_platforms = utils.get_lowering_platforms( - jax2obm_kwargs + self.jax2obm_kwargs ) - self._params_args_spec = params self._checkpoint_path: str = None # Set the Orbax checkpoint path if provided in the jax2obm_kwargs. - self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs) - self._load_all_checkpoint_weights = jax2obm_kwargs.get( + self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs) + self._load_all_checkpoint_weights = self.jax2obm_kwargs.get( constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False ) @@ -203,15 +209,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs): else constants.DEFAULT_WEIGHTS_NAME ) - def _maybe_set_polymorphic_constraints( - self, jax2obm_kwargs - ) -> Mapping[str, Sequence[Any]]: + def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]: """Sets the polymorphic constraints for the model. - Args: - jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion - library. - Returns: A mapping of function name to polymorphic constraints. @@ -221,7 +221,7 @@ def _maybe_set_polymorphic_constraints( size of the apply_fn_map or if a key in apply_fn_map is not found in polymorphic_constraints. """ - polymorphic_constraints = jax2obm_kwargs.get( + polymorphic_constraints = self.jax2obm_kwargs.get( constants.POLYMORPHIC_CONSTRAINTS, None ) if not isinstance(polymorphic_constraints, Mapping): @@ -300,3 +300,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]: def jax_methods(self) -> Mapping[str, Callable[..., Any]]: """Named methods in JAX context for validation.""" raise NotImplementedError('apply_fn_map is not implemented for ObmModule.') + + @property + def jax2obm_kwargs(self) -> Mapping[str, Any]: + """Returns the jax2obm_kwargs.""" + return self._jax2obm_kwargs diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 62c6b73a5..c5555d4ad 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns( jax2obm_kwargs=jax2obm_kwargs, ) + @parameterized.named_parameters( + {'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True}, + {'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False}, + ) + def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization): + params_spec = { + 'w': jax.ShapeDtypeStruct((2, 2), jnp.float32), + 'b': jax.ShapeDtypeStruct((2,), jnp.float32), + } + input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'} + + module = obm_module.ObmModule( + params=params_spec, + apply_fn=_linear, + input_polymorphic_shape=input_spec, + jax2obm_kwargs={ + constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization + }, + ) + + expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32 + with self.subTest('test_weights_w_dtype'): + self.assertEqual(module.model_params['w'].dtype, expected_dtype) + with self.subTest('test_weights_b_dtype'): + self.assertEqual(module.model_params['b'].dtype, expected_dtype) + if __name__ == '__main__': absltest.main() diff --git a/export/orbax/export/obm_configs.py b/export/orbax/export/obm_configs.py index afd93ba8e..bcf6db01c 100644 --- a/export/orbax/export/obm_configs.py +++ b/export/orbax/export/obm_configs.py @@ -51,6 +51,20 @@ class BatchPaddingPolicy(enum.Enum): MINIMIZE_TPU_COST_PER_REQUEST = "minimize_tpu_cost_per_request" +@enum.unique +class MixedPriorityBatchingPolicy(enum.Enum): + """The mixed priority batch policy for the batch scheduler. + + Options: + LOW_PRIORITY_PADDING_WITH_MAX_BATCH_SIZE: Pad low priority inputs up to the + max_batch_size. + """ + # TODO: b/417977029 - Add LOW_PRIORITY_PADDING_WITH_NEXT_ALLOWED_BATCH_SIZE, + # PRIORITY_MERGE, PRIORITY_ISOLATION. + LOW_PRIORITY_PADDING_WITH_MAX_BATCH_SIZE = "low_priority_padding_with_max_batch_size" + + + # LINT.ThenChange(//depot//orbax/export/obm_export.py) @@ -105,6 +119,8 @@ class BatchOptions: batch_padding_policy: The batch padding policy for the batch scheduler. Default is PAD_UP. low_priority_batch_options: The batch options for low priority inputs. + mixed_priority_batching_policy: The mixed priority batching policy for the + batch scheduler. Default is LOW_PRIORITY_PADDING_WITH_MAX_BATCH_SIZE. """ batch_component: BatchComponent @@ -116,6 +132,9 @@ class BatchOptions: disable_large_batch_splitting: bool = False batch_padding_policy: BatchPaddingPolicy = BatchPaddingPolicy.PAD_UP low_priority_batch_options: LowPriorityBatchOptions | None = None + mixed_priority_batching_policy: MixedPriorityBatchingPolicy = ( + MixedPriorityBatchingPolicy.LOW_PRIORITY_PADDING_WITH_MAX_BATCH_SIZE + ) def _validate_batch_options( self, diff --git a/export/orbax/export/serving_config.py b/export/orbax/export/serving_config.py index 3879b5bae..6a722c021 100644 --- a/export/orbax/export/serving_config.py +++ b/export/orbax/export/serving_config.py @@ -102,6 +102,18 @@ class ServingConfig: # } preprocess_output_passthrough_enabled: bool = False + def __post_init__(self): + if self.tf_preprocessor and self.preprocessors: + raise ValueError( + '`tf_preprocessor` and `preprocessors` cannot be set at the same' + ' time.' + ) + if self.tf_postprocessor and self.postprocessors: + raise ValueError( + '`tf_postprocessor` and `postprocessors` cannot be set at the same' + ' time.' + ) + def get_signature_keys(self) -> Sequence[str]: if isinstance(self.signature_key, str): return [self.signature_key] diff --git a/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto b/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto index e4c74e80a..bfeceaf9b 100644 --- a/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto +++ b/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto @@ -92,6 +92,7 @@ name_to_pipeline { allowed_batch_sizes: 16 max_enqueued_batches: 300 } + mixed_priority_batching_policy: LOW_PRIORITY_PADDING_WITH_MAX_BATCH_SIZE } } } diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index 576d3dd13..8e27b074d 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -18,6 +18,7 @@ import dataclasses import functools import inspect +import jax.numpy as jnp import os from typing import Any, Callable, List, Optional, Tuple, Union @@ -532,3 +533,40 @@ def get_lowering_platforms( ) return native_serialization_platforms + + +def to_bfloat16(x: Any) -> Any: + """Helper to convert leaves of a pytree to bfloat16. + + It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with + a floating point `dtype`. + + Args: + x: The input pytree to convert. + + Returns: + The input `x` with floating point values converted to `jnp.bfloat16`. + """ + + def _to_bfloat16_leaf(x: Any) -> Any: + if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype( + x.dtype, jnp.floating + ): + return jax.ShapeDtypeStruct( + x.shape, + jnp.bfloat16, + sharding=x.sharding, + ) + if isinstance(x, jax.ShapeDtypeStruct): + return x + if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating): + return x.astype(jnp.bfloat16) + if isinstance(x, float): + return jnp.bfloat16(x) + return x + + flattened_x, treedef = jax.tree_util.tree_flatten(x) + flattened_y = [ + jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x + ] + return jax.tree_util.tree_unflatten(treedef, flattened_y) diff --git a/model/orbax/experimental/model/cli/README.md b/model/orbax/experimental/model/cli/README.md index ddff30620..26c7aff55 100644 --- a/model/orbax/experimental/model/cli/README.md +++ b/model/orbax/experimental/model/cli/README.md @@ -2,6 +2,7 @@ A command-line tool for inspecting Orbax models. + ## Examples To inspect the model: diff --git a/model/orbax/experimental/model/core/python/compile_options_util.py b/model/orbax/experimental/model/core/python/compile_options_util.py index 9aefc7eab..356bc1538 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util.py +++ b/model/orbax/experimental/model/core/python/compile_options_util.py @@ -39,7 +39,7 @@ def generate_tpu_compilation_env( - xla_flags: Sequence[str] | None = None, + xla_flags_overrides: Sequence[str] | None = None, ) -> xla_pb2.CompilationEnvironmentsProto: """Generates the TPU compilation environment.""" # Get default TPU compilation environment. @@ -48,9 +48,9 @@ def generate_tpu_compilation_env( tpu_compilation_env_str ) # Override with supplied XLA flags if any is provided. - if xla_flags: + if xla_flags_overrides: parsed_flags = {} - for flag in xla_flags: + for flag in xla_flags_overrides: if not flag.startswith('--'): raise ValueError( f"Flag {flag} does not start with '--'. All flags must be in the" @@ -115,7 +115,7 @@ def generate_xla_compile_options( native_serialization_platforms: Sequence[str] | None, xla_flags_per_platform: Mapping[str, Sequence[str] | None], jax_mesh: jax.sharding.Mesh | None = None, - strip_xla_flags: bool = False, + persist_xla_flags: bool = True, ) -> manifest_pb2.CompileOptionsProtoMap: """Sets the XLA compilation options. @@ -127,7 +127,7 @@ def generate_xla_compile_options( which will be used to override the default XLA compilation flags. jax_mesh: The JAX mesh used for sharding. If None, the compile options will be set for a default single-replica. - strip_xla_flags: Whether to strip XLA flags from the compile options. + persist_xla_flags: Whether to persist XLA flags in the compile options. Returns: A `CompileOptionsProtoMap` containing the XLA compilation options per @@ -140,6 +140,9 @@ def generate_xla_compile_options( ValueError: If a platform is provided for XLA flags which is not provided in the native serialization platforms. ValueError: If the supplied XLA flag overrides cannot be parsed. + ValueError: If `xla_flags` are provided but `persist_xla_flags` is False. + This ensures that the XLA flags are persisted in the compile options, + otherwise they would be lost, leading to unexpected behavior. """ tpu_platform_name = manifest_pb2.Platform.Name( manifest_pb2.Platform.TPU @@ -180,19 +183,44 @@ def generate_xla_compile_options( for platform in platforms: compile_environment = None if platform.lower() == tpu_platform_name: - xla_flags = None + xla_flags_overrides = None if xla_flags_per_platform: - xla_flags = xla_flags_per_platform.get(platform, None) - compile_environment = generate_tpu_compilation_env(xla_flags) + xla_flags_overrides = xla_flags_per_platform.get(platform, None) + _validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags) + compile_environment = generate_tpu_compilation_env(xla_flags_overrides) compile_options_map.map[platform.lower()].CopyFrom( generate_compilation_options(compile_environment, jax_mesh) ) - if strip_xla_flags: + if not persist_xla_flags: for compile_options in compile_options_map.map.values(): compile_options.executable_build_options.comp_envs.Clear() return compile_options_map +def _validate_xla_flags_setting( + xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool +) -> None: + """Validates the XLA flags setting. + + XLA flag overrides are allowed only when XLA flags are required to be + persisted. + + Args: + xla_flags_overrides: A sequence of XLA flags provided for overriding. Can be + None. + persist_xla_flags: A boolean indicating whether the XLA flags should be + persisted in the compile options. + + Raises: + ValueError: If `xla_flags_overrides` is not None but `persist_xla_flags` is + False. + """ + if xla_flags_overrides and not persist_xla_flags: + raise ValueError( + 'persist_xla_flags must be True if xla_flags_overrides are provided.' + ) + + def get_field_for_flag(flag_name: str) -> descriptor.FieldDescriptor: """Gets the protobuf field descriptor for a given flag name.""" if flag_name not in _XLA_FLAG_TO_FIELD_MAP: diff --git a/model/orbax/experimental/model/core/python/compile_options_util_test.py b/model/orbax/experimental/model/core/python/compile_options_util_test.py index aa6bd1270..bdf594a69 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util_test.py +++ b/model/orbax/experimental/model/core/python/compile_options_util_test.py @@ -150,17 +150,21 @@ def test_merge_flags_into_compile_options(self): @parameterized.named_parameters( dict( testcase_name='dict_xla_flags', - xla_flags=[f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()], + xla_flags_overrides=[f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()], expected_env=EXPECTED_ENV, ), dict( testcase_name='no_xla_flags', - xla_flags=None, + xla_flags_overrides=None, expected_env=None, ), ) - def test_generate_tpu_compilation_env(self, xla_flags, expected_env): - env = compile_options_util.generate_tpu_compilation_env(xla_flags=xla_flags) + def test_generate_tpu_compilation_env( + self, xla_flags_overrides, expected_env + ): + env = compile_options_util.generate_tpu_compilation_env( + xla_flags_overrides=xla_flags_overrides + ) self.assertLen(env.environments, 1) actual_env_proto = tpu_comp_env_pb2.TpuCompilationEnvironment() env.environments[0].Unpack(actual_env_proto) @@ -185,7 +189,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self): ' --flag_name=flag_value.', ): compile_options_util.generate_tpu_compilation_env( - xla_flags=[ + xla_flags_overrides=[ '--xla_tpu_memory_bound_loop_optimizer_options=enabled:false', 'xla_tpu_allocate_scoped_vmem_at_same_offset: false', ] @@ -197,6 +201,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self): native_serialization_platforms=None, xla_flags_per_platform=None, expected_platforms=['tpu'], + persist_xla_flags=False, ), dict( testcase_name='no_native_serialization_platforms_with_xla_flags', @@ -205,12 +210,14 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self): 'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()] }, expected_platforms=['tpu'], + persist_xla_flags=True, ), dict( testcase_name='with_native_serialization_platforms_no_xla_flags', native_serialization_platforms=['cpu', 'tpu', 'cuda'], xla_flags_per_platform=None, expected_platforms=['cpu', 'tpu', 'cuda'], + persist_xla_flags=False, ), dict( testcase_name='with_native_serialization_platforms_with_xla_flags', @@ -219,6 +226,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self): 'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()] }, expected_platforms=['cpu', 'tpu', 'cuda'], + persist_xla_flags=True, ), ) def test_generate_xla_compile_options_flags_and_platforms( @@ -226,10 +234,12 @@ def test_generate_xla_compile_options_flags_and_platforms( native_serialization_platforms, xla_flags_per_platform, expected_platforms, + persist_xla_flags, ): compile_options_map = compile_options_util.generate_xla_compile_options( native_serialization_platforms=native_serialization_platforms, xla_flags_per_platform=xla_flags_per_platform, + persist_xla_flags=persist_xla_flags, ) self.assertLen(compile_options_map.map, len(expected_platforms)) @@ -237,7 +247,7 @@ def test_generate_xla_compile_options_flags_and_platforms( self.assertIn(platform, compile_options_map.map) compile_options = compile_options_map.map[platform] - if platform != 'tpu': + if platform != 'tpu' or not persist_xla_flags: self.assertEmpty( compile_options.executable_build_options.comp_envs.environments ) @@ -307,49 +317,18 @@ def test_generate_xla_compile_options_xla_flags_platform_not_in_native_serializa }, ) - @parameterized.named_parameters( - dict(testcase_name='strip_xla_flags_true', strip_xla_flags=True), - dict(testcase_name='strip_xla_flags_false', strip_xla_flags=False), - ) - def test_generate_xla_compile_options_strip_xla_flags(self, strip_xla_flags): - xla_flags_per_platform = { - 'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()] - } - compile_options_map = compile_options_util.generate_xla_compile_options( - native_serialization_platforms=['cpu', 'tpu', 'cuda'], - xla_flags_per_platform=xla_flags_per_platform, - strip_xla_flags=strip_xla_flags, - ) - self.assertLen(compile_options_map.map, 3) - for platform in ['cpu', 'tpu', 'cuda']: - self.assertIn(platform, compile_options_map.map) - compile_options = compile_options_map.map[platform] - - if strip_xla_flags or platform != 'tpu': - self.assertEmpty( - compile_options.executable_build_options.comp_envs.environments - ) - else: - # For TPU platform when not stripping, it should have xla flags. - self.assertLen( - compile_options.executable_build_options.comp_envs.environments, 1 - ) - actual_env_proto = tpu_comp_env_pb2.TpuCompilationEnvironment() - compile_options.executable_build_options.comp_envs.environments[ - 0 - ].Unpack(actual_env_proto) - - expected_env_overrides = EXPECTED_ENV - expected_env_proto = tpu_comp_env_pb2.TpuCompilationEnvironment() - expected_env_proto.ParseFromString( - tpu_comp_env.create_default_tpu_comp_env() - ) - expected_env_proto.MergeFrom(expected_env_overrides) - - self.assertEqual( - text_format.MessageToString(actual_env_proto), - text_format.MessageToString(expected_env_proto), - ) + def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'persist_xla_flags must be True if xla_flags_overrides are provided.', + ): + compile_options_util.generate_xla_compile_options( + native_serialization_platforms=['tpu'], + xla_flags_per_platform={ + 'tpu': [f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()] + }, + persist_xla_flags=False, + ) @parameterized.named_parameters( dict( diff --git a/model/orbax/experimental/model/jax2obm/utils.py b/model/orbax/experimental/model/jax2obm/utils.py index 0e172790e..4bb5956f3 100644 --- a/model/orbax/experimental/model/jax2obm/utils.py +++ b/model/orbax/experimental/model/jax2obm/utils.py @@ -95,7 +95,7 @@ def make_jax_exported_creator( def _aval_dtype(a) -> JaxDType: - assert isinstance(a, jax.core.UnshapedArray) + assert isinstance(a, jax.core.ShapedArray) return a.dtype