From 21695730a4e4d37e81806219e38f22e9251201ea Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <592045536@qq.com> Date: Tue, 18 Nov 2025 20:46:23 +0800 Subject: [PATCH 1/5] support dynamic activation quant for w4afp8 --- .../model_executor/layers/moe/fused_moe_cutlass_backend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 1e706a07320..5736f6fda00 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -1090,7 +1090,7 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process "down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None), } for name, value in scale_key_map.items(): - if value is None: + if hasattr(layer, name) and value is None: raise ValueError(f"scale {name} should not be none in w4a8 mode.") # 2. Extract scale tensor from state dict @@ -1111,8 +1111,9 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process for expert_idx in logical_expert_ids: for name, scale_key_template in scale_key_map.items(): - scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) - scale_weight_map[name].append(scale_tensor) + if hasattr(layer, name): + scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) + scale_weight_map[name].append(scale_tensor) for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale") From b8386943f2cc384e89e0d744d8dbbb31ab67bbb2 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <592045536@qq.com> Date: Thu, 27 Nov 2025 22:34:20 +0800 Subject: [PATCH 2/5] support dynamic w4afp8 --- .../gpu_ops/moe/ep_moe_expert_dispatch.cu | 6 +- fastdeploy/model_executor/layers/moe/ep.py | 5 +- .../layers/moe/fused_moe_cutlass_backend.py | 75 ++++++-- .../layers/quantization/__init__.py | 7 + .../layers/quantization/w4afp8.py | 6 +- fastdeploy/model_executor/layers/utils.py | 167 ++++++++++++++++++ 6 files changed, 246 insertions(+), 20 deletions(-) diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index 8a8fb111697..e4a1cc7f9cd 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -425,7 +425,9 @@ __global__ void permute_x_kernel( } abs_max = phi::BlockAllReduce(abs_max); float scale = 440.f / abs_max; // use 440 so we do not have to clip - dequant_scale[dst_token_idx] = abs_max; + if (tid == 0) { + dequant_scale[dst_token_idx] = abs_max; + } for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { Load(&data_smem[v_id * vec_size], &src_vec); #pragma unroll @@ -656,7 +658,7 @@ std::vector EPMoeExpertDispatch( int dequant_scale_size = 1; if (moe_quant_type == "w4afp8" && !up_gate_proj_in_scale) { - dequant_scale_size = moe_topk * num_rows; + dequant_scale_size = token_nums_this_rank; } auto dequant_scale = diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index a1dcda67f7e..be40b3d04c8 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -295,7 +295,7 @@ def low_latency_dispatch( use_fp8=use_fp8, async_finish=False, return_recv_hook=True, - # num_per_channel=quant_group_size, + num_per_channel=quant_group_size, ) return packed_recv_x, recv_expert_count, handle, dispatch_hook @@ -634,10 +634,11 @@ def dispatch( ): expertwise_scale = kwargs.get("expertwise_scale", None) use_fp8 = kwargs.get("use_fp8", False) + quant_group_size = kwargs.get("quant_group_size", 128) if not self.use_internode_ll_two_stage: recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch( - x, topk_idx, expertwise_scale, use_fp8 + x, topk_idx, expertwise_scale, use_fp8, quant_group_size ) else: # just supports dispatch_use_fp8 = True now! diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index d00dd45a0b1..7286f704ee6 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -23,7 +23,7 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.platforms import current_platform -from ..utils import get_tensor +from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_model from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): @@ -147,8 +147,11 @@ def apply_ep_prefill( recv_topk_weights, recv_num_tokens_per_expert_list, handle, - _, + event, ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights) + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() token_all_num = sum(recv_num_tokens_per_expert_list) # 3. Compute ffn @@ -206,7 +209,10 @@ def apply_ep_prefill( tmp_ffn_out = recv_x # 4. EP combine - return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) + tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + return tmp_ffn_out def apply_ep_decode( self, @@ -744,7 +750,7 @@ def __init__(self, quant_config): super().__init__(quant_config) self.quant_config = quant_config self.moe_quant_type = "w4afp8" - self.pack_num = 2 + self.pack_num = 2 if quant_config.is_quantized else 1 def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ @@ -911,21 +917,58 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): """ Paddle cutlass load weight process. """ + if not layer.is_quantized: + logger.info( + f"Rotating ernie.layers.{layer.layer_idx}.mlp.experts.[{layer.ep_rank * layer.num_local_experts},{layer.ep_rank * layer.num_local_experts + layer.num_local_experts}).down_proj.weight..." + ) + rotate_model( + state_dict, + layer.layer_idx, + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ep_rank=layer.ep_rank, + ) + up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( layer.extract_moe_ffn_weights(state_dict) ) self.check(layer, up_gate_proj_weights, down_proj_weights) + + up_gate_proj_weight_scales = [] + down_proj_weight_scales = [] + dynamic_scale_weight_map = { + self.added_scale_attrs[0]: up_gate_proj_weight_scales, + self.added_scale_attrs[1]: down_proj_weight_scales, + } + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] + weight_scale_name = self.added_scale_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): - quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i]) + quant_weight = weight_tensor[i] + if not layer.is_quantized: + block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512) + quant_weight, weight_scale = group_wise_int4_weight_quantize(weight_tensor[i], group_size=128) + free_tensor(weight_tensor[i]) + quant_weight = pack(quant_weight.transpose([1, 0]), bits=4) + if "down_proj" in weight_name: + weight_scale = weight_scale / (block_size**0.5) + dynamic_scale_weight_map[weight_scale_name].append(weight_scale) + + quant_weight = w4afp8_gemm_weight_convert(quant_weight) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) getattr(layer, weight_name).set_value(quanted_weight) self.load_w4afp8_scale_weights( - layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list + layer, + layer.weight_key_map, + state_dict, + logical_expert_ids, + ep_rank_to_expert_id_list, + dynamic_scale_weight_map, ) def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): @@ -937,7 +980,7 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): """ self.default_dtype = layer._helper.get_default_dtype() - if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant: + if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant: setattr( layer, "up_gate_proj_in_scale_all_experts", @@ -949,7 +992,7 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): ) # in_scales - if not layer.moe_quant_config.moe_dynamic_quant: + if layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant: for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: setattr( layer, @@ -988,6 +1031,7 @@ def load_w4afp8_scale_weights( state_dict: dict, logical_expert_ids: paddle.Tensor, ep_rank_to_expert_id_list: list, + dynamic_scale_weight_map: dict, ): """ Get w4afp8 weights from state dict and process them. @@ -1094,7 +1138,7 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process raise ValueError(f"scale {name} should not be none in w4a8 mode.") # 2. Extract scale tensor from state dict - if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant: + if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant: for expert_idx in ep_rank_to_expert_id_list: scale_tensor = get_tensor( ( @@ -1109,11 +1153,14 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process paddle.concat(up_gate_proj_in_scales_all_experts) ) - for expert_idx in logical_expert_ids: - for name, scale_key_template in scale_key_map.items(): - if hasattr(layer, name): - scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) - scale_weight_map[name].append(scale_tensor) + if not layer.is_quantized: + scale_weight_map = dynamic_scale_weight_map + else: + for expert_idx in logical_expert_ids: + for name, scale_key_template in scale_key_map.items(): + if hasattr(layer, name): + scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) + scale_weight_map[name].append(scale_tensor) for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale") diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 5d882aed292..644021b1a47 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -84,6 +84,13 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quantization_config["moe_quant_type"] = "wint4" quantization_config["quantization"] = "mix_quant" quant_config_name = "mix_quant" + # Special handling for moe w4afp8 dynamic quant + elif quant_config_name == "w4afp8": + quantization_config["dense_quant_type"] = "block_wise_fp8" + quantization_config["moe_quant_type"] = "w4afp8" + quantization_config["hadamard_block_size"] = 512 + quantization_config["quantization"] = "mix_quant" + quant_config_name = "mix_quant" else: quant_config_name = None if quant_config_name is None: diff --git a/fastdeploy/model_executor/layers/quantization/w4afp8.py b/fastdeploy/model_executor/layers/quantization/w4afp8.py index e7be78b06dd..2e61c97fde8 100644 --- a/fastdeploy/model_executor/layers/quantization/w4afp8.py +++ b/fastdeploy/model_executor/layers/quantization/w4afp8.py @@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase): quantization config for weight 4bits and activation fp8 """ - def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None: + def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized) -> None: super().__init__() self.weight_scale_dict = weight_scale_dict self.act_scale_dict = act_scale_dict @@ -40,6 +40,7 @@ def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_bloc self.quant_round_type = 1 self.is_permuted = is_permuted self.hadamard_block_size = hadamard_block_size + self.is_quantized = is_quantized def name(self) -> str: return "w4afp8" @@ -50,7 +51,8 @@ def from_config(cls, config: dict) -> "W4AFP8Config": act_scale_dict = config.get("act_scale_dict", None) is_permuted = config.get("is_permuted", True) hadamard_block_size = config.get("hadamard_block_size", 128) - return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) + is_quantized = config.get("is_quantized", False) + return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if isinstance(layer, FusedMoE): diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index c0644896e8e..79d6e9c94f0 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -53,6 +53,173 @@ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> return ((vocab_size + pad_to - 1) // pad_to) * pad_to +def random_orthogonal_matrix(size, device): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + paddle.Tensor: An orthogonal matrix of the specified size. + """ + paddle.device.cuda.empty_cache() + if device == "cuda": + random_matrix = paddle.randn(size, size, dtype="float32").to("gpu") + q, r = paddle.linalg.qr(random_matrix) + q *= paddle.sign(paddle.diag(r)).unsqueeze(0) + return q + + +def is_pow2(n): + return (n & (n - 1) == 0) and (n > 0) + + +def get_hadK(n, transpose=False): + hadK, K = None, None + assert is_pow2(n) + K = 1 + return hadK, K + + +def matmul_hadU_int4(X, transpose=False): + n = X.shape[-1] + hadK, K = get_hadK(n, transpose) + input = X.clone().reshape((-1, n, 1)) + output = input.clone() + while input.shape[1] > K: + input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2])) + output = output.reshape(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.reshape((input.shape[0], input.shape[1], -1)) + (input, output) = (output, input) + del output + + if K > 1: + input = hadK.reshape((1, K, K)).to(input) @ input + + return input.reshape(X.shape) / paddle.to_tensor(n, dtype="float32").sqrt() + + +def random_hadamard_matrix_int4(size, device=None, ffn2=False): + # See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" + if not ffn2: + Q = paddle.randint(low=0, high=2, shape=(size,)).cast("float32") + Q = paddle.ones_like(Q, dtype="float32") + Q = Q * 2 - 1 + Q = paddle.diag(Q) + return matmul_hadU_int4(Q), None + + else: + num_blocks = size + while not (num_blocks % 2): + num_blocks = num_blocks // 2 + block_size = size // num_blocks + Q = paddle.diag(paddle.ones((block_size,), dtype="float32")) + block = matmul_hadU_int4(Q) + large_matrix = paddle.zeros([size, size]) + + for i in range(num_blocks): + start_row = i * block_size + start_col = i * block_size + large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block + return large_matrix.cast("float32"), block_size + + +def get_orthogonal_matrix(size, mode="hadamard", device="cuda"): + if mode == "random": + return random_orthogonal_matrix(size, device) + elif mode == "hadamard": + return random_hadamard_matrix_int4(size, device) + elif mode == "hadamard_ffn2": + return random_hadamard_matrix_int4(size, device, True) + else: + raise ValueError(f"Unknown mode {mode}") + + +def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0): + with paddle.no_grad(): + # collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size] + Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2") + # down_proj.weight: [moe_intermediate_size, hidden_size] + expert_list = [ + get_tensor( + state_dict[ + f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight" + ] + ) + for expert_idx in range(moe_num_experts) + ] + moe_weight = paddle.concat(expert_list, axis=-1) # [moe_intermediate_size, hidden_size * moe_num_experts] + new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.to(Q_ffn2.place) + for expert_idx in range(moe_num_experts): + rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size] + expert_idx_local = ep_rank * moe_num_experts + expert_idx + state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = ( + rotated_weight.cpu() + ) + del moe_weight, new_moe_weight, rotated_weight + paddle.device.cuda.empty_cache() + return Q_ffn2.cpu() + + +def pack(src, bits=4): + pack_num = 8 // bits + shift_bits = (paddle.arange(0, pack_num) * bits).cast("uint8") + src = paddle.to_tensor(src).cast("uint8") + + if len(src.shape) == 2: + row, col = src.shape + src = src.reshape((row, col // pack_num, pack_num)) + else: + src = src.reshape((src.shape[0] // pack_num, pack_num)) + + src[..., 0] = paddle.bitwise_and(src[..., 0], paddle.to_tensor(15, dtype="uint8")) + src = paddle.to_tensor(src.numpy() << shift_bits.numpy()) + + return src.sum(axis=-1).transpose((1, 0)).cast("int8") + + +def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128): + """ + Block-wise int4 weight quantization. + + Args + weight: paddle.Tensor + group_size: int + + Returns + weight_quant: paddle.Tensor, int8 weight after quantization and pack + weight_scale: paddle.Tensor, fp32 weight scale with group_size + """ + if weight.dtype == paddle.bfloat16: + weight = weight.astype(paddle.float32) + assert weight.dim() == 2 + weight = weight.transpose((1, 0)) + out_features, in_features = weight.shape + q_max, q_min = 7, -8 + + # [out_features, in_features] -> [out_features, in_features // group_size, group_size] + assert ( + in_features % group_size == 0 + ), f"in_features must be divisible by group_size: {group_size}, but got in_features: {in_features}" + weight = weight.reshape((out_features, in_features // group_size, group_size)) + + # calculate weight_scale + abs_max = paddle.max(paddle.abs(weight), axis=-1, keepdim=False).astype(paddle.float32) + weight_scale = paddle.clip(abs_max, min=1e-8) + + quant_weight = paddle.round(weight / weight_scale.unsqueeze(-1) * q_max) + quant_weight = paddle.clip(quant_weight, min=q_min, max=q_max) + quant_weight = quant_weight.reshape((out_features, in_features)).transpose((1, 0)) + + return quant_weight.astype(paddle.int8), weight_scale + + def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: """ Only used in deep_gemm block wise quant weight. From 4cd3ff7581d1c9a6945b16e6f1ea3660e7bd46ee Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <592045536@qq.com> Date: Mon, 1 Dec 2025 19:57:27 +0800 Subject: [PATCH 3/5] add test --- .../utils/auto_gen_w4afp8_gemm_kernel.py | 2 +- tests/layers/test_w4afp8_moe.py | 219 ++++++++++++++++++ 2 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 tests/layers/test_w4afp8_moe.py diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 86268380ea4..4d783341fe3 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -85,7 +85,7 @@ """ # [M, K, Number of experts, token Padding Size, weight K group size] -gemm_case = [[256, 256, 2, 0, 128]] +gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]] dtype = ["BF16"] diff --git a/tests/layers/test_w4afp8_moe.py b/tests/layers/test_w4afp8_moe.py new file mode 100644 index 00000000000..65b7733172c --- /dev/null +++ b/tests/layers/test_w4afp8_moe.py @@ -0,0 +1,219 @@ +import json +import os +import shutil +import unittest + +import paddle +from paddle.distributed import fleet + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) +from fastdeploy.model_executor.layers.moe.moe import FusedMoE +from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config +from fastdeploy.scheduler import SchedulerConfig + +# from fastdeploy.worker.worker_process import init_distributed_environment +from tests.utils import OpPerformanceTester + +paddle.set_default_dtype("bfloat16") + + +class FuseMoEWrapper(paddle.nn.Layer): + def __init__( + self, + model_config: ModelConfig, + tp_size: int = 1, + tp_rank: int = 0, + ep_size: int = 1, + ep_rank: int = 0, + prefix: str = "ernie.layers.0", + nnodes: int = 1, + ): + super().__init__() + self.model_config = model_config + + self.tp_size = tp_size + self.ep_size = ep_size + self.ep_rank = ep_rank + + self.prefix = prefix + self.fd_config = FDConfig( + model_config=self.model_config, + parallel_config=ParallelConfig( + { + "tensor_parallel_size": self.tp_size, + "expert_parallel_size": self.ep_size, + "expert_parallel_rank": self.ep_rank, + "data_parallel_size": self.ep_size, + } + ), + quant_config=W4AFP8Config( + weight_scale_dict=None, + act_scale_dict=None, + is_permuted=False, + hadamard_block_size=512, + is_quantized=False, + ), + scheduler_config=SchedulerConfig({}), + cache_config=CacheConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + load_config=LoadConfig({}), + ips=",".join(["0"] * nnodes), + ) + self.fd_config.parallel_config.tp_group = None + self.fd_config.parallel_config.tensor_parallel_rank = tp_rank + self.fd_config.parallel_config.expert_parallel_size = self.ep_size + if self.ep_size > 1: + self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + self.fd_config.scheduler_config.splitwise_role = "mixed" + self.fd_config.model_config.moe_phase.phase = "decode" + + weight_key_map = { + "gate_weight_key": f"{self.prefix}.gate.weight", + "gate_correction_bias_key": f"{self.prefix}.moe_statics.e_score_correction_bias", + "up_gate_proj_expert_weight_key": f"{self.prefix}.mlp.experts.{{}}.up_gate_proj.weight", + "down_proj_expert_weight_key": f"{self.prefix}.mlp.experts.{{}}.down_proj.weight", + } + + self.fused_moe = FusedMoE( + fd_config=self.fd_config, + moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, + num_experts=self.fd_config.model_config.moe_num_experts, + top_k=self.fd_config.model_config.moe_k, + # avoiding invoke clean_low_latency_buffer in mixed ep. + layer_idx=0, + weight_key_map=weight_key_map, + topk_method="noaux_tc", + topk_group=4, + n_group=8, + gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32), + # gate_correction_bias = gate_correction_bias_real_data + ) + self.pack_num = 1 + moe_layer = self.fused_moe + + up_gate_proj_weight_shape = [ + moe_layer.num_local_experts, + moe_layer.hidden_size // self.pack_num, + moe_layer.moe_intermediate_size * 2, + ] + down_proj_weight_shape = [ + moe_layer.num_local_experts, + moe_layer.moe_intermediate_size // self.pack_num, + moe_layer.hidden_size, + ] + + up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) + down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16) + + local_expert_ids = list( + range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts) + ) + state_dict = {} + up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key") + down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key") + + for expert_idx in local_expert_ids: + up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx) + down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) + + state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[ + expert_idx - moe_layer.expert_id_offset + ] + state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset] + + moe_layer.load_state_dict(state_dict) + + +class TestW4A8FusedMoE(unittest.TestCase): + def setUp(self) -> None: + self.architectures = ["Ernie4_5_MoeForCausalLM"] + self.hidden_size = 256 + self.moe_intermediate_size = 256 + self.moe_num_experts = 2 + self.moe_k = 2 + self.hidden_act = "silu" + self.num_attention_heads = 56 + self.num_hidden_layers = 1 + self.model_config = self.build_model_config() + + def build_model_config(self) -> ModelConfig: + model_name_or_path = self.build_config_json() + return ModelConfig( + { + "model": model_name_or_path, + "max_model_len": 2048, + } + ) + + def build_config_json(self) -> str: + config_dict = { + "architectures": self.architectures, + "hidden_size": self.hidden_size, + "moe_intermediate_size": self.moe_intermediate_size, + "moe_num_experts": self.moe_num_experts, + "moe_k": self.moe_k, + "hidden_act": self.hidden_act, + "num_attention_heads": self.num_attention_heads, + "num_hidden_layers": self.num_hidden_layers, + "dtype": "bfloat16", + "is_quantized": False, + } + + tmp_dir = "./tmp_w4afp8_moe" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + return self.model_name_or_path + + def test_fused_moe(self): + # init_distributed_environment() + + gating = paddle.nn.Linear(self.model_config.hidden_size, self.model_config.moe_num_experts) + gating.to(dtype=paddle.float32) # it's dtype is bfloat16 default, but the forward input is float32 + gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) + + # ep_size = paddle.distributed.get_world_size() + # ep_rank = paddle.distributed.get_rank() + ep_size = 1 + ep_rank = 0 + + tp_size = 1 + tp_rank = 0 + + nnodes = (ep_size + 7) // 8 + + # 这行代码必须保留,否则影响均匀性! + paddle.seed(ep_rank + 100) + + fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes).fused_moe + weight_size = fused_moe.top_k * fused_moe.hidden_size * fused_moe.moe_intermediate_size * 3 / 2 + + tester = OpPerformanceTester( + op_name="w4afp8-moe", + op_fn=fused_moe, + num_layers=self.model_config.num_hidden_layers, + weight_size=weight_size, + gate=gating, + ) + + tester.benchmark( + input_size=self.model_config.hidden_size, + batch_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ) + + def tearDown(self) -> None: + if self.model_name_or_path: + print("Remove tmp model config file") + shutil.rmtree(self.model_name_or_path) + + +if __name__ == "__main__": + unittest.main() From 92e89094fee7c0998c2af4fdd95406be34e82157 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <592045536@qq.com> Date: Mon, 1 Dec 2025 20:00:37 +0800 Subject: [PATCH 4/5] fix --- .../layers/moe/fused_moe_cutlass_backend.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 98006223bf2..ded9ca4867f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -1001,24 +1001,25 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): ) # weight_scales - setattr( - layer, - "up_gate_proj_weight_scale", - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - "down_proj_weight_scale", - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + if layer.is_quantized: + setattr( + layer, + "up_gate_proj_weight_scale", + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + "down_proj_weight_scale", + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) def load_w4afp8_scale_weights( self, From 29a88f9b933b62c19cbdbdec1a8f27c5a171f654 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:43:52 +0800 Subject: [PATCH 5/5] fix --- tests/quantization/test_w4afp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/quantization/test_w4afp8.py b/tests/quantization/test_w4afp8.py index 97669acf246..36d24c966f1 100644 --- a/tests/quantization/test_w4afp8.py +++ b/tests/quantization/test_w4afp8.py @@ -32,6 +32,7 @@ def setUp(self): act_scale_dict={"layer.activation_scale": 1.0}, is_permuted=False, hadamard_block_size=128, + is_quantized=True, ) self.method = W4AFP8LinearMethod(self.config)