Skip to content

[moe training] update tests and benchmarks for torchtitan moe refactor #2718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions benchmarks/prototype/moe_training/benchmark_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@
"CUDA not available or compute capability < 8.9", allow_module_level=True
)

from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.conversion_utils import (
MoEScalingType,
MoETrainingConfig,
)
from torchao.quantization.quant_api import quantize_

# this test requires torchtitan
# this benchmark requires torchtitan
try:
from torchtitan.experiments.llama4.infra.expert_parallel import (
from torchtitan.distributed.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
from torchtitan.models.moe import MoE, MoEArgs
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -54,16 +56,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False):

# define model args
target_fqns = ["experts"]
model_args = TransformerModelArgs(
moe_enabled=True,
model_args = MoEArgs(
num_experts=16,
dim=5120,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
dim, hidden_dim = 5120, 4 * 5120
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

Expand All @@ -82,20 +83,27 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig()
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# FSDP2
fully_shard(model)
fully_shard(ref_model)

# inputs (llama4 shapes)
batch, seq, dim = 1, 8192, 5120
batch, seq = 1, 8192
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

def warmup(model, input):
for _ in range(3):
out = model(input)
loss = F.mse_loss(out, torch.ones_like(out))
loss.backward()
torch.cuda.synchronize()

def bench_fn_microseconds(model, input):
labels = torch.ones_like(input)
times = []
Expand Down Expand Up @@ -142,6 +150,7 @@ def profile_fn(model, input, profile_name="profile"):
model = torch.compile(model, fullgraph=False)

print("Benchmarking MoE with FSDP2 using bf16 training")
warmup(ref_model, ref_x)
bf16_us = bench_fn_microseconds(ref_model, ref_x)
print(f"bf16 time: {bf16_us} us")
if enable_profile:
Expand All @@ -152,6 +161,7 @@ def profile_fn(model, input, profile_name="profile"):
set_token_group_alignment_size_m(16)

print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
warmup(model, x)
fp8_us = bench_fn_microseconds(model, x)
print(f"fp8 time: {fp8_us} us")
if enable_profile:
Expand Down
28 changes: 11 additions & 17 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

# this test requires torchtitan
try:
from torchtitan.experiments.llama4.infra.expert_parallel import (
from torchtitan.distributed.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
from torchtitan.models.moe import MoE, MoEArgs
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool):
# has the contraction dim be divisible by 16. 16 byte alignment is required
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
set_token_group_alignment_size_m(16)
model_args = TransformerModelArgs(
moe_enabled=True,
model_args = MoEArgs(
num_experts=8,
dim=256,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
dim, hidden_dim = 5120, 4 * 5120
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

Expand All @@ -75,22 +73,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)

if compile:
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
model = torch.compile(model, fullgraph=False)
ref_model = torch.compile(ref_model, fullgraph=False)

# inputs
batch, seq, dim = 8, 2048, 256
batch, seq = 8, 2048
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
Expand Down Expand Up @@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]):
# Token groups must be divisible by 32 for mxfp8
set_token_group_alignment_size_m(block_size)

model_args = TransformerModelArgs(
moe_enabled=True,
model_args = MoEArgs(
num_experts=8,
dim=256,
multiple_of=block_size,
ffn_dim_multiplier=1.0,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
dim, hidden_dim = 256, 4 * 256
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

Expand Down Expand Up @@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
)

# inputs
batch, seq, dim = 8, 2048, 256
batch, seq = 8, 2048
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def _scaled_grouped_mm(
"""
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
logger.info("Using fp8 rowwise scaled_grouped_mm")
print("Using fp8 rowwise scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)
elif scaling_type == MoEScalingType.MXFP8:
logger.info("Using mxfp8 scaled_grouped_mm")
print("Using mxfp8 scaled_grouped_mm")
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
return _MXFP8GroupedMM.apply(
A,
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(
# low precision B tensor instead of the high precision B tensor.
# In the backward this is needed for grad_A: grad_output @ B.
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_t,
B_t._data,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
Expand Down
5 changes: 4 additions & 1 deletion torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
A_is_2d = A.dim() == 2
B_is_3d = B.dim() == 3
has_offs = kwargs.get(cls.offs_arg_name) is not None
other_args = args[2:]
if A_is_2d and B_is_3d and has_offs:
return _scaled_grouped_mm(
*args,
A,
B,
*other_args,
scaling_type=scaling_type,
**kwargs,
)
Expand Down
Loading