Skip to content

Provide MaxText axes to cudnn_flash_te to correctly perform dbias reduction if required #2221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Calling jax.device_count here prevents a "TPU platform already registered" error.
# See github.com/google/maxtext/issues/20 for more

from contextlib import contextmanager
from typing import Any, Sequence
import datetime
import functools
Expand Down Expand Up @@ -754,9 +755,34 @@ def run(config, recorder, diagnostic_config):
train_loop(config, recorder)


@contextmanager
def transformer_engine_context_or_noop():
"""If TransformerEngine is available, this context manager will provide the library with MaxText-specific details needed for correcct operation.

If TransformerEngine is not available, this is a No-Op and does not add any context.

If the transformer_engine package is available but TransformerEngine is not used in MaxText, this will still be a No-Op
as this context's data is only used if TransformerEngine modules, such as attention=cudnn_flash_te are called."""
try:
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
# Inform TransformerEngine of MaxText's physical mesh resources.
mesh_resource = MeshResource(
dp_resource="data",
tp_resource="tensor",
fsdp_resource="fsdp",
pp_resource=None,
cp_resource="context",
)
with global_shard_guard(mesh_resource):
yield
except ImportError:
yield


def main(argv: Sequence[str]) -> None:
config, recorder, diagnostic_config = initialize(argv)
run(config, recorder, diagnostic_config)
with transformer_engine_context_or_noop():
config, recorder, diagnostic_config = initialize(argv)
run(config, recorder, diagnostic_config)


if __name__ == "__main__":
Expand Down