Skip to content

Commit 6b8552b

Browse files
Fix [late-directive] error reported by pytype static analyzer in max_utils.py
PiperOrigin-RevId: 796887584
1 parent 69b4349 commit 6b8552b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

MaxText/max_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def device_space():
8686
return jax.memory.Space.Device # pytype: disable=module-attr
8787
else:
8888
# pytype: disable=module-attr
89-
return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access
89+
return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access
90+
# pytype: enable=module-attr
9091

9192

9293
def calculate_total_params_per_chip(params):

0 commit comments

Comments
 (0)