Skip to content

Commit 2099708

Browse files
authored
[BugFix] BF16 MoE Cutlass Backend Support EP (#5242)
1 parent ba915e0 commit 2099708

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ def override_name_from_config(self):
304304

305305
if hasattr(self, "num_experts") and getattr(self, "moe_num_experts") is None:
306306
self.moe_num_experts = self.num_experts
307+
if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None:
308+
self.moe_num_experts = self.n_routed_experts
307309

308310
def read_from_env(self):
309311
"""

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,10 @@ def apply_ep_prefill(
206206
tmp_ffn_out = recv_x
207207

208208
# 4. EP combine
209-
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
209+
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
210+
if self.ep_prefill_runner.ep_engine.async_finish:
211+
event.current_stream_wait()
212+
return tmp_ffn_out
210213

211214
def apply_ep_decode(
212215
self,
@@ -242,7 +245,7 @@ def apply_ep_decode(
242245
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
243246
num_local_experts, max_num, _ = permute_input.shape
244247
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
245-
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
248+
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4", "w16a16"]:
246249
expert_idx_per_token = None
247250
else:
248251
raise NotImplementedError

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def apply(
808808
N=hidden_size,
809809
K=moe_intermediate_size,
810810
stride_am=x_q.strides[0],
811-
stride_ak=x_scale.strides[1],
811+
stride_ak=x_q.strides[1],
812812
stride_be=layer.down_proj_weight.strides[0],
813813
stride_bk=layer.down_proj_weight.strides[2],
814814
stride_bn=layer.down_proj_weight.strides[1],

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,20 @@ def compute_logits(self, hidden_states: paddle.Tensor):
494494

495495
return logits
496496

497+
def empty_input_forward(self):
498+
"""
499+
empty_input_forward
500+
"""
501+
fake_hidden_states = paddle.ones(
502+
shape=[1, self.fd_config.model_config.hidden_size],
503+
dtype=paddle.get_default_dtype(),
504+
)
505+
for i in range(
506+
self.fd_config.model_config.first_k_dense_replace,
507+
self.fd_config.model_config.num_hidden_layers,
508+
):
509+
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
510+
497511
def forward(
498512
self,
499513
ids_remove_padding: paddle.Tensor,

0 commit comments

Comments
 (0)