|
19 | 19 | ) |
20 | 20 |
|
21 | 21 |
|
| 22 | +def normalize_backends(backends): |
| 23 | + """ |
| 24 | + Normalize backend names planned for deprecation and print warnings. |
| 25 | + Currently: |
| 26 | + - Replaces deprecated 'trtllm-gen-native' with 'trtllm-native'. |
| 27 | +
|
| 28 | + Args: |
| 29 | + backends: List of backend names |
| 30 | +
|
| 31 | + Returns: |
| 32 | + List of normalized backend names |
| 33 | + """ |
| 34 | + normalized = [] |
| 35 | + for backend in backends: |
| 36 | + if backend == "trtllm-gen-native": |
| 37 | + print( |
| 38 | + "[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. " |
| 39 | + ) |
| 40 | + normalized.append("trtllm-native") |
| 41 | + else: |
| 42 | + normalized.append(backend) |
| 43 | + return normalized |
| 44 | + |
| 45 | + |
22 | 46 | def run_attention_test(args): |
23 | 47 | """ |
24 | 48 | Run an attention test. |
@@ -66,7 +90,8 @@ def parse_attention_args(line, parser): |
66 | 90 | "cudnn", |
67 | 91 | "cutlass", |
68 | 92 | "trtllm-gen", |
69 | | - "trtllm-gen-native", |
| 93 | + "trtllm-native", |
| 94 | + "trtllm-gen-native", # Deprecated, will be removed in future |
70 | 95 | ], |
71 | 96 | help="Kernel backends to test. Default: fa2", |
72 | 97 | ) |
@@ -151,6 +176,10 @@ def parse_attention_args(line, parser): |
151 | 176 | ) |
152 | 177 |
|
153 | 178 | args = parser.parse_args(line) |
| 179 | + |
| 180 | + # Normalize backend names (handle deprecated names) |
| 181 | + args.backends = normalize_backends(args.backends) |
| 182 | + |
154 | 183 | if args.verbose >= 1: |
155 | 184 | print(f"[INFO] {args = }") |
156 | 185 | return args |
@@ -185,7 +214,7 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len |
185 | 214 | def testBatchDecodeWithPagedKVCacheWrapper(args): |
186 | 215 | """ |
187 | 216 | Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API. |
188 | | - Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native backends. |
| 217 | + Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native backends. |
189 | 218 |
|
190 | 219 | This test: |
191 | 220 | 1. Creates paged KV cache and query tensors |
@@ -490,7 +519,7 @@ def run_backend_wrapper(backend): |
490 | 519 | batch_offsets_q=ragged_q, |
491 | 520 | batch_offsets_o=ragged_q, |
492 | 521 | ) |
493 | | - elif backend == "trtllm-gen-native": |
| 522 | + elif backend == "trtllm-native": |
494 | 523 | return flashinfer.decode.trtllm_batch_decode_with_kv_cache( |
495 | 524 | query=q.contiguous(), |
496 | 525 | kv_cache=kv_cache, |
@@ -614,7 +643,7 @@ def run_backend_wrapper(backend): |
614 | 643 | def testBatchPrefillWithPagedKVCacheWrapper(args): |
615 | 644 | """ |
616 | 645 | Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API. |
617 | | - Supports fa2, fa3, trtllm-gen, trtllm-gen-native, and cudnn backends. |
| 646 | + Supports fa2, fa3, trtllm-gen, trtllm-native, and cudnn backends. |
618 | 647 |
|
619 | 648 | This test: |
620 | 649 | 1. Creates paged KV cache and query tensors for prefill |
@@ -697,13 +726,13 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): |
697 | 726 | remove_trtllm = True |
698 | 727 | if remove_trtllm: |
699 | 728 | backends.remove("trtllm-gen") |
700 | | - if "trtllm-gen-native" in backends: |
| 729 | + if "trtllm-native" in backends: |
701 | 730 | remove_trtllm_native = False |
702 | 731 | if not causal: |
703 | | - print("[INFO] trtllm-gen-native backend currently requires causal = True") |
| 732 | + print("[INFO] trtllm-native backend currently requires causal = True") |
704 | 733 | remove_trtllm_native = True |
705 | 734 | if remove_trtllm_native: |
706 | | - backends.remove("trtllm-gen-native") |
| 735 | + backends.remove("trtllm-native") |
707 | 736 |
|
708 | 737 | if "cutlass" in backends: |
709 | 738 | print("[INFO] CUTLASS backend does not support prefill. Skipping.") |
@@ -955,7 +984,7 @@ def run_backend_wrapper(backend): |
955 | 984 | batch_offsets_q=q_indptr, |
956 | 985 | batch_offsets_o=q_indptr, |
957 | 986 | )[0] |
958 | | - elif backend == "trtllm-gen-native": |
| 987 | + elif backend == "trtllm-native": |
959 | 988 | return flashinfer.prefill.trtllm_batch_context_with_kv_cache( |
960 | 989 | query=q, |
961 | 990 | kv_cache=kv_cache, |
@@ -1178,21 +1207,21 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): |
1178 | 1207 | remove_trtllm = True |
1179 | 1208 | if remove_trtllm: |
1180 | 1209 | backends.remove("trtllm-gen") |
1181 | | - if "trtllm-gen-native" in backends: |
| 1210 | + if "trtllm-native" in backends: |
1182 | 1211 | remove_trtllm_native = False |
1183 | 1212 | if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ |
1184 | 1213 | torch.float8_e4m3fn, |
1185 | 1214 | torch.float8_e5m2, |
1186 | 1215 | ]: |
1187 | | - print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.") |
| 1216 | + print("[INFO] trtllm-native backend does not support FP8. Skipping.") |
1188 | 1217 | remove_trtllm_native = True |
1189 | 1218 | if not (head_dim_qk == 192 and head_dim_vo == 128): |
1190 | 1219 | print( |
1191 | | - "[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128" |
| 1220 | + "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128" |
1192 | 1221 | ) |
1193 | 1222 | remove_trtllm_native = True |
1194 | 1223 | if remove_trtllm_native: |
1195 | | - backends.remove("trtllm-gen-native") |
| 1224 | + backends.remove("trtllm-native") |
1196 | 1225 |
|
1197 | 1226 | if len(backends) == 0: |
1198 | 1227 | print("[ERROR] No backends to test. Exiting.") |
@@ -1404,7 +1433,7 @@ def run_backend_wrapper(backend): |
1404 | 1433 | batch_offsets_stats=batch_offsets_stats, |
1405 | 1434 | is_cuda_graph_compatible=True, |
1406 | 1435 | )[0] |
1407 | | - elif backend == "trtllm-gen-native": |
| 1436 | + elif backend == "trtllm-native": |
1408 | 1437 | return flashinfer.prefill.trtllm_ragged_attention_deepseek( |
1409 | 1438 | query=q, |
1410 | 1439 | key=k, |
@@ -1538,7 +1567,7 @@ def run_backend_wrapper(backend): |
1538 | 1567 | def testBatchMLAPagedAttentionWrapper(args): |
1539 | 1568 | """ |
1540 | 1569 | Test BatchMLAPagedAttentionWrapper and equivalent APIs. |
1541 | | - Supports fa2, fa3, cutlass, and trtllm-gen-native. |
| 1570 | + Supports fa2, fa3, cutlass, and trtllm-native. |
1542 | 1571 |
|
1543 | 1572 | This test: |
1544 | 1573 | 1. Creates paged query and key-value cache tensors |
@@ -1634,15 +1663,15 @@ def testBatchMLAPagedAttentionWrapper(args): |
1634 | 1663 | remove_cutlass = True |
1635 | 1664 | if remove_cutlass: |
1636 | 1665 | backends.remove("cutlass") |
1637 | | - if "trtllm-gen-native" in backends: |
| 1666 | + if "trtllm-native" in backends: |
1638 | 1667 | remove_trtllm_native = False |
1639 | 1668 | if page_size not in [32, 64]: |
1640 | 1669 | print( |
1641 | | - "[INFO] trtllm-gen-native backend only supports page size 32 or 64. Skipping." |
| 1670 | + "[INFO] trtllm-native backend only supports page size 32 or 64. Skipping." |
1642 | 1671 | ) |
1643 | 1672 | remove_trtllm_native = True |
1644 | 1673 | if remove_trtllm_native: |
1645 | | - backends.remove("trtllm-gen-native") |
| 1674 | + backends.remove("trtllm-native") |
1646 | 1675 | if len(backends) == 0: |
1647 | 1676 | print("[ERROR] No backends to test. Exiting.") |
1648 | 1677 | return res |
@@ -1807,7 +1836,7 @@ def run_backend_wrapper(backend): |
1807 | 1836 | page_table=block_tables, |
1808 | 1837 | return_lse=False, |
1809 | 1838 | ) |
1810 | | - if backend == "trtllm-gen-native": |
| 1839 | + elif backend == "trtllm-native": |
1811 | 1840 | return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( |
1812 | 1841 | query=q.unsqueeze(1), |
1813 | 1842 | kv_cache=kv_cache.unsqueeze(1), |
|
0 commit comments