-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[CPU]Improve cpu fused moe perf #27244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,14 @@ | |
| import torch | ||
| from torch.nn import functional as F | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm import envs | ||
|
|
||
|
|
||
| def silu_and_mul(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| return F.silu(x[..., :d]) * x[..., d:] | ||
| output = torch.empty_like(x[..., : x.shape[-1] // 2]) | ||
| torch.ops._C.silu_and_mul(output, x) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this change needed here? is this faster than doing pytorch's |
||
| return output | ||
|
|
||
|
|
||
| def swigluoai_and_mul( | ||
|
|
@@ -237,7 +239,43 @@ def __call__( | |
|
|
||
| class CPUFusedMOE: | ||
| def __init__(self, layer: torch.nn.Module) -> None: | ||
| pass | ||
| use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add the perf implications of this PR to the description? |
||
|
|
||
| num_experts = layer.w13_weight.size(0) | ||
| has_w13_bias = hasattr(layer, "w13_bias") | ||
| has_w2_bias = hasattr(layer, "w2_bias") | ||
|
|
||
| layer.gate_up_linear = [] | ||
| layer.down_linear = [] | ||
|
|
||
| for i in range(num_experts): | ||
| layer_w13_weight = layer.w13_weight[i] | ||
| layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None | ||
| layer_w2_weight = layer.w2_weight[i] | ||
| layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None | ||
| if use_onednn_mm: | ||
| gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32) | ||
| layer.gate_up_linear.append( | ||
| lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm( | ||
| handle, x, bias | ||
| ) | ||
| ) | ||
| down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32) | ||
| layer.down_linear.append( | ||
| lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm( | ||
| handle, x, bias | ||
| ) | ||
| ) | ||
|
Comment on lines
+257
to
+268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To prevent this, the |
||
| else: | ||
| layer.gate_up_linear.append( | ||
| lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b) | ||
| ) | ||
| layer.down_linear.append( | ||
| lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) | ||
| ) | ||
| if use_onednn_mm: # remove weight | ||
| layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) | ||
| layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) | ||
|
|
||
| def __call__( | ||
| self, | ||
|
|
@@ -287,28 +325,19 @@ def __call__( | |
|
|
||
| outputs = [] | ||
| start_idx = 0 | ||
| has_w13_bias = hasattr(layer, "w13_bias") | ||
| has_w2_bias = hasattr(layer, "w2_bias") | ||
|
|
||
| for i, num_tokens in enumerate(tokens_per_expert): | ||
| end_idx = start_idx + num_tokens | ||
| if num_tokens == 0: | ||
| continue | ||
| tokens_for_this_expert = sorted_tokens[start_idx:end_idx] | ||
|
|
||
| layer_w13_weight = layer.w13_weight[i] | ||
| layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None | ||
| layer_w2_weight = layer.w2_weight[i] | ||
| layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None | ||
|
|
||
| gate_up = F.linear( | ||
| tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias | ||
| ) | ||
| gate_up = layer.gate_up_linear[i](tokens_for_this_expert) | ||
| if activation == "swigluoai": | ||
| gate_up = swigluoai_and_mul(gate_up) | ||
| else: | ||
| gate_up = silu_and_mul(gate_up) | ||
| expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias) | ||
| expert_out = layer.down_linear[i](gate_up) | ||
| outputs.append(expert_out) | ||
| start_idx = end_idx | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a similar comment on your 4bit PR
what happens if this assert fails
vllm/csrc/cpu/activation.cpp
Line 11 in 4f882be
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we crash because the activation kernel does not have a tail loop.
We should address this, but i don't think this particular concern should block this PR.