diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer.py b/benchmarks/prototype/moe_training/benchmark_moe_layer.py index 549aae5a5e..d18c6dc176 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_layer.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_layer.py @@ -30,16 +30,18 @@ "CUDA not available or compute capability < 8.9", allow_module_level=True ) -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ -# this test requires torchtitan +# this benchmark requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -54,16 +56,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False): # define model args target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=16, - dim=5120, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -82,7 +83,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # FSDP2 @@ -90,12 +91,19 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs (llama4 shapes) - batch, seq, dim = 1, 8192, 5120 + batch, seq = 1, 8192 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) x = ref_x.detach().clone().requires_grad_(True) + def warmup(model, input): + for _ in range(3): + out = model(input) + loss = F.mse_loss(out, torch.ones_like(out)) + loss.backward() + torch.cuda.synchronize() + def bench_fn_microseconds(model, input): labels = torch.ones_like(input) times = [] @@ -142,6 +150,7 @@ def profile_fn(model, input, profile_name="profile"): model = torch.compile(model, fullgraph=False) print("Benchmarking MoE with FSDP2 using bf16 training") + warmup(ref_model, ref_x) bf16_us = bench_fn_microseconds(ref_model, ref_x) print(f"bf16 time: {bf16_us} us") if enable_profile: @@ -152,6 +161,7 @@ def profile_fn(model, input, profile_name="profile"): set_token_group_alignment_size_m(16) print("Benchmarking MoE with FSDP2 using fp8 rowwise training") + warmup(model, x) fp8_us = bench_fn_microseconds(model, x) print(f"fp8 time: {fp8_us} us") if enable_profile: diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index d08f218842..0ffdd65dff 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -22,11 +22,10 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool): # has the contraction dim be divisible by 16. 16 byte alignment is required # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements. set_token_group_alignment_size_m(16) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -75,7 +73,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) + config = MoETrainingConfig() quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -83,14 +81,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) - if compile: # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it model = torch.compile(model, fullgraph=False) ref_model = torch.compile(ref_model, fullgraph=False) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]): # Token groups must be divisible by 32 for mxfp8 set_token_group_alignment_size_m(block_size) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - multiple_of=block_size, - ffn_dim_multiplier=1.0, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 256, 4 * 256 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 7dc246e251..30dfda4a6f 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -48,7 +48,7 @@ def _scaled_grouped_mm( """ # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. if scaling_type == MoEScalingType.FP8_ROWWISE: - logger.info("Using fp8 rowwise scaled_grouped_mm") + print("Using fp8 rowwise scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -56,7 +56,7 @@ def _scaled_grouped_mm( out_dtype, ) elif scaling_type == MoEScalingType.MXFP8: - logger.info("Using mxfp8 scaled_grouped_mm") + print("Using mxfp8 scaled_grouped_mm") block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? return _MXFP8GroupedMM.apply( A, @@ -144,7 +144,7 @@ def forward( # low precision B tensor instead of the high precision B tensor. # In the backward this is needed for grad_A: grad_output @ B. B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( - B_t, + B_t._data, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 1ddd098675..e0ab00fce8 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -97,9 +97,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): A_is_2d = A.dim() == 2 B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None + other_args = args[2:] if A_is_2d and B_is_3d and has_offs: return _scaled_grouped_mm( - *args, + A, + B, + *other_args, scaling_type=scaling_type, **kwargs, )