Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
TensorParallel,
)

from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_ddp,
maybe_enable_async_tp,
)
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -66,12 +70,6 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -88,8 +86,8 @@ def parallelize_llama(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
Expand Down Expand Up @@ -177,7 +175,6 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -256,14 +253,8 @@ def apply_non_moe_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)

Expand Down
28 changes: 8 additions & 20 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
apply_fsdp,
apply_moe_ep_tp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_ddp,
maybe_enable_async_tp,
)
from torchtitan.tools.logging import logger


Expand All @@ -51,16 +55,7 @@ def parallelize_deepseekv3(
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires --training.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -79,8 +74,8 @@ def parallelize_deepseekv3(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
enable_async_tp=False,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
Expand All @@ -100,7 +95,7 @@ def parallelize_deepseekv3(
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

if model_compile_enabled:
if job_config.compile.enable and "model" in job_config.compile.components:
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
torch._dynamo.config.capture_scalar_outputs = True
apply_compile(model)
Expand Down Expand Up @@ -167,7 +162,6 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -260,13 +254,7 @@ def apply_non_moe_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)
32 changes: 17 additions & 15 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not model_compile_enabled
):
raise RuntimeError("Async TP requires torch.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
Expand All @@ -89,8 +83,8 @@ def parallelize_llama(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Expand Down Expand Up @@ -139,12 +133,26 @@ def parallelize_llama(
return model


def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest we take this chance to put it into model-agnostic file. Specifically I'm thinking of torchtitan/distributed/tensor_parallel.py where we can also put NoParallel (https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L116) to.

I'm also thinking we may want to put most apply_ac (and maybe apply_compile) logic to that folder, as they are pretty much the same across all models.

Copy link
Contributor Author

@fegin fegin Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to have a tensor_parallel.py module. However, I am unsure if NoParallel should be part of tensor_parallel.py. Perhaps it would be more appropriate to place it in torchtitan/distributed/__init__.py, so that users can simply import it with from torchtitan.distributed import NoParallel.

As for AC, we can do this in another PR.

if not job_config.parallelism.enable_async_tensor_parallel:
return

if not (job_config.compile.enable and "model" in job_config.compile.components):
raise RuntimeError("Async TP requires --training.compile")

from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info("Async TP is enabled")


def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -221,14 +229,8 @@ def apply_tp(
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
"Tensor Parallelism to the model"
)

Expand Down
Loading