diff --git a/MaxText/train.py b/MaxText/train.py index eb20a164bd..7680fa651d 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -18,6 +18,7 @@ # Calling jax.device_count here prevents a "TPU platform already registered" error. # See github.com/google/maxtext/issues/20 for more +from contextlib import contextmanager from typing import Any, Sequence import datetime import functools @@ -754,9 +755,34 @@ def run(config, recorder, diagnostic_config): train_loop(config, recorder) +@contextmanager +def transformer_engine_context_or_noop(): + """If TransformerEngine is available, this context manager will provide the library with MaxText-specific details needed for correcct operation. + + If TransformerEngine is not available, this is a No-Op and does not add any context. + + If the transformer_engine package is available but TransformerEngine is not used in MaxText, this will still be a No-Op + as this context's data is only used if TransformerEngine modules, such as attention=cudnn_flash_te are called.""" + try: + from transformer_engine.jax.sharding import global_shard_guard, MeshResource + # Inform TransformerEngine of MaxText's physical mesh resources. + mesh_resource = MeshResource( + dp_resource="data", + tp_resource="tensor", + fsdp_resource="fsdp", + pp_resource=None, + cp_resource="context", + ) + with global_shard_guard(mesh_resource): + yield + except ImportError: + yield + + def main(argv: Sequence[str]) -> None: - config, recorder, diagnostic_config = initialize(argv) - run(config, recorder, diagnostic_config) + with transformer_engine_context_or_noop(): + config, recorder, diagnostic_config = initialize(argv) + run(config, recorder, diagnostic_config) if __name__ == "__main__":