diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py new file mode 100644 index 000000000..a2749f4c1 --- /dev/null +++ b/torchtitan/distributed/tensor_parallel.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config import JobConfig +from torchtitan.tools.logging import logger + + +def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh): + if not job_config.parallelism.enable_async_tensor_parallel: + return + + if not (job_config.compile.enable and "model" in job_config.compile.components): + raise RuntimeError("Async TP requires --training.compile") + + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info("Async TP is enabled") diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index a716c7890..e51168657 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -28,6 +28,7 @@ ReordererSequenceParallel, TensorParallel, ) +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger @@ -66,12 +67,6 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -88,8 +83,8 @@ def parallelize_llama( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( @@ -177,7 +172,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -256,14 +250,8 @@ def apply_non_moe_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1aedd73ad..8423c2a8e 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -19,6 +19,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import NoParallel +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.llama4.infra.parallelize import ( apply_compile, apply_fsdp, @@ -51,16 +52,7 @@ def parallelize_deepseekv3( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires --training.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -79,8 +71,8 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - enable_async_tp=False, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( @@ -100,7 +92,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if model_compile_enabled: + if job_config.compile.enable and "model" in job_config.compile.components: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) @@ -167,7 +159,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -260,13 +251,7 @@ def apply_non_moe_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 2e2e81302..05fd7043f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -31,6 +31,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -67,12 +68,6 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -89,8 +84,8 @@ def parallelize_llama( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -144,7 +139,6 @@ def apply_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -221,14 +215,8 @@ def apply_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" )