Skip to content

Commit baf2fdd

Browse files
jerryxyjOrbax Authors
authored andcommitted
Add bfloat16 flags and support bfloat16 optimization for native JAX function.
PiperOrigin-RevId: 832034030
1 parent d966ddf commit baf2fdd

File tree

6 files changed

+112
-20
lines changed

6 files changed

+112
-20
lines changed

export/orbax/export/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,18 @@ class ExportModelType(enum.Enum):
9797
# Mesh for the model.
9898
JAX_MESH = 'jax_mesh'
9999

100+
# TODO: b/459991985 - Remove this flag and use PERSIST_XLA_FLAGS instead.
100101
# Whether to strip XLA flags from the model.
101102
STRIP_XLA_FLAGS = 'strip_xla_flags'
102103

104+
# Whether to persist XLA flags in the model.
105+
PERSIST_XLA_FLAGS = 'persist_xla_flags'
106+
107+
# Whether to enable bf16 optimization for the model.
108+
# TODO_REGEX: b/422170690: (1): Apply this flag to the pre/post processors. (2):
109+
# Adding filter flags once the flag is applied to the pre/post processors.
110+
ENABLE_BF16_OPTIMIZATION = 'enable_bf16_optimization'
111+
103112
################################################################################
104113
# Proto field names
105114
################################################################################

export/orbax/export/jax_module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
197197
tensorflow_module.TensorFlowModule, self._export_module
198198
).jax2tf_kwargs_map
199199

200+
@property
201+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
202+
"""Returns the jax2obm_kwargs."""
203+
if self._export_version == constants.ExportModelType.TF_SAVEDMODEL:
204+
raise TypeError(
205+
'jax2obm_kwargs is not implemented for export version'
206+
' ExportModelType.TF_SAVEDMODEL.'
207+
)
208+
return cast(obm_module.ObmModule, self._export_module).jax2obm_kwargs
209+
200210
@property
201211
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
202212
"""Returns the polymorphic shapes."""

export/orbax/export/modules/obm_module.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,43 @@ def __init__(
7373
)
7474

7575
# It is possible for jax2obm_kwargs to be None if the key is present.
76-
if not jax2obm_kwargs:
77-
jax2obm_kwargs = {}
7876

77+
self._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {}
78+
79+
enable_bf16_optimization = self.jax2obm_kwargs.get(
80+
constants.ENABLE_BF16_OPTIMIZATION, False
81+
)
82+
83+
if enable_bf16_optimization:
84+
mapped_apply_fn = utils.to_bfloat16(apply_fn)
85+
self._params_args_spec = utils.to_bfloat16(params)
86+
else:
87+
mapped_apply_fn = apply_fn
88+
self._params_args_spec = params
7989
(
8090
self._apply_fn_map,
8191
self.input_polymorphic_shape_map,
8292
self.input_polymorphic_shape_symbol_values_map,
8393
) = self._normalize_apply_fn_map(
84-
apply_fn,
94+
mapped_apply_fn,
8595
input_polymorphic_shape,
8696
input_polymorphic_shape_symbol_values,
8797
)
8898

