diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 78cf33006..5978eba36 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 state, avoiding additional I/O to re-read metadata. - add `support_format` to utils.to_shape_dtype_struct() - Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`. +- Replace usage of `get_json_tpec_read` and delegate functionality to new + function `build_array_read_spec` which constructs and returns an + `ArrayReadSpec`. ## [0.11.28] - 2025-11-06 diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index e43a39118..40876a472 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -763,12 +763,13 @@ async def _async_deserialize( await _validate_non_ocdbt_files(infos, metadata_key) deserialize_ops = [] for info, arg, sharding in zip(infos, args, shardings): - tspec = ts_utils.get_json_tspec_read( + array_read_spec = ts_utils.build_array_read_spec( info, use_ocdbt=use_ocdbt, metadata_key=metadata_key, raise_array_data_missing_error=info.raise_array_data_missing_error, ) + tspec = array_read_spec.json tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg) # set dtype=None to deserialize for random keys @@ -939,19 +940,6 @@ def __init__( def has_dispatcher(self) -> bool: return self._dispatcher is not None - def _get_json_tspec_read( - self, - info: types.ParamInfo, - use_ocdbt: bool, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for reading.""" - return ts_utils.get_json_tspec_read( - info, - use_ocdbt=use_ocdbt, - metadata_key=self._metadata_key, - raise_array_data_missing_error=info.raise_array_data_missing_error, - ) - def typestr(self) -> str: return JAX_ARRAY_TYPE_STR @@ -968,7 +956,13 @@ async def metadata( for info in infos: # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json open_ops.append( ts.open(ts.Spec(tspec), open=True, context=info.ts_context) ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index 826c6c0e7..02a954aa8 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -391,6 +391,45 @@ def _maybe_add_cast_to_write_spec( return array_tspec +class ArrayReadSpec: + """Full TensorStore spec for reading an array.""" + + def __init__( + self, + directory: str, + relative_array_filename: str, + use_zarr3: bool, + *, + use_ocdbt: bool, + metadata_key: str | None = None, + raise_array_data_missing_error: bool = True, + ): + """Builds a TensorStore spec for reading an array.""" + kvstore_tspec = build_kvstore_tspec( + directory, + name=relative_array_filename, + use_ocdbt=use_ocdbt, + process_id=None, + ) + + tspec = { + 'driver': ZARR_VER3 if use_zarr3 else ZARR_VER2, + 'kvstore': kvstore_tspec, + 'recheck_cached_data': False, + 'recheck_cached_metadata': False, + # Raise error if data is missing. + 'fill_missing_data_reads': not raise_array_data_missing_error, + } + if metadata_key is not None: + tspec['metadata_key'] = metadata_key + self._json_spec = tspec + + @property + def json(self) -> JsonSpec: + """Spec to be used to open a TensorStore for reading the array.""" + return self._json_spec + + class ArrayWriteSpec: """Full TensorStore spec for writing an array.""" @@ -677,6 +716,26 @@ def get_json_tspec_write( return tspec +def build_array_read_spec( + info: types.ParamInfo, + *, + use_ocdbt: bool, + metadata_key: str | None = None, + raise_array_data_missing_error: bool = True, +) -> ArrayReadSpec: + """Gets ArrayReadSpec for reading.""" + if info.name is None or info.parent_dir is None: + raise ValueError('Must provide info.name and info.parent_dir.') + return ArrayReadSpec( + directory=info.parent_dir.as_posix(), + relative_array_filename=info.name, + use_zarr3=info.use_zarr3, + use_ocdbt=use_ocdbt, + metadata_key=metadata_key, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + + def build_array_write_spec( info: types.ParamInfo, arg: types.SaveArgs | None = None, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py index c052be72e..aa3420bfa 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py @@ -613,6 +613,60 @@ def test_maybe_cloud_storage(self): self.assertTrue(ts_utils.is_remote_storage(nested_tspec)) +class BuildArrayTSpecForReadTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.directory = self.create_tempdir().full_path + self.param_name = 'params/a' + + self.array_read_spec_constructor = functools.partial( + ts_utils.ArrayReadSpec, + directory=self.directory, + relative_array_filename=self.param_name, + ) + + @parameterized.product( + use_zarr3=(True, False), + use_ocdbt=(True, False), + ) + def test_basic(self, use_zarr3: bool, use_ocdbt: bool): + tspec = self.array_read_spec_constructor( + use_zarr3=use_zarr3, + use_ocdbt=use_ocdbt, + ) + json_spec = tspec.json + self.assertEqual(json_spec['driver'], 'zarr3' if use_zarr3 else 'zarr') + self.assertEqual( + json_spec['kvstore']['driver'], + 'ocdbt' if use_ocdbt else ts_utils.DEFAULT_DRIVER, + ) + self.assertFalse(json_spec['recheck_cached_data']) + self.assertFalse(json_spec['recheck_cached_metadata']) + self.assertFalse(json_spec['fill_missing_data_reads']) + self.assertNotIn('metadata_key', json_spec) + + def test_metadata_key(self): + tspec = self.array_read_spec_constructor( + use_zarr3=False, + use_ocdbt=False, + metadata_key='custom_metadata', + ) + self.assertEqual(tspec.json['metadata_key'], 'custom_metadata') + + @parameterized.parameters(True, False) + def test_fill_missing_data_reads(self, raise_array_data_missing_error): + tspec = self.array_read_spec_constructor( + use_zarr3=False, + use_ocdbt=False, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + self.assertEqual( + tspec.json['fill_missing_data_reads'], + not raise_array_data_missing_error, + ) + + class GetTsContextTest(parameterized.TestCase): @parameterized.product( diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index 53a234d70..18668be86 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -77,39 +77,6 @@ def __init__( self._metadata_key = metadata_key self._override_ocdbt_process_id = ocdbt_process_id - def _get_array_write_spec( - self, - info: types.ParamInfo, - value: np.ndarray, - use_ocdbt: bool, - process_index: Optional[Union[int, str]] = None, - arg: Optional[types.SaveArgs] = None, - ) -> ts_utils.ArrayWriteSpec: - """Gets ArrayWriteSpec for writing.""" - return ts_utils.build_array_write_spec( - info=info, - arg=arg, - global_shape=value.shape, - local_shape=value.shape, - dtype=value.dtype, - use_ocdbt=use_ocdbt, - process_index=process_index, - metadata_key=self._metadata_key, - ) - - def _get_json_tspec_read( - self, - info: types.ParamInfo, - use_ocdbt: bool, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for reading.""" - return ts_utils.get_json_tspec_read( - info, - use_ocdbt=use_ocdbt, - metadata_key=self._metadata_key, - raise_array_data_missing_error=info.raise_array_data_missing_error, - ) - def typestr(self) -> str: return 'np.ndarray' @@ -120,7 +87,13 @@ async def metadata( for info in infos: # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json open_ops.append( ts.open(ts.Spec(tspec), open=True, context=info.ts_context) ) @@ -149,15 +122,18 @@ async def _background_serialize( """Serializes numpy arrays in a background thread.""" write_coros = [] for value, info, arg in zip(values, infos, args): - array_write_spec = self._get_array_write_spec( - info, - value, + array_write_spec = ts_utils.build_array_write_spec( + info=info, + arg=arg, + global_shape=value.shape, + local_shape=value.shape, + dtype=value.dtype, use_ocdbt=info.is_ocdbt_checkpoint, process_index=ocdbt_utils.get_process_index_for_subdir( use_ocdbt=info.is_ocdbt_checkpoint, override_ocdbt_process_id=self._override_ocdbt_process_id, ), - arg=arg, + metadata_key=self._metadata_key, ) tspec = array_write_spec.json if logging.vlog_is_on(1): @@ -205,7 +181,13 @@ async def deserialize( ) # Use OCDBT flag from the existing checkpoint. use_ocdbt = info.is_ocdbt_checkpoint - tspec = self._get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec( + info, + use_ocdbt=use_ocdbt, + metadata_key=self._metadata_key, + raise_array_data_missing_error=info.raise_array_data_missing_error, + ) + tspec = array_read_spec.json tspec = ts_utils.get_cast_tspec_deserialize(tspec, arg) if logging.vlog_is_on(1): diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py b/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py index 6e78304d0..ff57fb24f 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging.py @@ -101,7 +101,8 @@ async def open_tensorstore( use_zarr3=use_zarr3, ts_context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), ) - tspec = ts_utils.get_json_tspec_read(info, use_ocdbt=use_ocdbt) + array_read_spec = ts_utils.build_array_read_spec(info, use_ocdbt=use_ocdbt) + tspec = array_read_spec.json return await ts.open( ts.Spec(tspec), read=True, diff --git a/export/orbax/export/constants.py b/export/orbax/export/constants.py index c1fbeb91d..7cb277965 100644 --- a/export/orbax/export/constants.py +++ b/export/orbax/export/constants.py @@ -97,9 +97,18 @@ class ExportModelType(enum.Enum): # Mesh for the model. JAX_MESH = 'jax_mesh' +# TODO: b/459991985 - Remove this flag and use PERSIST_XLA_FLAGS instead. # Whether to strip XLA flags from the model. STRIP_XLA_FLAGS = 'strip_xla_flags' +# Whether to persist XLA flags in the model. +PERSIST_XLA_FLAGS = 'persist_xla_flags' + +# Whether to enable bf16 optimization for the model. +# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2): +# Adding filter flags once the flag is applied to the pre/post processors. +ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization' + ################################################################################ # Proto field names ################################################################################ diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index 5ba509c51..3d076ac59 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]: tensorflow_module.TensorFlowModule, self._export_module ).jax2tf_kwargs_map + @property + def jax2obm_kwargs(self) -> Mapping[str, Any]: + """Returns the jax2obm_kwargs.""" + if self._export_version == constants.ExportModelType.TF_SAVEDMODEL: + raise TypeError( + 'jax2obm_kwargs is not implemented for export version' + ' ExportModelType.TF_SAVEDMODEL.' + ) + return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs + @property def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]: """Returns the polymorphic shapes.""" diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index 3debc78d6..d38aed98c 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -73,34 +73,43 @@ def __init__( ) # It is possible for jax2obm_kwargs to be None if the key is present. - if not jax2obm_kwargs: - jax2obm_kwargs = {} + self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {} + + enable_bf16_optimization = self.jax2obm_kwargs.get( + constants.ENABLE_BF16_OPTIMIZATION, False + ) + + if enable_bf16_optimization: + mapped_apply_fn = utils.to_bfloat16(apply_fn) + self._params_args_spec = utils.to_bfloat16(params) + else: + mapped_apply_fn = apply_fn + self._params_args_spec = params ( self._apply_fn_map, self.input_polymorphic_shape_map, self.input_polymorphic_shape_symbol_values_map, ) = self._normalize_apply_fn_map( - apply_fn, + mapped_apply_fn, input_polymorphic_shape, input_polymorphic_shape_symbol_values, ) - self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None) - self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False) - - self.polymorphic_constraints = self._maybe_set_polymorphic_constraints( - jax2obm_kwargs + self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None) + self._strip_xla_flags = self.jax2obm_kwargs.get( + constants.STRIP_XLA_FLAGS, False ) + + self.polymorphic_constraints = self._maybe_set_polymorphic_constraints() self._native_serialization_platforms = utils.get_lowering_platforms( - jax2obm_kwargs + self.jax2obm_kwargs ) - self._params_args_spec = params self._checkpoint_path: str = None # Set the Orbax checkpoint path if provided in the jax2obm_kwargs. - self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs) - self._load_all_checkpoint_weights = jax2obm_kwargs.get( + self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs) + self._load_all_checkpoint_weights = self.jax2obm_kwargs.get( constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False ) @@ -203,15 +212,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs): else constants.DEFAULT_WEIGHTS_NAME ) - def _maybe_set_polymorphic_constraints( - self, jax2obm_kwargs - ) -> Mapping[str, Sequence[Any]]: + def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]: """Sets the polymorphic constraints for the model. - Args: - jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion - library. - Returns: A mapping of function name to polymorphic constraints. @@ -221,7 +224,7 @@ def _maybe_set_polymorphic_constraints( size of the apply_fn_map or if a key in apply_fn_map is not found in polymorphic_constraints. """ - polymorphic_constraints = jax2obm_kwargs.get( + polymorphic_constraints = self.jax2obm_kwargs.get( constants.POLYMORPHIC_CONSTRAINTS, None ) if not isinstance(polymorphic_constraints, Mapping): @@ -300,3 +303,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]: def jax_methods(self) -> Mapping[str, Callable[..., Any]]: """Named methods in JAX context for validation.""" raise NotImplementedError('apply_fn_map is not implemented for ObmModule.') + + @property + def jax2obm_kwargs(self) -> Mapping[str, Any]: + """Returns the jax2obm_kwargs.""" + return self._jax2obm_kwargs diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 62c6b73a5..c5555d4ad 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns( jax2obm_kwargs=jax2obm_kwargs, ) + @parameterized.named_parameters( + {'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True}, + {'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False}, + ) + def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization): + params_spec = { + 'w': jax.ShapeDtypeStruct((2, 2), jnp.float32), + 'b': jax.ShapeDtypeStruct((2,), jnp.float32), + } + input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'} + + module = obm_module.ObmModule( + params=params_spec, + apply_fn=_linear, + input_polymorphic_shape=input_spec, + jax2obm_kwargs={ + constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization + }, + ) + + expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32 + with self.subTest('test_weights_w_dtype'): + self.assertEqual(module.model_params['w'].dtype, expected_dtype) + with self.subTest('test_weights_b_dtype'): + self.assertEqual(module.model_params['b'].dtype, expected_dtype) + if __name__ == '__main__': absltest.main() diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index 576d3dd13..8e27b074d 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -18,6 +18,7 @@ import dataclasses import functools import inspect +import jax.numpy as jnp import os from typing import Any, Callable, List, Optional, Tuple, Union @@ -532,3 +533,40 @@ def get_lowering_platforms( ) return native_serialization_platforms + + +def to_bfloat16(x: Any) -> Any: + """Helper to convert leaves of a pytree to bfloat16. + + It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with + a floating point `dtype`. + + Args: + x: The input pytree to convert. + + Returns: + The input `x` with floating point values converted to `jnp.bfloat16`. + """ + + def _to_bfloat16_leaf(x: Any) -> Any: + if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype( + x.dtype, jnp.floating + ): + return jax.ShapeDtypeStruct( + x.shape, + jnp.bfloat16, + sharding=x.sharding, + ) + if isinstance(x, jax.ShapeDtypeStruct): + return x + if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating): + return x.astype(jnp.bfloat16) + if isinstance(x, float): + return jnp.bfloat16(x) + return x + + flattened_x, treedef = jax.tree_util.tree_flatten(x) + flattened_y = [ + jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x + ] + return jax.tree_util.tree_unflatten(treedef, flattened_y) diff --git a/model/orbax/experimental/model/cli/README.md b/model/orbax/experimental/model/cli/README.md index ddff30620..26c7aff55 100644 --- a/model/orbax/experimental/model/cli/README.md +++ b/model/orbax/experimental/model/cli/README.md @@ -2,6 +2,7 @@ A command-line tool for inspecting Orbax models. + ## Examples To inspect the model: