update vllm kernel benchmark scripts#176
update vllm kernel benchmark scripts#1761pikachu wants to merge 43 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Updates and expands XPU kernel benchmark scripts to better support regular validation (correctness checks + perf reporting), with more consistent runtime configuration logging.
Changes:
- Added several new benchmark scripts for FP8 quant/GEMM, rotary embedding, MoE kernels, FlashAttention varlen, and MLA concat/cache.
- Updated existing benchmarks to print the active benchmark configuration/provider for easier CI log triage.
- Simplified RMSNorm benchmarking by removing optional IPEX path and expanding the config grid.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| benchmark/benchmark_topk.py | Adds per-run config/provider logging for CI traceability. |
| benchmark/benchmark_swigluoai_and_mul.py | Adds per-run config/provider logging for CI traceability. |
| benchmark/benchmark_static_scaled_fp8_quant.py | New correctness + perf benchmark for static scaled FP8 quant. |
| benchmark/benchmark_rotary_embedding.py | New correctness + perf benchmark for rotary embedding native vs vLLM paths. |
| benchmark/benchmark_rmsnorm.py | Removes IPEX path, expands config space, adds logging, refactors “naive” naming. |
| benchmark/benchmark_reshape_and_cache.py | Adds per-run config/provider logging for CI traceability. |
| benchmark/benchmark_moe_sum.py | New correctness + perf benchmark for moe_sum op. |
| benchmark/benchmark_moe_align_block_size.py | New correctness + perf benchmarks + opchecks for moe_align_block_size variants. |
| benchmark/benchmark_lora.py | Fixes imports to use benchmark.* package paths. |
| benchmark/benchmark_grouped_topk.py | Adds per-run config/provider logging for CI traceability. |
| benchmark/benchmark_fp8_gemm_w8a16.py | New correctness + perf benchmark for fp8_gemm_w8a16. |
| benchmark/benchmark_dynamic_per_token_scaled_fp8_quant.py | New correctness + perf benchmark for dynamic per-token FP8 quant. |
| benchmark/benchmark_cutlass_fused_moe.py | New correctness + perf benchmark for CUTLASS fused MoE vs reference. |
| benchmark/benchmark_cutlass_flash_attn_varlen.py | New correctness + perf benchmark for FlashAttention varlen vs native reference. |
| benchmark/benchmark_concat_and_cache_mla.py | New benchmark comparing torch.cat vs direct copy for MLA concat. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| kv_lens=kv_lens, | ||
| block_tables=block_tables, | ||
| scale=scale, | ||
| casual=is_causal, |
There was a problem hiding this comment.
ref_paged_attn is called with keyword argument casual, which is very likely a typo for causal. If ref_paged_attn doesn’t accept casual, this will raise a TypeError and break the benchmark/correctness run. Rename the keyword to causal (or match the exact parameter name in ref_paged_attn).
| casual=is_causal, | |
| causal=is_causal, |
| kv_lens=kv_lens, | ||
| block_tables=block_tables, | ||
| scale=scale, | ||
| casual=is_causal, |
There was a problem hiding this comment.
Same issue as above: casual is likely an invalid keyword for ref_paged_attn and will crash at runtime. Use the correct keyword (likely causal).
| casual=is_causal, | |
| causal=is_causal, |
| positions = torch.randint(0, | ||
| max_position, (batch_size, seq_len), | ||
| device=device) | ||
| head_stride = head_size + (64 if head_stride_is_contiguous else 0) |
There was a problem hiding this comment.
The head_stride_is_contiguous flag appears inverted here: when the stride is contiguous, head_stride typically equals head_size (no padding). Adding 64 when head_stride_is_contiguous is True makes the tensor less contiguous and undermines the intended layout coverage. Swap the condition so extra padding is added only when testing the non-contiguous stride case.
| head_stride = head_size + (64 if head_stride_is_contiguous else 0) | |
| head_stride = head_size + (0 if head_stride_is_contiguous else 64) |
| dtype=dtype) | ||
| else: | ||
| key_cache = torch.randn(sum(kv_lens), | ||
| num_query_heads, |
There was a problem hiding this comment.
In the non-paged KV path, key_cache is generated with num_query_heads. KV cache tensors are typically shaped with num_kv_heads (and the code already distinguishes num_query_heads vs num_kv_heads). Generating KV with query heads can make the reference and kernel consume mismatched shapes/semantics. Use num_kv_heads here (and ensure the reference path expects the same).
| num_query_heads, | |
| num_kv_heads, |
| args = parse_args() | ||
| seed_everything(0) | ||
|
|
||
| num_toknes = [1, 7, 83, 4096] |
There was a problem hiding this comment.
Correct the typo in the variable name num_toknes to num_tokens for readability and consistency (it’s printed as num_tokens and used to build configs).
| fp8_dtype = [torch.float8_e4m3fn] | ||
| group_shape = [(1, -1), (-1, 1)] | ||
| print("Final configuration:") | ||
| print(f" num_tokens: {num_toknes}") |
There was a problem hiding this comment.
Correct the typo in the variable name num_toknes to num_tokens for readability and consistency (it’s printed as num_tokens and used to build configs).
| print(f" group_shape: {group_shape}") | ||
|
|
||
| configs = list( | ||
| itertools.product(num_toknes, hidden_size, dtype, fp8_dtype, group_shape)) |
There was a problem hiding this comment.
Correct the typo in the variable name num_toknes to num_tokens for readability and consistency (it’s printed as num_tokens and used to build configs).
| from tests.utils import parse_args, opcheck, round_up, seed_everything | ||
| from tests.test_moe_align_block_size import torch_moe_align_block_size, _verify_expert_level_sorting | ||
|
|
||
| seed_everything(0) |
There was a problem hiding this comment.
Calling seed_everything(0) at import time introduces a module import side-effect (and can interfere with other benchmarks/tests when this file is imported). Move seeding under the if __name__ == \"__main__\": guard (or into the specific routines) so importing the module doesn’t alter global RNG state.
benchmark/benchmark_rmsnorm.py
Outdated
| print(f"native output={output_native}") | ||
| print(f"vLLM output={output_vllm}") |
There was a problem hiding this comment.
Printing full output tensors for every correctness config can massively slow down benchmark runs and flood CI logs (especially with the expanded config grid). Consider printing only on mismatch (or printing summary stats like max/mean abs diff) and keeping the config identifier for debugging.
Purpose
update kernel benchmark scripts to set up regular validation pipelines