Skip to content

Commit f566d49

Browse files
authored
misc: Add XQA decode to microbenchmark for sm90 and sm120 (#2055)
<!-- .github/pull_request_template.md --> ## 📌 Description In #2001 , XQA decode kernels became available through `trtllm_batch_decode_with_kv_cache` on SM90 and SM120. Current PR adds the ability to benchmark through the microbenchmark. Example microbenchmark command and outputs before and after: ``` ### Before current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping. [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping. [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec ### After current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec [PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Standardized backend identifier to "trtllm-native" and expanded its support across benchmark routines and utilities. * Argument parsing now canonicalizes deprecated backend aliases and emits a deprecation warning when encountered. * **Documentation** * README and tool-facing messages updated to use the canonical backend name and include contextual notes about the change. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent adcc5dd commit f566d49

File tree

3 files changed

+68
-35
lines changed

3 files changed

+68
-35
lines changed

benchmarks/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including:
117117
| `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) |
118118
| `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. |
119119
| `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. |
120-
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-gen-native, cublas|
120+
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|
121121

122122
### Attention Flags
123123
| Flag | Description |
@@ -213,14 +213,14 @@ Legend:
213213
- cutlass: CUTLASS
214214
- trtllm: TensorRT-LLM
215215
- trtllm-gen: TensorRT-LLM (generic wrapper)
216-
- trtllm-gen-native: TensorRT-LLM (native API)
216+
- trtllm-native: TensorRT-LLM (native API)
217217
-->
218218
| Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 |
219219
|---------|-----|-----|-----|-----|-----|-------|-------|-------|
220-
| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn |
221-
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn |
222-
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn |
223-
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-gen-native | fa2, cutlass, trtllm-gen-native | fa2 |
220+
| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn |
221+
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn |
222+
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn |
223+
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 |
224224
| **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
225225
| **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
226226
| **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas |
@@ -238,4 +238,4 @@ Backend Legend:
238238
- cutlass: CUTLASS
239239
- trtllm: TensorRT-LLM
240240
- trtllm-gen: TensorRT-LLM
241-
- trtllm-gen-native: TensorRT-LLM (out-of-wrapper)
241+
- trtllm-native: TensorRT-LLM (out-of-wrapper)

benchmarks/routines/attention.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,30 @@
1919
)
2020

2121

22+
def normalize_backends(backends):
23+
"""
24+
Normalize backend names planned for deprecation and print warnings.
25+
Currently:
26+
- Replaces deprecated 'trtllm-gen-native' with 'trtllm-native'.
27+
28+
Args:
29+
backends: List of backend names
30+
31+
Returns:
32+
List of normalized backend names
33+
"""
34+
normalized = []
35+
for backend in backends:
36+
if backend == "trtllm-gen-native":
37+
print(
38+
"[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. "
39+
)
40+
normalized.append("trtllm-native")
41+
else:
42+
normalized.append(backend)
43+
return normalized
44+
45+
2246
def run_attention_test(args):
2347
"""
2448
Run an attention test.
@@ -66,7 +90,8 @@ def parse_attention_args(line, parser):
6690
"cudnn",
6791
"cutlass",
6892
"trtllm-gen",
69-
"trtllm-gen-native",
93+
"trtllm-native",
94+
"trtllm-gen-native", # Deprecated, will be removed in future
7095
],
7196
help="Kernel backends to test. Default: fa2",
7297
)
@@ -151,6 +176,10 @@ def parse_attention_args(line, parser):
151176
)
152177

153178
args = parser.parse_args(line)
179+
180+
# Normalize backend names (handle deprecated names)
181+
args.backends = normalize_backends(args.backends)
182+
154183
if args.verbose >= 1:
155184
print(f"[INFO] {args = }")
156185
return args
@@ -185,7 +214,7 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len
185214
def testBatchDecodeWithPagedKVCacheWrapper(args):
186215
"""
187216
Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API.
188-
Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native backends.
217+
Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native backends.
189218
190219
This test:
191220
1. Creates paged KV cache and query tensors
@@ -490,7 +519,7 @@ def run_backend_wrapper(backend):
490519
batch_offsets_q=ragged_q,
491520
batch_offsets_o=ragged_q,
492521
)
493-
elif backend == "trtllm-gen-native":
522+
elif backend == "trtllm-native":
494523
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
495524
query=q.contiguous(),
496525
kv_cache=kv_cache,
@@ -614,7 +643,7 @@ def run_backend_wrapper(backend):
614643
def testBatchPrefillWithPagedKVCacheWrapper(args):
615644
"""
616645
Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API.
617-
Supports fa2, fa3, trtllm-gen, trtllm-gen-native, and cudnn backends.
646+
Supports fa2, fa3, trtllm-gen, trtllm-native, and cudnn backends.
618647
619648
This test:
620649
1. Creates paged KV cache and query tensors for prefill
@@ -697,13 +726,13 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
697726
remove_trtllm = True
698727
if remove_trtllm:
699728
backends.remove("trtllm-gen")
700-
if "trtllm-gen-native" in backends:
729+
if "trtllm-native" in backends:
701730
remove_trtllm_native = False
702731
if not causal:
703-
print("[INFO] trtllm-gen-native backend currently requires causal = True")
732+
print("[INFO] trtllm-native backend currently requires causal = True")
704733
remove_trtllm_native = True
705734
if remove_trtllm_native:
706-
backends.remove("trtllm-gen-native")
735+
backends.remove("trtllm-native")
707736

708737
if "cutlass" in backends:
709738
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
@@ -955,7 +984,7 @@ def run_backend_wrapper(backend):
955984
batch_offsets_q=q_indptr,
956985
batch_offsets_o=q_indptr,
957986
)[0]
958-
elif backend == "trtllm-gen-native":
987+
elif backend == "trtllm-native":
959988
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
960989
query=q,
961990
kv_cache=kv_cache,
@@ -1178,21 +1207,21 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11781207
remove_trtllm = True
11791208
if remove_trtllm:
11801209
backends.remove("trtllm-gen")
1181-
if "trtllm-gen-native" in backends:
1210+
if "trtllm-native" in backends:
11821211
remove_trtllm_native = False
11831212
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
11841213
torch.float8_e4m3fn,
11851214
torch.float8_e5m2,
11861215
]:
1187-
print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.")
1216+
print("[INFO] trtllm-native backend does not support FP8. Skipping.")
11881217
remove_trtllm_native = True
11891218
if not (head_dim_qk == 192 and head_dim_vo == 128):
11901219
print(
1191-
"[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128"
1220+
"[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128"
11921221
)
11931222
remove_trtllm_native = True
11941223
if remove_trtllm_native:
1195-
backends.remove("trtllm-gen-native")
1224+
backends.remove("trtllm-native")
11961225

11971226
if len(backends) == 0:
11981227
print("[ERROR] No backends to test. Exiting.")
@@ -1404,7 +1433,7 @@ def run_backend_wrapper(backend):
14041433
batch_offsets_stats=batch_offsets_stats,
14051434
is_cuda_graph_compatible=True,
14061435
)[0]
1407-
elif backend == "trtllm-gen-native":
1436+
elif backend == "trtllm-native":
14081437
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
14091438
query=q,
14101439
key=k,
@@ -1538,7 +1567,7 @@ def run_backend_wrapper(backend):
15381567
def testBatchMLAPagedAttentionWrapper(args):
15391568
"""
15401569
Test BatchMLAPagedAttentionWrapper and equivalent APIs.
1541-
Supports fa2, fa3, cutlass, and trtllm-gen-native.
1570+
Supports fa2, fa3, cutlass, and trtllm-native.
15421571
15431572
This test:
15441573
1. Creates paged query and key-value cache tensors
@@ -1634,15 +1663,15 @@ def testBatchMLAPagedAttentionWrapper(args):
16341663
remove_cutlass = True
16351664
if remove_cutlass:
16361665
backends.remove("cutlass")
1637-
if "trtllm-gen-native" in backends:
1666+
if "trtllm-native" in backends:
16381667
remove_trtllm_native = False
16391668
if page_size not in [32, 64]:
16401669
print(
1641-
"[INFO] trtllm-gen-native backend only supports page size 32 or 64. Skipping."
1670+
"[INFO] trtllm-native backend only supports page size 32 or 64. Skipping."
16421671
)
16431672
remove_trtllm_native = True
16441673
if remove_trtllm_native:
1645-
backends.remove("trtllm-gen-native")
1674+
backends.remove("trtllm-native")
16461675
if len(backends) == 0:
16471676
print("[ERROR] No backends to test. Exiting.")
16481677
return res
@@ -1807,7 +1836,7 @@ def run_backend_wrapper(backend):
18071836
page_table=block_tables,
18081837
return_lse=False,
18091838
)
1810-
if backend == "trtllm-gen-native":
1839+
elif backend == "trtllm-native":
18111840
return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
18121841
query=q.unsqueeze(1),
18131842
kv_cache=kv_cache.unsqueeze(1),

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,43 +162,47 @@ def dtype_str_to_torch_dtype(dtype_str):
162162
routine_cc_to_supported_backends = {
163163
# ATTENTION
164164
"BatchDecodeWithPagedKVCacheWrapper": {
165+
# NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache
165166
"7.5": ["fa2"],
166167
"8.0": ["fa2", "fa2_tc", "cudnn"],
167168
"8.6": ["fa2", "fa2_tc", "cudnn"],
168169
"8.9": ["fa2", "fa2_tc", "cudnn"],
169-
"9.0": ["fa2", "fa2_tc", "cudnn"],
170-
"10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"],
171-
"10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"],
172-
"12.0": ["fa2", "fa2_tc", "cudnn"],
170+
"9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"],
171+
"10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"],
172+
"10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"],
173+
"12.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"],
173174
},
174175
"BatchPrefillWithPagedKVCacheWrapper": {
176+
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
175177
"7.5": [],
176178
"8.0": ["fa2", "cudnn"],
177179
"8.6": ["fa2", "cudnn"],
178180
"8.9": ["fa2", "cudnn"],
179181
"9.0": ["fa2", "fa3", "cudnn"],
180-
"10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"],
181-
"10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"],
182+
"10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
183+
"10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
182184
"12.0": ["fa2", "cudnn"],
183185
},
184186
"BatchPrefillWithRaggedKVCacheWrapper": {
187+
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
185188
"7.5": [],
186189
"8.0": ["fa2", "cudnn"],
187190
"8.6": ["fa2", "cudnn"],
188191
"8.9": ["fa2", "cudnn"],
189192
"9.0": ["fa2", "fa3", "cudnn"],
190-
"10.0": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"],
191-
"10.3": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"],
193+
"10.0": ["fa2", "cudnn", "cutlass", "trtllm-native"],
194+
"10.3": ["fa2", "cudnn", "cutlass", "trtllm-native"],
192195
"12.0": ["fa2", "cudnn"],
193196
},
194197
"BatchMLAPagedAttentionWrapper": {
198+
# NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla
195199
"7.5": [],
196200
"8.0": ["fa2"],
197201
"8.6": ["fa2"],
198202
"8.9": ["fa2"],
199203
"9.0": ["fa2", "fa3"],
200-
"10.0": ["fa2", "cutlass", "trtllm-gen-native"],
201-
"10.3": ["fa2", "cutlass", "trtllm-gen-native"],
204+
"10.0": ["fa2", "cutlass", "trtllm-native"],
205+
"10.3": ["fa2", "cutlass", "trtllm-native"],
202206
"12.0": ["fa2"],
203207
},
204208
# GEMM

0 commit comments

Comments
 (0)