-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Closed
Closed
Copy link
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
onednn_mm crashes when consecutive matmuls of different dtypes but same M,K,N
I think, the oneDNN handler of the first matmul cached and reused for the second matmul which is of different dtype.
It only happens if all M,K,N are the same for the 2 matmuls.
Reproducer:
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <[email protected]>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from vllm import _custom_ops as ops
def onednn_gemm_test(m: int,
n: int,
k: int,
device: str = "cpu"):
## fp32
a_f32 = torch.rand((m, k), dtype=torch.float32, device=device)
b_f32 = torch.rand((n, k), dtype=torch.float32, device=device)
out_f32_ref = torch.nn.functional.linear(a_f32, b_f32)
handler_f32 = ops.create_onednn_mm(
b_f32.t(),
)
out_f32 = ops.onednn_mm(handler_f32, a_f32, bias=None)
torch.testing.assert_close(out_f32, out_f32_ref)
print("Done FP32", flush=True)
## bf16
a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b_bf16 = torch.rand((n, k), dtype=torch.bfloat16, device=device)
out_bf16_ref = torch.nn.functional.linear(a_bf16, b_bf16)
handler_bf16 = ops.create_onednn_mm(
b_bf16.t(),
)
out_bf16 = ops.onednn_mm(handler_bf16, a_bf16, bias=None)
torch.testing.assert_close(out_bf16, out_bf16_ref)
print("Done BF16", flush=True)
if __name__ == "__main__":
onednn_gemm_test(m=1024, n=1024, k=1024)
Reproducer output:
Done FP32
corrupted size vs. prev_size
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working