@@ -73,34 +73,43 @@ def __init__(
7373 )
7474
7575 # It is possible for jax2obm_kwargs to be None if the key is present.
76- if not jax2obm_kwargs :
77- jax2obm_kwargs = {}
7876
77+ self ._jax2obm_kwargs = jax2obm_kwargs if jax2obm_kwargs else {}
78+
79+ enable_bf16_optimization = self .jax2obm_kwargs .get (
80+ constants .ENABLE_BF16_OPTIMIZATION , False
81+ )
82+
83+ if enable_bf16_optimization :
84+ mapped_apply_fn = utils .to_bfloat16 (apply_fn )
85+ self ._params_args_spec = utils .to_bfloat16 (params )
86+ else :
87+ mapped_apply_fn = apply_fn
88+ self ._params_args_spec = params
7989 (
8090 self ._apply_fn_map ,
8191 self .input_polymorphic_shape_map ,
8292 self .input_polymorphic_shape_symbol_values_map ,
8393 ) = self ._normalize_apply_fn_map (
84- apply_fn ,
94+ mapped_apply_fn ,
8595 input_polymorphic_shape ,
8696 input_polymorphic_shape_symbol_values ,
8797 )
8898
89- self ._jax_mesh = jax2obm_kwargs .get (constants .JAX_MESH , None )
90- self ._strip_xla_flags = jax2obm_kwargs .get (constants .STRIP_XLA_FLAGS , False )
91-
92- self .polymorphic_constraints = self ._maybe_set_polymorphic_constraints (
93- jax2obm_kwargs
99+ self ._jax_mesh = self .jax2obm_kwargs .get (constants .JAX_MESH , None )
100+ self ._strip_xla_flags = self .jax2obm_kwargs .get (
101+ constants .STRIP_XLA_FLAGS , False
94102 )
103+
104+ self .polymorphic_constraints = self ._maybe_set_polymorphic_constraints ()
95105 self ._native_serialization_platforms = utils .get_lowering_platforms (
96- jax2obm_kwargs
106+ self . jax2obm_kwargs
97107 )
98- self ._params_args_spec = params
99108
100109 self ._checkpoint_path : str = None
101110 # Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
102- self ._maybe_set_orbax_checkpoint_path (jax2obm_kwargs )
103- self ._load_all_checkpoint_weights = jax2obm_kwargs .get (
111+ self ._maybe_set_orbax_checkpoint_path (self . jax2obm_kwargs )
112+ self ._load_all_checkpoint_weights = self . jax2obm_kwargs .get (
104113 constants .LOAD_ALL_CHECKPOINT_WEIGHTS , False
105114 )
106115
@@ -203,15 +212,9 @@ def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
203212 else constants .DEFAULT_WEIGHTS_NAME
204213 )
205214
206- def _maybe_set_polymorphic_constraints (
207- self , jax2obm_kwargs
208- ) -> Mapping [str , Sequence [Any ]]:
215+ def _maybe_set_polymorphic_constraints (self ) -> Mapping [str , Sequence [Any ]]:
209216 """Sets the polymorphic constraints for the model.
210217
211- Args:
212- jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
213- library.
214-
215218 Returns:
216219 A mapping of function name to polymorphic constraints.
217220
@@ -221,7 +224,7 @@ def _maybe_set_polymorphic_constraints(
221224 size of the apply_fn_map or if a key in apply_fn_map is not found in
222225 polymorphic_constraints.
223226 """
224- polymorphic_constraints = jax2obm_kwargs .get (
227+ polymorphic_constraints = self . jax2obm_kwargs .get (
225228 constants .POLYMORPHIC_CONSTRAINTS , None
226229 )
227230 if not isinstance (polymorphic_constraints , Mapping ):
@@ -300,3 +303,8 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
300303 def jax_methods (self ) -> Mapping [str , Callable [..., Any ]]:
301304 """Named methods in JAX context for validation."""
302305 raise NotImplementedError ('apply_fn_map is not implemented for ObmModule.' )
306+
307+ @property
308+ def jax2obm_kwargs (self ) -> Mapping [str , Any ]:
309+ """Returns the jax2obm_kwargs."""
310+ return self ._jax2obm_kwargs
0 commit comments