diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 976ac8dab..1d993a4cd 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -86,7 +86,8 @@ def device_space(): return jax.memory.Space.Device # pytype: disable=module-attr else: # pytype: disable=module-attr - return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access + return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access + # pytype: enable=module-attr def calculate_total_params_per_chip(params):