diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 552d9e9cf88f..d131f69ea989 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -5,12 +5,14 @@ import torch from torch.nn import functional as F +from vllm import _custom_ops as ops from vllm import envs def silu_and_mul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] + output = torch.empty_like(x[..., : x.shape[-1] // 2]) + torch.ops._C.silu_and_mul(output, x) + return output def swigluoai_and_mul( @@ -237,7 +239,43 @@ def __call__( class CPUFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: - pass + use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() + + num_experts = layer.w13_weight.size(0) + has_w13_bias = hasattr(layer, "w13_bias") + has_w2_bias = hasattr(layer, "w2_bias") + + layer.gate_up_linear = [] + layer.down_linear = [] + + for i in range(num_experts): + layer_w13_weight = layer.w13_weight[i] + layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None + layer_w2_weight = layer.w2_weight[i] + layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None + if use_onednn_mm: + gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32) + layer.gate_up_linear.append( + lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm( + handle, x, bias + ) + ) + down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32) + layer.down_linear.append( + lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm( + handle, x, bias + ) + ) + else: + layer.gate_up_linear.append( + lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b) + ) + layer.down_linear.append( + lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) + ) + if use_onednn_mm: # remove weight + layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) def __call__( self, @@ -287,8 +325,6 @@ def __call__( outputs = [] start_idx = 0 - has_w13_bias = hasattr(layer, "w13_bias") - has_w2_bias = hasattr(layer, "w2_bias") for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens @@ -296,19 +332,12 @@ def __call__( continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - layer_w13_weight = layer.w13_weight[i] - layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None - layer_w2_weight = layer.w2_weight[i] - layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None - - gate_up = F.linear( - tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias - ) + gate_up = layer.gate_up_linear[i](tokens_for_this_expert) if activation == "swigluoai": gate_up = swigluoai_and_mul(gate_up) else: gate_up = silu_and_mul(gate_up) - expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias) + expert_out = layer.down_linear[i](gate_up) outputs.append(expert_out) start_idx = end_idx