Skip to content

Centralize Async TP Enablement with maybe_enable_async_tp API #1619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions torchtitan/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger


def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh):
if not job_config.parallelism.enable_async_tensor_parallel:
return

if not (job_config.compile.enable and "model" in job_config.compile.components):
raise RuntimeError("Async TP requires --training.compile")

from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info("Async TP is enabled")
18 changes: 3 additions & 15 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ReordererSequenceParallel,
TensorParallel,
)
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp

from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger
Expand Down Expand Up @@ -66,12 +67,6 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -88,8 +83,8 @@ def parallelize_llama(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
Expand Down Expand Up @@ -177,7 +172,6 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -256,14 +250,8 @@ def apply_non_moe_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)

Expand Down
23 changes: 4 additions & 19 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.expert_parallel import NoParallel
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.llama4.infra.parallelize import (
apply_compile,
apply_fsdp,
Expand Down Expand Up @@ -51,16 +52,7 @@ def parallelize_deepseekv3(
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires --training.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -79,8 +71,8 @@ def parallelize_deepseekv3(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
enable_async_tp=False,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
Expand All @@ -100,7 +92,7 @@ def parallelize_deepseekv3(
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

if model_compile_enabled:
if job_config.compile.enable and "model" in job_config.compile.components:
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
torch._dynamo.config.capture_scalar_outputs = True
apply_compile(model)
Expand Down Expand Up @@ -167,7 +159,6 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -260,13 +251,7 @@ def apply_non_moe_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)
18 changes: 3 additions & 15 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -67,12 +68,6 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -89,8 +84,8 @@ def parallelize_llama(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Expand Down Expand Up @@ -144,7 +139,6 @@ def apply_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -221,14 +215,8 @@ def apply_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)

Expand Down
Loading