89-
self._jax_mesh = jax2obm_kwargs.get(constants.JAX_MESH, None)
90-
self._strip_xla_flags = jax2obm_kwargs.get(constants.STRIP_XLA_FLAGS, False)
91-
92-
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints(
93-
jax2obm_kwargs
99+
self._jax_mesh = self.jax2obm_kwargs.get(constants.JAX_MESH, None)
100+
self._strip_xla_flags = self.jax2obm_kwargs.get(
101+
constants.STRIP_XLA_FLAGS, False
94102
)
103+
104+
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints()
95105
self._native_serialization_platforms = utils.get_lowering_platforms(
96-
jax2obm_kwargs
106+
self.jax2obm_kwargs
97107
)
98-
self._params_args_spec = params
99108

100109
self._checkpoint_path: str = None
101110
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
102-
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
103-
self._load_all_checkpoint_weights = jax2obm_kwargs.get(
111+
self._maybe_set_orbax_checkpoint_path(self.jax2obm_kwargs)
112+
self._load_all_checkpoint_weights = self.jax2obm_kwargs.get(
104113
constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False
105114
)
106115

@@ -203,15 +212,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
203212
else constants.DEFAULT_WEIGHTS_NAME
204213
)
205214

206-
def _maybe_set_polymorphic_constraints(
207-
self, jax2obm_kwargs
208-
) -> Mapping[str, Sequence[Any]]:
215+
def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[Any]]:
209216
"""Sets the polymorphic constraints for the model.
210217
211-
Args:
212-
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
213-
library.
214-
215218
Returns:
216219
A mapping of function name to polymorphic constraints.
217220
@@ -221,7 +224,7 @@ def _maybe_set_polymorphic_constraints(
221224
size of the apply_fn_map or if a key in apply_fn_map is not found in
222225
polymorphic_constraints.
223226
"""
224-
polymorphic_constraints = jax2obm_kwargs.get(
227+
polymorphic_constraints = self.jax2obm_kwargs.get(
225228
constants.POLYMORPHIC_CONSTRAINTS, None
226229
)
227230
if not isinstance(polymorphic_constraints, Mapping):
@@ -300,3 +303,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
300303
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
301304
"""Named methods in JAX context for validation."""
302305
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')
306+
307+
@property
308+
def jax2obm_kwargs(self) -> Mapping[str, Any]:
309+
"""Returns the jax2obm_kwargs."""
310+
return self._jax2obm_kwargs

export/orbax/export/modules/obm_module_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,32 @@ def test_obm_module_multiple_apply_fns(
357357
jax2obm_kwargs=jax2obm_kwargs,
358358
)
359359

360+
@parameterized.named_parameters(
361+
{'testcase_name': 'enable_bf16', 'enable_bf16_optimization': True},
362+
{'testcase_name': 'disable_bf16', 'enable_bf16_optimization': False},
363+
)
364+
def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
365+
params_spec = {
366+
'w': jax.ShapeDtypeStruct((2, 2), jnp.float32),
367+
'b': jax.ShapeDtypeStruct((2,), jnp.float32),
368+
}
369+
input_spec = {constants.DEFAULT_METHOD_KEY: 'b, ...'}
370+
371+
module = obm_module.ObmModule(
372+
params=params_spec,
373+
apply_fn=_linear,
374+
input_polymorphic_shape=input_spec,
375+
jax2obm_kwargs={
376+
constants.ENABLE_BF16_OPTIMIZATION: enable_bf16_optimization
377+
},
378+
)
379+
380+
expected_dtype = jnp.bfloat16 if enable_bf16_optimization else jnp.float32
381+
with self.subTest('test_weights_w_dtype'):
382+
self.assertEqual(module.model_params['w'].dtype, expected_dtype)
383+
with self.subTest('test_weights_b_dtype'):
384+
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
385+
360386

361387
if __name__ == '__main__':
362388
absltest.main()

export/orbax/export/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import inspect
21+
import jax.numpy as jnp
2122
import os
2223
from typing import Any, Callable, List, Optional, Tuple, Union
2324

@@ -532,3 +533,40 @@ def get_lowering_platforms(
532533
)
533534

534535
return native_serialization_platforms
536+
537+
538+
def to_bfloat16(x: Any) -> Any:
539+
"""Helper to convert leaves of a pytree to bfloat16.
540+
541+
It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with
542+
a floating point `dtype`.
543+
544+
Args:
545+
x: The input pytree to convert.
546+
547+
Returns:
548+
The input `x` with floating point values converted to `jnp.bfloat16`.
549+
"""
550+
551+
def _to_bfloat16_leaf(x: Any) -> Any:
552+
if isinstance(x, jax.ShapeDtypeStruct) and jnp.issubdtype(
553+
x.dtype, jnp.floating
554+
):
555+
return jax.ShapeDtypeStruct(
556+
x.shape,
557+
jnp.bfloat16,
558+
sharding=x.sharding,
559+
)
560+
if isinstance(x, jax.ShapeDtypeStruct):
561+
return x
562+
if hasattr(x, 'dtype') and jnp.issubdtype(x.dtype, jnp.floating):
563+
return x.astype(jnp.bfloat16)
564+
if isinstance(x, float):
565+
return jnp.bfloat16(x)
566+
return x
567+
568+
flattened_x, treedef = jax.tree_util.tree_flatten(x)
569+
flattened_y = [
570+
jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x
571+
]
572+
return jax.tree_util.tree_unflatten(treedef, flattened_y)

model/orbax/experimental/model/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
A command-line tool for inspecting Orbax models.
44

5+
56
## Examples
67

78
To inspect the model:

0 commit comments

Comments
 (0)