Skip to content

[WIP] [mxfp8] torchao mxfp8 moe integration #1549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class Float8Converter(ModelConverter):
fp8_token_group_alignment_size = 16

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

Expand Down Expand Up @@ -69,7 +71,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# For fp8 grouped GEMM, token group sizes must be multiples of 16
# (16 byte alignment / 1 byte per elem = 16 elements)
set_token_group_alignment_size_m(16)
set_token_group_alignment_size_m(self.fp8_token_group_alignment_size)

if float8_config.recipe_name is not None:
assert not float8_config.enable_fsdp_float8_all_gather, (
Expand Down
46 changes: 43 additions & 3 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch.nn as nn

from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchtitan.config.job_config import JobConfig, MX
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m
Expand All @@ -30,6 +31,7 @@ class MXConverter(ModelConverter):
enabled: bool
filter_fqns: List[str]
mx_config: Any # MXLinearConfig type when imported
mxfp8_token_group_alignment_size = 32

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
# Ensure minimum torchao versions
Expand Down Expand Up @@ -59,9 +61,11 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# For MoE training with mxfp8, token group sizes must be multiples of 32
if job_config.mx.moe_fqns_prototype:
mxfp8_block_size = 32
set_token_group_alignment_size_m(mxfp8_block_size)
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
self.moe_fqns = job_config.mx.moe_fqns_prototype
set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size)
logger.info(
f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}"
)

# Configure MXFP8
from torchao.prototype.mx_formats.config import (
Expand Down Expand Up @@ -91,6 +95,13 @@ def convert(self, model: nn.Module):
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.quantization import quantize_

# MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
# be converted back to nn.Linear:
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
if self.moe_fqns:
self._convert_moe_layers(model)

assert isinstance(self.config, MXLinearConfig)
quantize_(
model,
Expand All @@ -99,6 +110,35 @@ def convert(self, model: nn.Module):
)
logger.info("Swapped to MXLinear layers")

def _convert_moe_layers(self, model: nn.Module):
"""
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
"""
from torchao.quantization.quant_api import quantize_

try:
from torchao.prototype.moe_training.conversion_utils import (
MoETrainingConfig,
)
except ImportError as e:
raise ImportError(
"torchao installation does not have MoE training support. Please install torchao nightly build."
) from e

def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in self.moe_fqns:
if target_fqn in cur_fqn:
return True
return False

config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
logger.info(
f"Converted MoE layers matching FQNS {self.moe_fqns} "
"to use dynamic MXFP8 quantization with scaled grouped GEMMs"
)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
"""
MXFP8 doesn't require any post-optimizer hooks at the moment
Expand Down
Loading