From 25cefed9ac53f6aca9e9a8898cc44343688d8165 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Fri, 18 Jul 2025 16:52:52 +0000 Subject: [PATCH 1/6] working register_fake, but need direct_register_custom_op Signed-off-by: Randall Smith --- csrc/rocm/skinny_gemms.cu | 2 + vllm/model_executor/layers/utils.py | 58 +++++++++++++++++++++++++++-- vllm/utils/__init__.py | 2 + 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 6212570c79d1..e3feae758c91 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1232,6 +1232,8 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, auto K_in = in_a.size(1); auto N_in = in_b.size(0); + std::cout << "HELLO WVSPLITK\n"; + TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); TORCH_CHECK(in_a.dtype() == torch::kFloat16 || diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index ad4ba9c0b827..9f10b67698b0 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -9,6 +9,7 @@ from vllm import envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def get_token_bin_counts_and_mask( tokens: torch.Tensor, @@ -70,10 +71,11 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +@torch.library.custom_op("vllm::rocm_unquantized_gemm_impl", mutates_args=()) +def rocm_unquantized_gemm_impl(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # print(f"rocm_unquantized_gemm_impl") from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -96,6 +98,53 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) +# @rocm_unquantized_gemm_impl.register_fake +# def rocm_unquantized_gemm_impl_fake(x: torch.Tensor, + # weight: torch.Tensor, + # bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # out = torch.nn.functional.linear(x, weight, bias) + # # print(f"out.shape={out.shape}, x.shape={x.shape}, weight.shape={weight.shape}, out.device={out.device}, x.device={x.device}, weight.device={weight.device}") + # # import os + # # print(f"rocm_unquantized_gemm_impl_fake:pid={os.getpid()}") + # # out = torch.empty(x.shape[0], weight.shape[0], device=x.device) + # # from torchdistx.fake import fake_mode + # # out = torch.zeros((x.shape[0], weight.shape[0]), device=x.device) + # return out + +# def rocm_unquantized_gemm_impl_fake(x: torch.tensor, + # weight: torch.tensor, + # bias: optional[torch.tensor] = none) -> torch.tensor: + # out = torch.nn.functional.linear(x, weight, bias) + # return out + +@rocm_unquantized_gemm_impl.register_fake +def rocm_unquantized_gemm_impl_fake(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # out = torch.nn.functional.linear(x, weight, bias) + # print(f"out.shape={out.shape}, x.shape={x.shape}, weight.shape={weight.shape}, out.device={out.device}, x.device={x.device}, weight.device={weight.device}") + # import os + # print(f"rocm_unquantized_gemm_impl_fake:pid={os.getpid()}") + out = weight.new_empty([x.shape[0], weight.shape[0]]) + # out = torch.empty([x.shape[0], weight.shape[0]], device=weight.device) + + # from torchdistx.fake import fake_mode + # out = torch.zeros((x.shape[0], weight.shape[0]), device=x.device) + return out + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return rocm_unquantized_gemm_impl(x, weight, bias) + +# direct_register_custom_op( + # op_name="rocm_unquantized_gemm_impl", + # op_func=rocm_unquantized_gemm_impl, + # mutates_args=[], + # fake_impl=rocm_unquantized_gemm_impl_fake, + # dispatch_key=current_platform.dispatch_key, +# ) def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, @@ -114,3 +163,4 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: return cpu_unquantized_gemm else: return default_unquantized_gemm + diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bfdbd682464a..040a368c462f 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2481,6 +2481,8 @@ def direct_register_custom_op( import torch.library if hasattr(torch.library, "infer_schema"): + print(f"op_func={op_func}, type(op_func)={type(op_func)}") + print(f"vllm_lib={vllm_lib}") schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: From a35bcacb397768fa9d18d3a2a29bf80a84486439 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Sat, 19 Jul 2025 02:47:33 +0000 Subject: [PATCH 2/6] remove debug code Signed-off-by: Randall Smith --- csrc/rocm/skinny_gemms.cu | 1 - vllm/model_executor/layers/utils.py | 52 ++++++----------------------- 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index e3feae758c91..9cbb7ccda8dd 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1232,7 +1232,6 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, auto K_in = in_a.size(1); auto N_in = in_b.size(0); - std::cout << "HELLO WVSPLITK\n"; TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 9f10b67698b0..f5a2529f06e5 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -71,11 +71,9 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -@torch.library.custom_op("vllm::rocm_unquantized_gemm_impl", mutates_args=()) def rocm_unquantized_gemm_impl(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # print(f"rocm_unquantized_gemm_impl") from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -98,53 +96,25 @@ def rocm_unquantized_gemm_impl(x: torch.Tensor, return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) -# @rocm_unquantized_gemm_impl.register_fake -# def rocm_unquantized_gemm_impl_fake(x: torch.Tensor, - # weight: torch.Tensor, - # bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # out = torch.nn.functional.linear(x, weight, bias) - # # print(f"out.shape={out.shape}, x.shape={x.shape}, weight.shape={weight.shape}, out.device={out.device}, x.device={x.device}, weight.device={weight.device}") - # # import os - # # print(f"rocm_unquantized_gemm_impl_fake:pid={os.getpid()}") - # # out = torch.empty(x.shape[0], weight.shape[0], device=x.device) - # # from torchdistx.fake import fake_mode - # # out = torch.zeros((x.shape[0], weight.shape[0]), device=x.device) - # return out - -# def rocm_unquantized_gemm_impl_fake(x: torch.tensor, - # weight: torch.tensor, - # bias: optional[torch.tensor] = none) -> torch.tensor: - # out = torch.nn.functional.linear(x, weight, bias) - # return out - -@rocm_unquantized_gemm_impl.register_fake + def rocm_unquantized_gemm_impl_fake(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # out = torch.nn.functional.linear(x, weight, bias) - # print(f"out.shape={out.shape}, x.shape={x.shape}, weight.shape={weight.shape}, out.device={out.device}, x.device={x.device}, weight.device={weight.device}") - # import os - # print(f"rocm_unquantized_gemm_impl_fake:pid={os.getpid()}") - out = weight.new_empty([x.shape[0], weight.shape[0]]) - # out = torch.empty([x.shape[0], weight.shape[0]], device=weight.device) - - # from torchdistx.fake import fake_mode - # out = torch.zeros((x.shape[0], weight.shape[0]), device=x.device) - return out + return weight.new_empty([x.shape[0], weight.shape[0]]) def rocm_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - return rocm_unquantized_gemm_impl(x, weight, bias) - -# direct_register_custom_op( - # op_name="rocm_unquantized_gemm_impl", - # op_func=rocm_unquantized_gemm_impl, - # mutates_args=[], - # fake_impl=rocm_unquantized_gemm_impl_fake, - # dispatch_key=current_platform.dispatch_key, -# ) + return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + +direct_register_custom_op( + op_name="rocm_unquantized_gemm_impl", + op_func=rocm_unquantized_gemm_impl, + mutates_args=[], + fake_impl=rocm_unquantized_gemm_impl_fake, + dispatch_key=current_platform.dispatch_key, +) def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, From 0044390abf6e459803ad6a010c3ae02109e2346b Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 22 Jul 2025 04:54:13 +0000 Subject: [PATCH 3/6] formatting Signed-off-by: Randall Smith --- csrc/rocm/skinny_gemms.cu | 1 - vllm/model_executor/layers/utils.py | 20 ++++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 9cbb7ccda8dd..6212570c79d1 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1232,7 +1232,6 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, auto K_in = in_a.size(1); auto N_in = in_b.size(0); - TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); TORCH_CHECK(in_a.dtype() == torch::kFloat16 || diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index f5a2529f06e5..eec92ab3cf6a 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,9 +8,9 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import current_platform - from vllm.utils import direct_register_custom_op + def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -71,9 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def rocm_unquantized_gemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -97,17 +98,20 @@ def rocm_unquantized_gemm_impl(x: torch.Tensor, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl_fake(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def rocm_unquantized_gemm_impl_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return weight.new_empty([x.shape[0], weight.shape[0]]) + def rocm_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + direct_register_custom_op( op_name="rocm_unquantized_gemm_impl", op_func=rocm_unquantized_gemm_impl, @@ -116,6 +120,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, dispatch_key=current_platform.dispatch_key, ) + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, @@ -133,4 +138,3 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: return cpu_unquantized_gemm else: return default_unquantized_gemm - From d22d7c7106f770ca0c7d9ac12b8e1b7cb32cd5ef Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 22 Jul 2025 04:55:34 +0000 Subject: [PATCH 4/6] revert __init__ Signed-off-by: Randall Smith --- vllm/utils/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 040a368c462f..bfdbd682464a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2481,8 +2481,6 @@ def direct_register_custom_op( import torch.library if hasattr(torch.library, "infer_schema"): - print(f"op_func={op_func}, type(op_func)={type(op_func)}") - print(f"vllm_lib={vllm_lib}") schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: From e355d2c71f6dc345a7841df7c28e4bfd97426181 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 22 Jul 2025 05:15:35 +0000 Subject: [PATCH 5/6] fake tensor Signed-off-by: Randall Smith --- vllm/model_executor/layers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index eec92ab3cf6a..447c8ff9dca8 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -102,7 +102,7 @@ def rocm_unquantized_gemm_impl_fake( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return weight.new_empty([x.shape[0], weight.shape[0]]) + return x.new_empty((*x.shape[:-1], weight.shape[0])) def rocm_unquantized_gemm(layer: torch.nn.Module, From 1537e590441703de356cdc3abcd940e19e7779d4 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 22 Jul 2025 16:51:54 +0000 Subject: [PATCH 6/6] add return type Signed-off-by: Randall Smith --- vllm/model_executor/layers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 447c8ff9dca8..cd32f12f3c26 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -108,7 +108,7 @@ def rocm_unquantized_gemm_impl_fake( def rocm_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)