Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import torch
from torch.nn import functional as F

from vllm import _custom_ops as ops
from vllm import envs


def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
output = torch.empty_like(x[..., : x.shape[-1] // 2])
torch.ops._C.silu_and_mul(output, x)
return output


def swigluoai_and_mul(
Expand Down Expand Up @@ -237,7 +239,43 @@ def __call__(

class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()

num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias")

layer.gate_up_linear = []
layer.down_linear = []

for i in range(num_experts):
layer_w13_weight = layer.w13_weight[i]
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
layer_w2_weight = layer.w2_weight[i]
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
if use_onednn_mm:
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
layer.gate_up_linear.append(
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
handle, x, bias
)
)
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
layer.down_linear.append(
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
handle, x, bias
)
)
Comment on lines +257 to +268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CPUDNNLGEMMHandler objects created by ops.create_onednn_mm contain pointers to C++ state and are not serializable. Storing lambdas that capture these handlers in layer.gate_up_linear and layer.down_linear will cause issues if the model is serialized (e.g., with pickle or torch.save). Upon deserialization, the handler pointers will be invalid, which can lead to segmentation faults when the model is used or garbage collected.

To prevent this, the CPUDNNLGEMMHandler class should be made non-picklable by implementing __getstate__ to raise an exception. Since that class is not in this file, an alternative is to avoid storing these handlers on the torch.nn.Module instance if model serialization is a possibility. If serialization is not a supported use case for CPU-based models, this might be acceptable, but it's a significant risk.

else:
layer.gate_up_linear.append(
lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b)
)
layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
)
if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)

def __call__(
self,
Expand Down Expand Up @@ -287,28 +325,19 @@ def __call__(

outputs = []
start_idx = 0
has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias")

for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]

layer_w13_weight = layer.w13_weight[i]
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
layer_w2_weight = layer.w2_weight[i]
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None

gate_up = F.linear(
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
)
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
if activation == "swigluoai":
gate_up = swigluoai_and_mul(gate_up)
else:
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx

Expand Down