Skip to content
6 changes: 4 additions & 2 deletions custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,9 @@ __global__ void permute_x_kernel(
}
abs_max = phi::BlockAllReduce<MaxOp, float, Kthread>(abs_max);
float scale = 440.f / abs_max; // use 440 so we do not have to clip
dequant_scale[dst_token_idx] = abs_max;
if (tid == 0) {
dequant_scale[dst_token_idx] = abs_max;
}
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
Load<T, vec_size>(&data_smem[v_id * vec_size], &src_vec);
#pragma unroll
Expand Down Expand Up @@ -661,7 +663,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(

int dequant_scale_size = 1;
if (moe_quant_type == "w4afp8" && !up_gate_proj_in_scale) {
dequant_scale_size = moe_topk * num_rows;
dequant_scale_size = token_nums_this_rank;
}

auto dequant_scale =
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"""

# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128]]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]

dtype = ["BF16"]

Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/model_executor/layers/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def low_latency_dispatch(
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
# num_per_channel=quant_group_size,
num_per_channel=quant_group_size,
)

return packed_recv_x, recv_expert_count, handle, dispatch_hook
Expand Down Expand Up @@ -634,10 +634,11 @@ def dispatch(
):
expertwise_scale = kwargs.get("expertwise_scale", None)
use_fp8 = kwargs.get("use_fp8", False)
quant_group_size = kwargs.get("quant_group_size", 128)

if not self.use_internode_ll_two_stage:
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
x, topk_idx, expertwise_scale, use_fp8
x, topk_idx, expertwise_scale, use_fp8, quant_group_size
)
else:
# just supports dispatch_use_fp8 = True now!
Expand Down
102 changes: 72 additions & 30 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import fastdeploy
from fastdeploy.platforms import current_platform

from ..utils import get_tensor
from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_model
from .fused_moe_backend_base import UnquantizedFusedMoEMethod

if current_platform.is_cuda():
Expand Down Expand Up @@ -745,7 +745,7 @@ def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = "w4afp8"
self.pack_num = 2
self.pack_num = 2 if quant_config.is_quantized else 1

def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Expand Down Expand Up @@ -912,21 +912,58 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
if not layer.is_quantized:
logger.info(
f"Rotating ernie.layers.{layer.layer_idx}.mlp.experts.[{layer.ep_rank * layer.num_local_experts},{layer.ep_rank * layer.num_local_experts + layer.num_local_experts}).down_proj.weight..."
)
rotate_model(
state_dict,
layer.layer_idx,
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
ep_rank=layer.ep_rank,
)

up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)

up_gate_proj_weight_scales = []
down_proj_weight_scales = []
dynamic_scale_weight_map = {
self.added_scale_attrs[0]: up_gate_proj_weight_scales,
self.added_scale_attrs[1]: down_proj_weight_scales,
}

for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
weight_scale_name = self.added_scale_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i])
quant_weight = weight_tensor[i]
if not layer.is_quantized:
block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512)
quant_weight, weight_scale = group_wise_int4_weight_quantize(weight_tensor[i], group_size=128)
free_tensor(weight_tensor[i])
quant_weight = pack(quant_weight.transpose([1, 0]), bits=4)
if "down_proj" in weight_name:
weight_scale = weight_scale / (block_size**0.5)
dynamic_scale_weight_map[weight_scale_name].append(weight_scale)

quant_weight = w4afp8_gemm_weight_convert(quant_weight)
weight_list.append(quant_weight)
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)

self.load_w4afp8_scale_weights(
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
layer,
layer.weight_key_map,
state_dict,
logical_expert_ids,
ep_rank_to_expert_id_list,
dynamic_scale_weight_map,
)

def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
Expand All @@ -938,7 +975,7 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
"""

self.default_dtype = layer._helper.get_default_dtype()
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_quantized和moe_dynamic_quant,要不以后就用一个字段吧,用户设定了is_quantized,并且是w4af8,你直接看权重的dtype和shape,对不上的话,就动态量化

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_quantized和moe_dynamic_quant,要不以后就用一个字段吧,用户设定了is_quantized,并且是w4af8,你直接看权重的dtype和shape,对不上的话,就动态量化

  1. is_quantized这个参数不止这里用到,组网里也要用它来判断weight_key_map里给的字段是“weight”还是“quant_weight”,框架里本来就有的,我觉得这里直接复用下也OK
  2. 这里保留了moe_dynamic_quant是因为想同时支持:1. 激活权重都动态量化(not is_quantized) 2. 权重静态激活动态(is_quantized and moe_dynamic_quant) 3. 激活权重都静态(is_quantized) 三种方式。以后W4FP8的激活都走动态量化了的话,那moe_dynamic_quant就不再需要了

setattr(
layer,
"up_gate_proj_in_scale_all_experts",
Expand All @@ -950,7 +987,7 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
)

# in_scales
if not layer.moe_quant_config.moe_dynamic_quant:
if layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr(
layer,
Expand All @@ -963,24 +1000,25 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
)

# weight_scales
setattr(
layer,
"up_gate_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
"down_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
if layer.is_quantized:
setattr(
layer,
"up_gate_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
"down_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)

def load_w4afp8_scale_weights(
self,
Expand All @@ -989,6 +1027,7 @@ def load_w4afp8_scale_weights(
state_dict: dict,
logical_expert_ids: paddle.Tensor,
ep_rank_to_expert_id_list: list,
dynamic_scale_weight_map: dict,
):
"""
Get w4afp8 weights from state dict and process them.
Expand Down Expand Up @@ -1095,7 +1134,7 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process
raise ValueError(f"scale {name} should not be none in w4a8 mode.")

# 2. Extract scale tensor from state dict
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
Expand All @@ -1110,11 +1149,14 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process
paddle.concat(up_gate_proj_in_scales_all_experts)
)

for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
if hasattr(layer, name):
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)
if not layer.is_quantized:
scale_weight_map = dynamic_scale_weight_map
else:
for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
if hasattr(layer, name):
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)

for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
# Special handling for moe w4afp8 dynamic quant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿建议不要写死吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿建议不要写死吧

那这个咋写呀,命令行设置的话目前没那么灵活

elif quant_config_name == "w4afp8":
quantization_config["dense_quant_type"] = "block_wise_fp8"
quantization_config["moe_quant_type"] = "w4afp8"
quantization_config["hadamard_block_size"] = 512
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/model_executor/layers/quantization/w4afp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase):
quantization config for weight 4bits and activation fp8
"""

def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None:
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
Expand All @@ -40,6 +40,7 @@ def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_bloc
self.quant_round_type = 1
self.is_permuted = is_permuted
self.hadamard_block_size = hadamard_block_size
self.is_quantized = is_quantized

def name(self) -> str:
return "w4afp8"
Expand All @@ -50,7 +51,8 @@ def from_config(cls, config: dict) -> "W4AFP8Config":
act_scale_dict = config.get("act_scale_dict", None)
is_permuted = config.get("is_permuted", True)
hadamard_block_size = config.get("hadamard_block_size", 128)
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size)
is_quantized = config.get("is_quantized", False)
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized)

def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):
Expand Down
Loading
Loading