We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 69b4349 commit 6b8552bCopy full SHA for 6b8552b
MaxText/max_utils.py
@@ -86,7 +86,8 @@ def device_space():
86
return jax.memory.Space.Device # pytype: disable=module-attr
87
else:
88
# pytype: disable=module-attr
89
- return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access
+ return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access
90
+ # pytype: enable=module-attr
91
92
93
def calculate_total_params_per_chip(params):
0 commit comments