From cab372fd004b2fc862041a279769042f466b4adc Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 11 Aug 2025 09:18:27 -0700 Subject: [PATCH] torchao mxfp8 moe integration --- torchtitan/components/quantization/float8.py | 4 +- torchtitan/components/quantization/mx.py | 46 ++++++++++++++++++-- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 362925815..9a5d4e0aa 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -24,6 +24,8 @@ class Float8Converter(ModelConverter): + fp8_token_group_alignment_size = 16 + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False @@ -69,7 +71,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # For fp8 grouped GEMM, token group sizes must be multiples of 16 # (16 byte alignment / 1 byte per elem = 16 elements) - set_token_group_alignment_size_m(16) + set_token_group_alignment_size_m(self.fp8_token_group_alignment_size) if float8_config.recipe_name is not None: assert not float8_config.enable_fsdp_float8_all_gather, ( diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 15c74b7fd..916fa5802 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -11,6 +11,7 @@ import torch.nn as nn +from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchtitan.config.job_config import JobConfig, MX from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m @@ -30,6 +31,7 @@ class MXConverter(ModelConverter): enabled: bool filter_fqns: List[str] mx_config: Any # MXLinearConfig type when imported + mxfp8_token_group_alignment_size = 32 def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # Ensure minimum torchao versions @@ -59,9 +61,11 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # For MoE training with mxfp8, token group sizes must be multiples of 32 if job_config.mx.moe_fqns_prototype: - mxfp8_block_size = 32 - set_token_group_alignment_size_m(mxfp8_block_size) - logger.info(f"Setting token group alignment size to {mxfp8_block_size}") + self.moe_fqns = job_config.mx.moe_fqns_prototype + set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size) + logger.info( + f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}" + ) # Configure MXFP8 from torchao.prototype.mx_formats.config import ( @@ -91,6 +95,13 @@ def convert(self, model: nn.Module): from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.quantization import quantize_ + # MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will + # be converted back to nn.Linear: + # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299 + # TODO: add warning in torchao when this happens, or find a better way to avoid this. + if self.moe_fqns: + self._convert_moe_layers(model) + assert isinstance(self.config, MXLinearConfig) quantize_( model, @@ -99,6 +110,35 @@ def convert(self, model: nn.Module): ) logger.info("Swapped to MXLinear layers") + def _convert_moe_layers(self, model: nn.Module): + """ + Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, + to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. + """ + from torchao.quantization.quant_api import quantize_ + + try: + from torchao.prototype.moe_training.conversion_utils import ( + MoETrainingConfig, + ) + except ImportError as e: + raise ImportError( + "torchao installation does not have MoE training support. Please install torchao nightly build." + ) from e + + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in self.moe_fqns: + if target_fqn in cur_fqn: + return True + return False + + config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + logger.info( + f"Converted MoE layers matching FQNS {self.moe_fqns} " + "to use dynamic MXFP8 quantization with scaled grouped GEMMs" + ) + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): """ MXFP8 doesn't require any post-optimizer hooks at the moment