From 0367f1bbc474c075bb83308d8f2cebea05c5e059 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:09:28 -0700 Subject: [PATCH 1/9] support FP8 TRTLLM attn kernel Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 2 + .../benchmark_trtllm_decode_attention.py | 296 +++++++++-------- .../benchmark_trtllm_prefill_attention.py | 243 ++++++++------ tests/compile/test_fusion_attn.py | 253 ++++++++++++++- .../test_flashinfer_trtllm_attention.py | 300 ++++++++++-------- vllm/attention/backends/flashinfer.py | 5 +- vllm/attention/layer.py | 11 +- vllm/compilation/fusion_attn.py | 101 +++--- vllm/utils/flashinfer.py | 32 +- vllm/v1/attention/backends/flashinfer.py | 102 ++++-- 10 files changed, 897 insertions(+), 448 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4fc885785492..8f8cf342aa84 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -649,6 +649,7 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -665,6 +666,7 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca45b..694e91aa8d0a 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,149 +24,182 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len - - # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + sm_scale = float(1.0 / (head_size**0.5)) + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len + + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) - - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) - k_scale = v_scale = 1.0 - - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - def time_fn(fn, warmup=10, trials=20): - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - times = [] - for i in range(warmup): - fn() - for i in range(trials): - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) # ms - return sum(times) / len(times), torch.std(torch.tensor(times)) - - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] + for i in range(batch_size): + seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), ) - wrapper.plan( kv_indptr, kv_indices, kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, "NONE", + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, + kv_data_type=dtype, ) + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output_baseline) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) + + # if o_quant_dtype == FP8_DTYPE: + # output_trtllm = output_trtllm.to(dtype) * o_scale + + # if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + # rtol, atol = 5e-2, 7e-2 + # else: + # rtol, atol = 1e-2, 1e-2 + + # torch.testing.assert_close( + # output_baseline, output_trtllm, atol=atol, rtol=rtol), \ + # f"{torch.max(torch.abs(output_baseline - output_trtllm))}" # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +211,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +241,41 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbcca..c0f10b387f00 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + max_q_len = max_kv_len = max_seq_len - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) + sm_scale = float(1.0 / (head_size**0.5)) - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN - - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) - - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) - k_scale = v_scale = 1.0 - - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +128,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +151,69 @@ def time_fn(fn, warmup=10, trials=20): times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_prefill(): - return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline - ) + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) - def trt_prefill(): + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output_baseline) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) + + # if o_quant_dtype == FP8_DTYPE: + # output_trtllm = output_trtllm.to(dtype) * o_scale + + # if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + # rtol, atol = 5e-2, 7e-2 + # else: + # rtol, atol = 1e-2, 1e-2 + + # torch.testing.assert_close( + # output_baseline, output_trtllm, atol=atol, rtol=rtol), \ + # f"{torch.max(torch.abs(output_baseline - output_trtllm))}" # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +225,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +255,41 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 70750eb9ac4e..bf6804a42245 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from typing import Optional import pytest @@ -8,12 +9,24 @@ from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal from vllm import LLM, SamplingParams +from vllm.attention import Attention from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + ModelConfig, PassConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform +from vllm.v1.attention.backends.flashinfer import FlashInferMetadataBuilder +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -132,3 +145,241 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # Reset backend to make sure llm2 gets released backend = None + + +class TestAttentionStaticQuantPatternModel(torch.nn.Module): + """Test model for AttentionStaticQuantPattern fusion.""" + + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig): + super().__init__() + self.num_qo_heads = num_qo_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + self.attn = Attention( + num_heads=self.num_qo_heads, + head_size=self.head_size, + scale=1.0 / (self.head_size**0.5), + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix="model.layers.0.self_attn.attn", + ) + + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_fp8_scale = torch.tensor([1.0], dtype=torch.float32) + + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + query_start_loc = torch.arange(0, + batch_size + 1, + dtype=torch.int32, + device=self.device) + seq_lens = torch.ones(batch_size, + dtype=torch.int32, + device=self.device) + + # Create simple block table and slot mapping for testing + block_size = 16 + num_tokens = batch_size # num_tokens = batch_size for simplicity + num_blocks = max(1, (num_tokens + block_size - 1) // block_size) + block_table = torch.arange(num_blocks, + dtype=torch.int32, + device=self.device).unsqueeze(0).repeat( + batch_size, 1) + slot_mapping = torch.arange(batch_size, + dtype=torch.long, + device=self.device) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.cpu(), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + num_computed_tokens_cpu=torch.zeros(batch_size, dtype=torch.int32), + num_reqs=batch_size, + num_actual_tokens=batch_size, + max_query_len=1, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + ) + + # Mock the KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Initialize FlashInferMetadataBuilder + builder = FlashInferMetadataBuilder( + kv_cache_spec=AttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + use_mla=False, + ), + layer_names=[self.attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + # Build FlashInferMetadata + self.attn_metadata = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata) + + return self.attn_metadata + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + attn_output = attn_output.view(-1, self.num_qo_heads * self.head_size) + output, _ = self.quant_fp8(attn_output, self.quant_fp8_scale) + return output + + +@pytest.mark.parametrize("num_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("batch_size", [7, 256, 533]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "model_quant_dtype", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", FP8_DTYPE)]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), + reason="Only test on SM100(Blackwell)") +def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, + batch_size: int, dtype: torch.dtype, + model_quant_dtype: tuple[str, torch.dtype], + monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass with FlashInfer backend""" + + # Enable FlashInfer v1 backend for this test + monkeypatch.setenv("VLLM_USE_V1", "1") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + monkeypatch.setenv("VLLM_USE_TRTLLM_ATTN", "1") + + device = torch.device("cuda:0") + torch.manual_seed(42) + + num_qo_heads, num_kv_heads = num_heads + model_name, quant_dtype = model_quant_dtype + + quant_op = None + if quant_dtype == FP8_DTYPE: + quant_op = QUANT_OPS[kFp8StaticTensorSym] + else: + raise ValueError(f"Unsupported quant_dtype: {quant_dtype}") + + vllm_config = VllmConfig( + model_config=ModelConfig( + model=model_name, + max_model_len=2048, + ), + scheduler_config=SchedulerConfig(max_num_seqs=1024), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+quant_fp8"], + full_cuda_graph=True, + ), + ) + + # Create test inputs + q = torch.randn(batch_size, + num_qo_heads, + head_size, + dtype=dtype, + device=device) + k = torch.randn(batch_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + v = torch.randn(batch_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + + # Mark first dimension as dynamic for realistic testing + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(k, 0) + torch._dynamo.mark_dynamic(v, 0) + + # Run model directly without compilation and fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with set_current_vllm_config(vllm_config_unfused), set_forward_context( + attn_metadata=None, vllm_config=vllm_config_unfused): + model_unfused = TestAttentionStaticQuantPatternModel( + num_qo_heads, num_kv_heads, head_size, dtype, device, + vllm_config_unfused) + model_unfused = model_unfused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata( + batch_size) + + # Run model directly without compilation and fusion + result_unfused = model_unfused(q, k, v) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + enable_attn_fusion=True, enable_noop=True) + vllm_config.cache_config = CacheConfig(cache_dtype="fp8") + with set_current_vllm_config(vllm_config), set_forward_context( + attn_metadata=None, vllm_config=vllm_config): + model_fused = TestAttentionStaticQuantPatternModel( + num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, + vllm_config) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes enabled + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw + ) + backend = TestBackend(noop_pass, attn_pass) + + # Compile model with fusion enabled + model_compiled = torch.compile(model_fused, + backend=backend, + fullgraph=True) + result_fused = model_compiled(q, k, v) + + # Check quantization ops in the graph before and after fusion + backend.check_before_ops([quant_op], fully_replaced=True) + + # Check attention ops in the graph before and after fusion + attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), \ + "Should have same number of attention nodes before and after fusion" + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + "Attention should not have output_scale before fusion" + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + "Attention should have output_scale after fusion" + + # Check that results are closed + torch.testing.assert_close(result_unfused.to(dtype), + result_fused.to(dtype), + atol=1e-2, + rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 4b84e6a00ece..619822f3ee43 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -13,21 +13,7 @@ allow_module_level=True) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -MAX_Q_LEN = 1024 -MAX_KV_LEN = 4096 -BATCH_SIZES = [4, 12] -NUM_HEADS = [(16, 16), (40, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -KV_LAYOUTS = ["HND"] -DTYPES = [torch.bfloat16] -KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 50.0] +FP8_DTYPE = current_platform.fp8_dtype() def to_float8(x, dtype=torch.float8_e4m3fn): @@ -39,42 +25,59 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +DTYPE = [torch.bfloat16] +QUANT_DTYPES = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), +] +BATCH_SIZE = [4, 12] +MAX_SEQ_LENS = [(1024, 4096)] +NUM_HEADS = [(64, 8), (40, 8)] +HEAD_SIZE = [128] +KV_LAYOUT = ["HND"] # currently only HND is supported +BLOCK_SIZE = [16] +SOFT_CAP = [None, 50.0] + +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - torch.set_default_device("cuda") current_platform.seed_everything(0) - kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - max_kv_len = torch.max(kv_lens).item() - num_seqs = len(kv_lens) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + _, max_kv_len = max_seq_lens - scale = head_size**-0.5 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -83,23 +86,40 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] + for i in range(batch_size): + seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) @@ -112,103 +132,93 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Decode wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, - use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4)) wrapper.plan(kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, "NONE", - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) # TRTLLM Decode - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - output_trtllm = torch.empty(query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=kv_lens_tensor, - max_seq_len=max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("soft_cap", [None]) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if dtype != kv_cache_dtype: - pytest.skip(f"Not supported dtype({dtype}) with " - "kv_cache_dtype({kv_cache_dtype})") - torch.set_default_device("cuda") current_platform.seed_everything(0) - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) - - kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - seq_lens = kv_lens + q_lens - max_seq_len = torch.max(seq_lens).item() - num_seqs = len(seq_lens) + max_q_len, max_kv_len = max_seq_lens - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - scale = head_size**-0.5 - - query = torch.randn(torch.sum(q_lens).item(), - num_query_heads, - head_size, - dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -217,22 +227,49 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) + + q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + query = torch.randn(torch.sum(q_lens).item(), + num_qo_heads, + head_size, + dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size @@ -246,48 +283,55 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, causal=True, - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) - # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=dtype) + # TRTLLM Prefill + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a85ec2463283..6b8691dff641 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1053,10 +1053,11 @@ def forward( assert decode_meta.decode_wrapper._sm_scale == softmax_scale # TODO: @pavanimajety Remove this once the switch happens # inside flashinfer. + # see https://github.com/flashinfer-ai/flashinfer/issues/1493 if not use_trtllm_attention( + attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, num_decode_tokens, attn_metadata.max_decode_seq_len, - kv_cache_dtype, attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, attn_metadata.head_dim): + kv_cache_dtype): decode_meta.decode_wrapper.run( decode_query, kv_cache.permute(*stride_order), diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0e87fa3f23e3..04ab100c8775 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -128,11 +128,17 @@ def __init__( self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size @@ -291,6 +297,7 @@ def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index a40a8caf34a8..992467761c76 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -9,7 +9,7 @@ unset_fake_temporarily) from vllm.attention import Attention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -18,28 +18,36 @@ logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default class AttentionStaticQuantPattern: + """ + Fusion for Attention+StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the StaticQuant + op will be removed from the graph, and its scale will be passed into + Attention op as the `output_scale` argument. + """ def __init__( self, - layer_name: str, - num_heads: int, - head_size: int, + layer: Attention, quant_dtype: torch.dtype, - symmetric=True, ): - self.layer_name = layer_name - self.num_heads = num_heads - self.head_size = head_size + self.layer = layer + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.head_size = layer.head_size self.quant_dtype = quant_dtype self.quant_key = QuantKey(dtype=quant_dtype, static=True, group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + symmetric=True) assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -48,51 +56,48 @@ def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) - def register_if_supported(self, pm_pass: PatternMatcherPass, - layer: Attention): - if layer.impl.fused_output_quant_supported(self.quant_dtype, - self.quant_key.static, - self.quant_key.group_shape): + def register_if_supported(self, pm_pass: PatternMatcherPass): + if self.layer.impl.fused_output_quant_supported( + self.quant_dtype, self.quant_key.static, + self.quant_key.group_shape): self._register(pm_pass) def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): - view_7 = RESHAPE_OP(output_attn, - [-1, self.num_heads, self.head_size]) - + attn_output: torch.Tensor, quant_scale: torch.Tensor, + quant_output: torch.Tensor): at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=attn_output, layer_name=self.layer_name, output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) - + attn_output_view = RESHAPE_OP( + at1[1], [-1, self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, - result=output_quant, - input=attn_out_view, - scale=scale) + result=quant_output, + input=attn_output_view, + scale=quant_scale) return at2[1] def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): - view_7 = RESHAPE_OP(output_quant, - [-1, self.num_heads, self.head_size]) - + attn_output: torch.Tensor, quant_scale: torch.Tensor, + quant_output: torch.Tensor): + # attn out in quant_dtype + attn_output = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size], + 0.0, + dtype=self.quant_dtype, + device=q.device) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=attn_output, layer_name=self.layer_name, - output_scale=scale) - + output_scale=quant_scale) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -102,10 +107,10 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads * self.head_size), # attn_output + empty_bf16(5, self.num_heads, self.head_size), # attn_output + empty_fp32(1, 1), # quant_scale self.empty_quant(5, self.num_heads * self.head_size), # quant_output - empty_fp32(1, 1) # scale ] def wrap_trace_fn(process_fx, trace_fn): @@ -140,27 +145,29 @@ class AttnFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): super().__init__(config) - self.static_fwd_ctx = config.compilation_config.static_forward_context + attn_layers_count = 0 self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for key, layer in self.static_fwd_ctx.items(): - pattern = AttentionStaticQuantPattern(key, layer.num_heads, - layer.head_size, - current_platform.fp8_dtype()) - pattern.register_if_supported(self.patterns, layer) - if len(self.static_fwd_ctx) == 0: + for layer_name, layer in get_layers_from_vllm_config( + config, Attention).items(): + pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) + pattern.register_if_supported(self.patterns) + attn_layers_count += 1 + + if attn_layers_count == 0: logger.warning( - "Attention + quant fusion is enabled, but " - "CompilationConfig.static_forward_context is empty. " - "Cannot access attention layers so no fusion " - "patterns were registered.") + "Attention + quant fusion is enabled, but no attention layers " + "were found in CompilationConfig.static_forward_context " + "so no fusion patterns were registered.") def __call__(self, graph: torch.fx.graph.Graph) -> None: self.begin() self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + graph.eliminate_dead_code() + logger.debug("Fused quantization onto %s attention nodes", count) self.dump_graph(graph, "after_attn_fusion") self.end_and_log() diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 2e31b7bad747..edad25234fa8 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -174,23 +174,42 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]: def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, num_tokens: int, max_seq_len: int, kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], has_sinks: bool = False, + enable_fusion: bool = False, ) -> bool: use_trtllm, env_value = supports_trtllm_attention() if not use_trtllm: return False - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0): + if num_qo_heads % num_kv_heads != 0: return False + if kv_cache_dtype.startswith("fp8"): + if enable_fusion: + logger.info_once("Using TRTLLM attention (FP8 kv cache required).") + return True + else: + # Remove this when TRTLLM attn kernel supports FP8 kv cache + # without attn+quant fusion. + raise ValueError("Flashinfer TRTLLM attention does not support " + "FP8 kv cache without attn+quant fusion. " + "Suggested the following configs: " + "(1) set kv_cache_dtype to 'auto', or " + "(2) turn on enable_attn_fusion, or " + "(3) disable Flashinfer backend") + elif kv_cache_dtype == "auto" and enable_fusion: + raise ValueError("Flashinfer TRTLLM attention does not support " + "auto kv cache dtype with attn+quant fusion. " + "Suggested the following configs: " + "(1) set kv_cache_dtype to 'fp8', or " + "(2) turn off enable_attn_fusion, or " + "(3) disable Flashinfer backend") + # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: @@ -290,6 +309,7 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", + "support_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 991904229fd7..0caa37379ea2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,12 +15,17 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_attention +from vllm.utils.flashinfer import (support_trtllm_attention, + use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -35,6 +40,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FP8_DTYPE = current_platform.fp8_dtype() + logger = init_logger(__name__) @@ -519,22 +526,24 @@ def build(self, else: kv_cache_dtype = self.kv_cache_spec.dtype - num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + config = self.vllm_config + num_qo_heads = config.model_config.get_num_attention_heads( + config.parallel_config) num_kv_heads = self.kv_cache_spec.num_kv_heads head_dim = self.kv_cache_spec.head_size + # Check if attn+quant fusion is enabled (requires TRTLLM attention) + enable_fusion = config.compilation_config.pass_config.enable_attn_fusion + # Check if any layer uses sinks (requires TRTLLM attention) has_sinks = self.global_hyperparameters.has_sinks - # currently prefill trtllm attention does not support fp8 kv cache - prefill_use_trtllm = not cache_dtype.startswith("fp8") \ - and use_trtllm_attention( - num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) + prefill_use_trtllm = use_trtllm_attention( + num_qo_heads, num_kv_heads, num_prefill_tokens, + max_seq_len, cache_dtype, has_sinks, enable_fusion) decode_use_trtllm = use_trtllm_attention( - num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) + num_qo_heads, num_kv_heads, num_decode_tokens, + max_seq_len, cache_dtype, has_sinks, enable_fusion) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -548,7 +557,7 @@ def build(self, head_dim=head_dim, page_size=page_size, kv_data_type=kv_cache_dtype, - q_data_type=self.vllm_config.model_config.dtype, + q_data_type=kv_cache_dtype, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -622,6 +631,8 @@ def __init__( self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) + self.window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -644,6 +655,19 @@ def __init__( ) self.sinks = sinks + self.support_trtllm_attn = support_trtllm_attention(num_heads, + num_kv_heads) + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: GroupShape): + supported_quant_type = (dtype == FP8_DTYPE and static and + group_shape == GroupShape.PER_TENSOR) + return (self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and supported_quant_type) + def forward( self, layer: torch.nn.Module, @@ -672,15 +696,30 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - if attn_metadata is None: # Profiling run. return output + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is not None: + assert (attn_metadata.prefill_use_trtllm and + attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + assert output.dtype == FP8_DTYPE, \ + "output must be fp8 quantized when the attn fusion applied" + + # TRTLLM attn kernel requires o scale as a host scalar, store the + # o scale to host scalar in warmup run with cuda graph not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -718,9 +757,6 @@ def forward( self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) - window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) - # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] output_padded = output @@ -732,6 +768,14 @@ def forward( output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output + # Insert quant op for query + if attn_metadata.q_data_type == FP8_DTYPE: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -748,7 +792,7 @@ def forward( if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == ( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale @@ -783,12 +827,12 @@ def forward( seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=window_left, + window_left=self.window_left, sinks=self.sinks, out=output[num_decode_tokens:], ) @@ -800,7 +844,7 @@ def forward( assert decode_wrapper is not None if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == window_left + assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale @@ -815,8 +859,8 @@ def forward( # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] + block_tables_decode = attn_metadata.\ + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -834,9 +878,9 @@ def forward( block_tables=block_tables_decode, seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - window_left=window_left, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, sinks=self.sinks, out=output[:num_decode_tokens], ) From 093ccdcd1668b94908a33cae3cda7b562c06192a Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 14 Aug 2025 22:50:29 -0700 Subject: [PATCH 2/9] address fusion_attn comment Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- vllm/attention/backends/flashinfer.py | 11 ++++--- vllm/compilation/fusion_attn.py | 33 ++++++++++---------- vllm/utils/flashinfer.py | 32 +++++++------------- vllm/v1/attention/backends/flashinfer.py | 38 +++++++++++++----------- 4 files changed, 56 insertions(+), 58 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 6b8691dff641..1ffa7bd9ad55 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1054,10 +1054,13 @@ def forward( # TODO: @pavanimajety Remove this once the switch happens # inside flashinfer. # see https://github.com/flashinfer-ai/flashinfer/issues/1493 - if not use_trtllm_attention( - attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, - num_decode_tokens, attn_metadata.max_decode_seq_len, - kv_cache_dtype): + if not use_trtllm_attention(attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + num_decode_tokens, + attn_metadata.max_decode_seq_len, + kv_cache_dtype, + decode_query.dtype, + is_prefill=False): decode_meta.decode_wrapper.run( decode_query, kv_cache.permute(*stride_order), diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 992467761c76..8e0177e089d4 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -38,6 +38,7 @@ def __init__( self, layer: Attention, quant_dtype: torch.dtype, + symmetric=True, ): self.layer = layer self.layer_name = layer.layer_name @@ -47,7 +48,7 @@ def __init__( self.quant_key = QuantKey(dtype=quant_dtype, static=True, group_shape=GroupShape.PER_TENSOR, - symmetric=True) + symmetric=symmetric) assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -65,28 +66,28 @@ def register_if_supported(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - attn_output: torch.Tensor, quant_scale: torch.Tensor, - quant_output: torch.Tensor): + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=attn_output, + output=output_attn, layer_name=self.layer_name, output_scale=None) - attn_output_view = RESHAPE_OP( - at1[1], [-1, self.num_heads * self.head_size]) + attn_out_view = RESHAPE_OP(at1[1], + [-1, self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, - result=quant_output, - input=attn_output_view, - scale=quant_scale) + result=output_quant, + input=attn_out_view, + scale=scale) return at2[1] def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - attn_output: torch.Tensor, quant_scale: torch.Tensor, - quant_output: torch.Tensor): - # attn out in quant_dtype - attn_output = torch.ops.aten.full.default( + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): + # attn output in quant_dtype + output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size], 0.0, dtype=self.quant_dtype, @@ -95,9 +96,9 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, query=q, key=k, value=v, - output=attn_output, + output=output_attn, layer_name=self.layer_name, - output_scale=quant_scale) + output_scale=scale) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -108,9 +109,9 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v empty_bf16(5, self.num_heads, self.head_size), # attn_output - empty_fp32(1, 1), # quant_scale self.empty_quant(5, self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1) # scale ] def wrap_trace_fn(process_fx, trace_fn): diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index edad25234fa8..191218493a59 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -179,8 +179,9 @@ def use_trtllm_attention( num_tokens: int, max_seq_len: int, kv_cache_dtype: str, + q_dtype: torch.dtype, + is_prefill: bool, has_sinks: bool = False, - enable_fusion: bool = False, ) -> bool: use_trtllm, env_value = supports_trtllm_attention() if not use_trtllm: @@ -189,26 +190,15 @@ def use_trtllm_attention( if num_qo_heads % num_kv_heads != 0: return False - if kv_cache_dtype.startswith("fp8"): - if enable_fusion: - logger.info_once("Using TRTLLM attention (FP8 kv cache required).") - return True - else: - # Remove this when TRTLLM attn kernel supports FP8 kv cache - # without attn+quant fusion. - raise ValueError("Flashinfer TRTLLM attention does not support " - "FP8 kv cache without attn+quant fusion. " - "Suggested the following configs: " - "(1) set kv_cache_dtype to 'auto', or " - "(2) turn on enable_attn_fusion, or " - "(3) disable Flashinfer backend") - elif kv_cache_dtype == "auto" and enable_fusion: - raise ValueError("Flashinfer TRTLLM attention does not support " - "auto kv cache dtype with attn+quant fusion. " - "Suggested the following configs: " - "(1) set kv_cache_dtype to 'fp8', or " - "(2) turn off enable_attn_fusion, or " - "(3) disable Flashinfer backend") + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # TRTLLM prefill attention does not support FP8 kv cache with + # non-quantized query + if is_prefill and kv_cache_dtype.startswith("fp8"): + return False # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0caa37379ea2..5b6da3ba5962 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -532,18 +532,21 @@ def build(self, num_kv_heads = self.kv_cache_spec.num_kv_heads head_dim = self.kv_cache_spec.head_size - # Check if attn+quant fusion is enabled (requires TRTLLM attention) - enable_fusion = config.compilation_config.pass_config.enable_attn_fusion - # Check if any layer uses sinks (requires TRTLLM attention) has_sinks = self.global_hyperparameters.has_sinks + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + q_dtype = self.vllm_config.model_config.dtype + enable_fusion = config.compilation_config.pass_config.enable_attn_fusion + if cache_dtype.startswith("fp8") and enable_fusion: + q_dtype = kv_cache_dtype + prefill_use_trtllm = use_trtllm_attention( - num_qo_heads, num_kv_heads, num_prefill_tokens, - max_seq_len, cache_dtype, has_sinks, enable_fusion) + num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len, + cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks) decode_use_trtllm = use_trtllm_attention( - num_qo_heads, num_kv_heads, num_decode_tokens, - max_seq_len, cache_dtype, has_sinks, enable_fusion) + num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len, + cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -557,7 +560,7 @@ def build(self, head_dim=head_dim, page_size=page_size, kv_data_type=kv_cache_dtype, - q_data_type=kv_cache_dtype, + q_data_type=q_dtype, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -709,10 +712,12 @@ def forward( # The attn+quant fusion happens when output_scale is provided. if output_scale is not None: + assert attn_metadata.q_data_type == FP8_DTYPE, \ + "Planned q_dtype must be FP8 when attn+quant fusion applied" assert (attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" assert output.dtype == FP8_DTYPE, \ - "output must be fp8 quantized when the attn fusion applied" + "output dtype must be FP8 when attn+quant fusion applied" # TRTLLM attn kernel requires o scale as a host scalar, store the # o scale to host scalar in warmup run with cuda graph not enabled @@ -720,6 +725,13 @@ def forward( layer._o_scale_float = output_scale.cpu().item() self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + # Insert FP8 quant for query + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -768,14 +780,6 @@ def forward( output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output - # Insert quant op for query - if attn_metadata.q_data_type == FP8_DTYPE: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens From 455ad978870e4d15db935e2dbc2364287a90c1b7 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Fri, 15 Aug 2025 10:39:59 -0700 Subject: [PATCH 3/9] adress test issue Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- tests/compile/test_fusion_attn.py | 147 +++++++++++++++--------------- 1 file changed, 72 insertions(+), 75 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index bf6804a42245..9562447102be 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -8,6 +8,7 @@ from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm import LLM, SamplingParams from vllm.attention import Attention from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym @@ -18,12 +19,12 @@ ModelConfig, PassConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.forward_context import get_forward_context, set_forward_context -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform from vllm.v1.attention.backends.flashinfer import FlashInferMetadataBuilder -from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -150,9 +151,15 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, class TestAttentionStaticQuantPatternModel(torch.nn.Module): """Test model for AttentionStaticQuantPattern fusion.""" - def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype, device: torch.device, - vllm_config: VllmConfig): + def __init__(self, + num_qo_heads: int, + num_kv_heads: int, + head_size: int, + kv_cache_dtype: torch.dtype, + device: torch.device, + vllm_config: VllmConfig, + w: Optional[torch.Tensor] = None, + kv_cache: Optional[torch.Tensor] = None): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -170,64 +177,57 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, prefix="model.layers.0.self_attn.attn", ) - self.quant_fp8 = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) - self.quant_fp8_scale = torch.tensor([1.0], dtype=torch.float32) + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) - def build_attn_metadata(self, batch_size: int): - """Initialize attention metadata.""" - query_start_loc = torch.arange(0, - batch_size + 1, - dtype=torch.int32, - device=self.device) - seq_lens = torch.ones(batch_size, - dtype=torch.int32, - device=self.device) - - # Create simple block table and slot mapping for testing - block_size = 16 - num_tokens = batch_size # num_tokens = batch_size for simplicity - num_blocks = max(1, (num_tokens + block_size - 1) // block_size) - block_table = torch.arange(num_blocks, - dtype=torch.int32, - device=self.device).unsqueeze(0).repeat( - batch_size, 1) - slot_mapping = torch.arange(batch_size, - dtype=torch.long, - device=self.device) - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_cpu=query_start_loc.cpu(), - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - num_computed_tokens_cpu=torch.zeros(batch_size, dtype=torch.int32), - num_reqs=batch_size, - num_actual_tokens=batch_size, - max_query_len=1, - block_table_tensor=block_table, - slot_mapping=slot_mapping, - ) + hidden_size = num_qo_heads * head_size + if w is not None: + self.w = w + else: + self.w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() + self.wscale = torch.tensor([1.0], dtype=torch.float32) + self.scale = torch.tensor([1.0], dtype=torch.float32) + + self.block_size = 16 + self.kv_cache = None + if kv_cache is not None: + self.kv_cache = kv_cache - # Mock the KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) - kv_cache = torch.zeros(num_blocks, - 2, - self.num_kv_heads, - block_size, - self.head_size, - dtype=self.kv_cache_dtype, - device=self.device) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) - self.attn.kv_cache = [kv_cache] + def build_attn_metadata(self, batch_size: int): + """Initialize Flashinfer attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + # Create kv cache + if self.kv_cache is None: + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + self.kv_cache = torch.randn(num_blocks, 2, self.num_kv_heads, + self.block_size, self.head_size).to( + dtype=self.kv_cache_dtype, + device=self.device) + self.kv_cache = self.kv_cache.permute(0, 1, 3, 2, 4) + + self.attn.kv_cache = [self.kv_cache] # Initialize FlashInferMetadataBuilder builder = FlashInferMetadataBuilder( kv_cache_spec=AttentionSpec( - block_size=block_size, + block_size=self.block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_dtype, @@ -247,9 +247,10 @@ def build_attn_metadata(self, batch_size: int): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - attn_output = attn_output.view(-1, self.num_qo_heads * self.head_size) - output, _ = self.quant_fp8(attn_output, self.quant_fp8_scale) - return output + return self.fp8_linear.apply(input=attn_output, + weight=self.w, + weight_scale=self.wscale, + input_scale=self.scale) @pytest.mark.parametrize("num_heads", [(64, 8), (40, 8)]) @@ -259,6 +260,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): @pytest.mark.parametrize( "model_quant_dtype", [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", FP8_DTYPE)]) +@pytest.mark.parametrize("test_backend", ["FLASHINFER"]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), @@ -266,13 +268,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, batch_size: int, dtype: torch.dtype, model_quant_dtype: tuple[str, torch.dtype], - monkeypatch, dist_init): - """Test AttentionStaticQuantPattern fusion pass with FlashInfer backend""" + test_backend, monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass""" - # Enable FlashInfer v1 backend for this test monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - monkeypatch.setenv("VLLM_USE_TRTLLM_ATTN", "1") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", test_backend) device = torch.device("cuda:0") torch.manual_seed(42) @@ -280,6 +280,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, num_qo_heads, num_kv_heads = num_heads model_name, quant_dtype = model_quant_dtype + # The quant op to check the fusion happenes or not quant_op = None if quant_dtype == FP8_DTYPE: quant_op = QUANT_OPS[kFp8StaticTensorSym] @@ -297,22 +298,19 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, custom_ops=["+quant_fp8"], full_cuda_graph=True, ), - ) + cache_config=CacheConfig(cache_dtype="fp8")) # Create test inputs q = torch.randn(batch_size, - num_qo_heads, - head_size, + num_qo_heads * head_size, dtype=dtype, device=device) k = torch.randn(batch_size, - num_kv_heads, - head_size, + num_kv_heads * head_size, dtype=dtype, device=device) v = torch.randn(batch_size, - num_kv_heads, - head_size, + num_kv_heads * head_size, dtype=dtype, device=device) @@ -326,7 +324,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, with set_current_vllm_config(vllm_config_unfused), set_forward_context( attn_metadata=None, vllm_config=vllm_config_unfused): model_unfused = TestAttentionStaticQuantPatternModel( - num_qo_heads, num_kv_heads, head_size, dtype, device, + num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, vllm_config_unfused) model_unfused = model_unfused.to(device) @@ -340,12 +338,11 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( enable_attn_fusion=True, enable_noop=True) - vllm_config.cache_config = CacheConfig(cache_dtype="fp8") with set_current_vllm_config(vllm_config), set_forward_context( attn_metadata=None, vllm_config=vllm_config): model_fused = TestAttentionStaticQuantPatternModel( num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config) + vllm_config, model_unfused.w, model_unfused.kv_cache) model_fused = model_fused.to(device) forward_ctx = get_forward_context() From 6861f8a1e6489b4195babbaa94031d079aacda13 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Fri, 15 Aug 2025 11:57:25 -0700 Subject: [PATCH 4/9] address core parts Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- vllm/compilation/fusion_attn.py | 9 +++------ vllm/utils/flashinfer.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 8e0177e089d4..7d43b1d175e6 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -146,17 +146,14 @@ class AttnFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): super().__init__(config) - attn_layers_count = 0 self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for layer_name, layer in get_layers_from_vllm_config( - config, Attention).items(): + attn_layers = get_layers_from_vllm_config(config, Attention) + for layer_name, layer in attn_layers.items(): pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) pattern.register_if_supported(self.patterns) - attn_layers_count += 1 - - if attn_layers_count == 0: + if len(attn_layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " "were found in CompilationConfig.static_forward_context " diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 191218493a59..996be1265667 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -299,7 +299,7 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", - "support_trtllm_attention", + "supports_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5b6da3ba5962..c56e721dff8c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -24,7 +24,7 @@ GroupShape) from vllm.platforms import current_platform from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (support_trtllm_attention, +from vllm.utils.flashinfer import (supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block @@ -536,7 +536,7 @@ def build(self, has_sinks = self.global_hyperparameters.has_sinks # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - q_dtype = self.vllm_config.model_config.dtype + q_dtype = config.model_config.dtype enable_fusion = config.compilation_config.pass_config.enable_attn_fusion if cache_dtype.startswith("fp8") and enable_fusion: q_dtype = kv_cache_dtype @@ -658,8 +658,8 @@ def __init__( ) self.sinks = sinks - self.support_trtllm_attn = support_trtllm_attention(num_heads, - num_kv_heads) + self.support_trtllm_attn = (supports_trtllm_attention() and + num_heads % num_kv_heads == 0) self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None @@ -711,13 +711,16 @@ def forward( self.bmm2_scale = layer._v_scale_float # The attn+quant fusion happens when output_scale is provided. - if output_scale is not None: + if output_scale is None: + assert attn_metadata.q_data_type != FP8_DTYPE, \ + "Query can only be FP8 if output fusion happened." + else: assert attn_metadata.q_data_type == FP8_DTYPE, \ - "Planned q_dtype must be FP8 when attn+quant fusion applied" + "Query must be FP8 when attn+quant fusion happened." assert (attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" assert output.dtype == FP8_DTYPE, \ - "output dtype must be FP8 when attn+quant fusion applied" + "Output must be FP8 when attn+quant fusion happened." # TRTLLM attn kernel requires o scale as a host scalar, store the # o scale to host scalar in warmup run with cuda graph not enabled From 9d4083ea6db3052f9cad925b2712309b4f1c9cc5 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:29:04 -0700 Subject: [PATCH 5/9] address fusion test issue Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- tests/compile/test_fusion_attn.py | 171 ++++++++++++++++-------------- 1 file changed, 92 insertions(+), 79 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 9562447102be..f9f8e2e7e775 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -8,7 +8,9 @@ from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal -from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + get_attention_backend) from vllm import LLM, SamplingParams from vllm.attention import Attention from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym @@ -24,7 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform -from vllm.v1.attention.backends.flashinfer import FlashInferMetadataBuilder +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -151,15 +153,10 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, class TestAttentionStaticQuantPatternModel(torch.nn.Module): """Test model for AttentionStaticQuantPattern fusion.""" - def __init__(self, - num_qo_heads: int, - num_kv_heads: int, - head_size: int, - kv_cache_dtype: torch.dtype, - device: torch.device, - vllm_config: VllmConfig, - w: Optional[torch.Tensor] = None, - kv_cache: Optional[torch.Tensor] = None): + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + attn_builder: type[AttentionMetadataBuilder], + vllm_config: VllmConfig): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -179,53 +176,13 @@ def __init__(self, self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) - - hidden_size = num_qo_heads * head_size - if w is not None: - self.w = w - else: - self.w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() self.wscale = torch.tensor([1.0], dtype=torch.float32) self.scale = torch.tensor([1.0], dtype=torch.float32) self.block_size = 16 - self.kv_cache = None - if kv_cache is not None: - self.kv_cache = kv_cache - - def build_attn_metadata(self, batch_size: int): - """Initialize Flashinfer attention metadata.""" - # Create common attn metadata - batch_spec = BatchSpec(seq_lens=[1] * batch_size, - query_lens=[1] * batch_size) - common_attn_metadata = create_common_attn_metadata( - batch_spec, - self.block_size, - self.device, - arange_block_indices=True) - - # Create kv cache - if self.kv_cache is None: - max_blocks = (max(batch_spec.seq_lens) + self.block_size - - 1) // self.block_size - num_blocks = batch_size * max_blocks - - # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) - self.kv_cache = torch.randn(num_blocks, 2, self.num_kv_heads, - self.block_size, self.head_size).to( - dtype=self.kv_cache_dtype, - device=self.device) - self.kv_cache = self.kv_cache.permute(0, 1, 3, 2, 4) - - self.attn.kv_cache = [self.kv_cache] - - # Initialize FlashInferMetadataBuilder - builder = FlashInferMetadataBuilder( + # Initialize attn MetadataBuilder + self.builder = attn_builder( kv_cache_spec=AttentionSpec( block_size=self.block_size, num_kv_heads=self.num_kv_heads, @@ -238,17 +195,49 @@ def build_attn_metadata(self, batch_size: int): device=self.device, ) - # Build FlashInferMetadata - self.attn_metadata = builder.build( + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + self.attn_metadata = self.builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata) return self.attn_metadata - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + w: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) return self.fp8_linear.apply(input=attn_output, - weight=self.w, + weight=w, weight_scale=self.wscale, input_scale=self.scale) @@ -260,7 +249,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): @pytest.mark.parametrize( "model_quant_dtype", [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", FP8_DTYPE)]) -@pytest.mark.parametrize("test_backend", ["FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASHINFER"]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), @@ -268,11 +257,11 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, batch_size: int, dtype: torch.dtype, model_quant_dtype: tuple[str, torch.dtype], - test_backend, monkeypatch, dist_init): + backend: str, monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", test_backend) + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) device = torch.device("cuda:0") torch.manual_seed(42) @@ -280,10 +269,14 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, num_qo_heads, num_kv_heads = num_heads model_name, quant_dtype = model_quant_dtype + if backend == "FLASHINFER": + attn_builder, _ = get_attention_backend(_Backend.FLASHINFER_VLLM_V1) + else: + raise ValueError(f"Unsupported backend: {backend}") + # The quant op to check the fusion happenes or not - quant_op = None if quant_dtype == FP8_DTYPE: - quant_op = QUANT_OPS[kFp8StaticTensorSym] + quant_key = kFp8StaticTensorSym else: raise ValueError(f"Unsupported quant_dtype: {quant_dtype}") @@ -296,15 +289,12 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+quant_fp8"], - full_cuda_graph=True, ), cache_config=CacheConfig(cache_dtype="fp8")) # Create test inputs - q = torch.randn(batch_size, - num_qo_heads * head_size, - dtype=dtype, - device=device) + hidden_size = num_qo_heads * head_size + q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, @@ -313,6 +303,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, num_kv_heads * head_size, dtype=dtype, device=device) + linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() # Mark first dimension as dynamic for realistic testing torch._dynamo.mark_dynamic(q, 0) @@ -325,7 +316,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, attn_metadata=None, vllm_config=vllm_config_unfused): model_unfused = TestAttentionStaticQuantPatternModel( num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config_unfused) + attn_builder, vllm_config_unfused) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() @@ -333,7 +324,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, batch_size) # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v) + result_unfused = model_unfused(q, k, v, linear_w) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( @@ -342,7 +333,7 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, attn_metadata=None, vllm_config=vllm_config): model_fused = TestAttentionStaticQuantPatternModel( num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config, model_unfused.w, model_unfused.kv_cache) + attn_builder, vllm_config) model_fused = model_fused.to(device) forward_ctx = get_forward_context() @@ -352,20 +343,38 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, noop_pass = NoOpEliminationPass(vllm_config) attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw ) - backend = TestBackend(noop_pass, attn_pass) + test_backend = TestBackend(noop_pass, attn_pass) # Compile model with fusion enabled model_compiled = torch.compile(model_fused, - backend=backend, + backend=test_backend, fullgraph=True) - result_fused = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v, linear_w) + + # After the 1st round of the forward pass, output quant scale should be + # loaded into the attn layer's _o_scale_float, the 2nd round should + # reuse the loaded _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v, linear_w) + assert model_compiled.attn._o_scale_float is not None - # Check quantization ops in the graph before and after fusion - backend.check_before_ops([quant_op], fully_replaced=True) + # Check attn fusion support + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key.dtype, + quant_key.static, + quant_key.group_shape) for key, + layer in vllm_config.compilation_config.static_forward_context.items() + ] + if any(attn_fusion_supported): + # Check quantization ops in the graph before and after fusion + test_backend.check_before_ops([QUANT_OPS[quant_key]], + fully_replaced=True) # Check attention ops in the graph before and after fusion - attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) - attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) + attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, + test_backend.graph_post_pass)) assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" assert len(attn_nodes_pre) == len(attn_nodes_post), \ @@ -376,7 +385,11 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, "Attention should have output_scale after fusion" # Check that results are closed - torch.testing.assert_close(result_unfused.to(dtype), - result_fused.to(dtype), + torch.testing.assert_close(result_unfused, + result_fused_1, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(result_unfused, + result_fused_2, atol=1e-2, rtol=1e-2) From 01e62983ff81b84071e49ab86b215432e8d7d06a Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 18 Aug 2025 01:51:36 -0700 Subject: [PATCH 6/9] add github issue for graph.eliminate_dead_code() Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- vllm/compilation/fusion_attn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 7d43b1d175e6..1f77a2667613 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -164,6 +164,10 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + + # TODO: Move this to pass_manager.py after the fx graph broken issue + # has been resolved. + # see https://github.com/vllm-project/vllm/issues/23091 graph.eliminate_dead_code() logger.debug("Fused quantization onto %s attention nodes", count) From a178cacc434d186044ebe6300d122a2106cbd49d Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 18 Aug 2025 07:18:20 -0700 Subject: [PATCH 7/9] fix fusion test Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- tests/compile/test_fusion_attn.py | 42 +++++++++++++------------------ 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index f9f8e2e7e775..99d746281df1 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -9,10 +9,10 @@ from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, - get_attention_backend) + create_common_attn_metadata) from vllm import LLM, SamplingParams from vllm.attention import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes @@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -155,7 +154,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module): def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype, device: torch.device, - attn_builder: type[AttentionMetadataBuilder], vllm_config: VllmConfig): super().__init__() self.num_qo_heads = num_qo_heads @@ -182,7 +180,7 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, self.block_size = 16 # Initialize attn MetadataBuilder - self.builder = attn_builder( + self.builder = self.attn.attn_backend.get_builder_cls()( kv_cache_spec=AttentionSpec( block_size=self.block_size, num_kv_heads=self.num_kv_heads, @@ -242,38 +240,30 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_scale=self.scale) -@pytest.mark.parametrize("num_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize( - "model_quant_dtype", + "model_name, quant_dtype", [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", FP8_DTYPE)]) -@pytest.mark.parametrize("backend", ["FLASHINFER"]) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), reason="Only test on SM100(Blackwell)") -def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, - batch_size: int, dtype: torch.dtype, - model_quant_dtype: tuple[str, torch.dtype], - backend: str, monkeypatch, dist_init): +def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, + head_size: int, batch_size: int, + dtype: torch.dtype, model_name: str, + quant_dtype: torch.dtype, backend: _Backend, + monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) device = torch.device("cuda:0") torch.manual_seed(42) - num_qo_heads, num_kv_heads = num_heads - model_name, quant_dtype = model_quant_dtype - - if backend == "FLASHINFER": - attn_builder, _ = get_attention_backend(_Backend.FLASHINFER_VLLM_V1) - else: - raise ValueError(f"Unsupported backend: {backend}") - # The quant op to check the fusion happenes or not if quant_dtype == FP8_DTYPE: quant_key = kFp8StaticTensorSym @@ -313,10 +303,11 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, # Run model directly without compilation and fusion vllm_config_unfused = copy.deepcopy(vllm_config) with set_current_vllm_config(vllm_config_unfused), set_forward_context( - attn_metadata=None, vllm_config=vllm_config_unfused): + attn_metadata=None, vllm_config=vllm_config_unfused + ), global_force_attn_backend_context_manager(backend): model_unfused = TestAttentionStaticQuantPatternModel( num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - attn_builder, vllm_config_unfused) + vllm_config_unfused) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() @@ -330,10 +321,11 @@ def test_attention_quant_pattern(num_heads: tuple[int, int], head_size: int, vllm_config.compilation_config.pass_config = PassConfig( enable_attn_fusion=True, enable_noop=True) with set_current_vllm_config(vllm_config), set_forward_context( - attn_metadata=None, vllm_config=vllm_config): + attn_metadata=None, vllm_config=vllm_config + ), global_force_attn_backend_context_manager(backend): model_fused = TestAttentionStaticQuantPatternModel( num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - attn_builder, vllm_config) + vllm_config) model_fused = model_fused.to(device) forward_ctx = get_forward_context() From 0db5d8cc6f84c8b43c2720fae1a24132dfa210be Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 18 Aug 2025 08:17:54 -0700 Subject: [PATCH 8/9] fix comment Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- .../kernels/benchmark_trtllm_decode_attention.py | 14 -------------- .../kernels/benchmark_trtllm_prefill_attention.py | 14 -------------- tests/compile/test_fusion_attn.py | 12 +++--------- 3 files changed, 3 insertions(+), 37 deletions(-) diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 694e91aa8d0a..ad537e6697c2 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -162,22 +162,8 @@ def trtllm_decode(): ) baseline_mean, baseline_std = time_fn(baseline_decode) - if o_quant_dtype == FP8_DTYPE: - _, o_scale = to_float8(output_baseline) trtllm_mean, trtllm_std = time_fn(trtllm_decode) - # if o_quant_dtype == FP8_DTYPE: - # output_trtllm = output_trtllm.to(dtype) * o_scale - - # if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - # rtol, atol = 5e-2, 7e-2 - # else: - # rtol, atol = 1e-2, 1e-2 - - # torch.testing.assert_close( - # output_baseline, output_trtllm, atol=atol, rtol=rtol), \ - # f"{torch.max(torch.abs(output_baseline - output_trtllm))}" - # Calculate percentage speedup (positive means TRT is faster) speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index c0f10b387f00..49810e20c7d8 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -176,22 +176,8 @@ def trtllm_prefill(): ) baseline_mean, baseline_std = time_fn(baseline_prefill) - if o_quant_dtype == FP8_DTYPE: - _, o_scale = to_float8(output_baseline) trtllm_mean, trtllm_std = time_fn(trtllm_prefill) - # if o_quant_dtype == FP8_DTYPE: - # output_trtllm = output_trtllm.to(dtype) * o_scale - - # if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - # rtol, atol = 5e-2, 7e-2 - # else: - # rtol, atol = 1e-2, 1e-2 - - # torch.testing.assert_close( - # output_baseline, output_trtllm, atol=atol, rtol=rtol), \ - # f"{torch.max(torch.abs(output_baseline - output_trtllm))}" - # Calculate percentage speedup (positive means TRT is faster) speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 99d746281df1..bef0fdef985e 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -245,8 +245,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize( - "model_name, quant_dtype", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", FP8_DTYPE)]) + "model_name, quant_key", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)]) @pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @@ -255,7 +255,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, - quant_dtype: torch.dtype, backend: _Backend, + quant_key: QuantKey, backend: _Backend, monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" @@ -264,12 +264,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, device = torch.device("cuda:0") torch.manual_seed(42) - # The quant op to check the fusion happenes or not - if quant_dtype == FP8_DTYPE: - quant_key = kFp8StaticTensorSym - else: - raise ValueError(f"Unsupported quant_dtype: {quant_dtype}") - vllm_config = VllmConfig( model_config=ModelConfig( model=model_name, From 78befbf075a63c4ee40417acc658d6d5a7ca67e3 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 18 Aug 2025 18:14:42 -0700 Subject: [PATCH 9/9] add mix quant dtype for decode input Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- benchmarks/kernels/benchmark_trtllm_decode_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index ad537e6697c2..b3f81715461b 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -235,6 +235,7 @@ def write_results_to_csv(results, filename=None): quant_dtypes = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (None, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), ]