Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions export/orbax/export/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,18 @@ class ExportModelType(enum.Enum):
# Mesh for the model.
JAX_MESH = 'jax_mesh'

# TODO: b/459991985 - Remove this flag and use PERSIST_XLA_FLAGS instead.
# Whether to strip XLA flags from the model.
STRIP_XLA_FLAGS = 'strip_xla_flags'

# Whether to persist XLA flags in the model.
PERSIST_XLA_FLAGS = 'persist_xla_flags'

# Whether to enable bf16 optimization for the model.
# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2):
# Adding filter flags once the flag is applied to the pre/post processors.
ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization'

################################################################################
# Proto field names
################################################################################
Expand Down
10 changes: 10 additions & 0 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
tensorflow_module.TensorFlowModule, self._export_module
).jax2tf_kwargs_map

@property
def jax2obm_kwargs(self) -> Mapping[str, Any]:
"""Returns the jax2obm_kwargs."""
if self._export_version == constants.ExportModelType.TF_SAVEDMODEL:
raise TypeError(
'jax2obm_kwargs is not implemented for export version'
' ExportModelType.TF_SAVEDMODEL.'
)
return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs

@property
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
"""Returns the polymorphic shapes."""
Expand Down
48 changes: 28 additions & 20 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,43 @@ def __init__(
)

# It is possible for jax2obm_kwargs to be None if the key is present.
if not jax2obm_kwargs:
jax2obm_kwargs = {}

self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {}

enable_bf16_optimization = self.jax2obm_kwargs.get(
constants.ENABLE_BF16_OPTIMIZATION, False
)

if enable_bf16_optimization:
mapped_apply_fn = utils.to_bfloat16(apply_fn)
self._params_args_spec = utils.to_bfloat16(params)
else:
mapped_apply_fn = apply_fn
self._params_args_spec = params
(
self._apply_fn_map,
self.input_polymorphic_shape_map,
self.input_polymorphic_shape_symbol_values_map,
) = self._normalize_apply_fn_map(
apply_fn,
mapped_apply_fn,
input_polymorphic_shape,
input_polymorphic_shape_symbol_values,
)

self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None)
self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False)

self.polymorphic_constraints = self._maybe_set_polymorphic_constraints(
jax2obm_kwargs
self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None)
self._strip_xla_flags = self.jax2obm_kwargs.get(
constants.STRIP_XLA_FLAGS, False
)

self.polymorphic_constraints = self._maybe_set_polymorphic_constraints()
self._native_serialization_platforms = utils.get_lowering_platforms(
jax2obm_kwargs
self.jax2obm_kwargs
)
self._params_args_spec = params

self._checkpoint_path: str = None
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
self._load_all_checkpoint_weights = jax2obm_kwargs.get(
self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs)
self._load_all_checkpoint_weights = self.jax2obm_kwargs.get(
constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False
)

Expand Down Expand Up @@ -203,15 +212,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
else constants.DEFAULT_WEIGHTS_NAME
)

def _maybe_set_polymorphic_constraints(
self, jax2obm_kwargs
) -> Mapping[str, Sequence[Any]]:
def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]:
"""Sets the polymorphic constraints for the model.

Args:
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
library.

Returns:
A mapping of function name to polymorphic constraints.

Expand All @@ -221,7 +224,7 @@ def _maybe_set_polymorphic_constraints(
size of the apply_fn_map or if a key in apply_fn_map is not found in
polymorphic_constraints.
"""
polymorphic_constraints = jax2obm_kwargs.get(
polymorphic_constraints = self.jax2obm_kwargs.get(
constants.POLYMORPHIC_CONSTRAINTS, None
)
if not isinstance(polymorphic_constraints, Mapping):
Expand Down Expand Up @@ -300,3 +303,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
"""Named methods in JAX context for validation."""
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')

@property
def jax2obm_kwargs(self) -> Mapping[str, Any]:
"""Returns the jax2obm_kwargs."""
return self._jax2obm_kwargs
26 changes: 26 additions & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns(
jax2obm_kwargs=jax2obm_kwargs,
)

@parameterized.named_parameters(
{'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True},
{'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False},
)
def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
params_spec = {
'w': jax.ShapeDtypeStruct((2, 2), jnp.float32),
'b': jax.ShapeDtypeStruct((2,), jnp.float32),
}
input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'}

module = obm_module.ObmModule(
params=params_spec,
apply_fn=_linear,
input_polymorphic_shape=input_spec,
jax2obm_kwargs={
constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization
},
)

expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32
with self.subTest('test_weights_w_dtype'):
self.assertEqual(module.model_params['w'].dtype, expected_dtype)
with self.subTest('test_weights_b_dtype'):
self.assertEqual(module.model_params['b'].dtype, expected_dtype)


if __name__ == '__main__':
absltest.main()
38 changes: 38 additions & 0 deletions export/orbax/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import functools
import inspect
import jax.numpy as jnp
import os
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -532,3 +533,40 @@ def get_lowering_platforms(
)

return native_serialization_platforms


def to_bfloat16(x: Any) -> Any:
"""Helper to convert leaves of a pytree to bfloat16.

It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with
a floating point `dtype`.

Args:
x: The input pytree to convert.

Returns:
The input `x` with floating point values converted to `jnp.bfloat16`.
"""

def _to_bfloat16_leaf(x: Any) -> Any:
if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype(
x.dtype, jnp.floating
):
return jax.ShapeDtypeStruct(
x.shape,
jnp.bfloat16,
sharding=x.sharding,
)
if isinstance(x, jax.ShapeDtypeStruct):
return x
if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating):
return x.astype(jnp.bfloat16)
if isinstance(x, float):
return jnp.bfloat16(x)
return x

flattened_x, treedef = jax.tree_util.tree_flatten(x)
flattened_y = [
jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x
]
return jax.tree_util.tree_unflatten(treedef, flattened_y)
1 change: 1 addition & 0 deletions model/orbax/experimental/model/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

A command-line tool for inspecting Orbax models.


## Examples

To inspect the model:
Expand Down
Loading