diff --git a/benchmarks/kernels/benchmark_trtllm_attention.py b/benchmarks/kernels/benchmark_trtllm_attention.py index 8c980f930366..68c48858e61c 100644 --- a/benchmarks/kernels/benchmark_trtllm_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_attention.py @@ -71,22 +71,20 @@ def benchmark_decode( if kv_cache_dtype.startswith("fp8"): kv_cache, _ = to_float8(kv_cache) + output_trtllm = torch.empty(q.shape, dtype=dtype) + # Benchmark TRT decode def trt_decode(): return flashinfer.decode.trtllm_batch_decode_with_kv_cache( q, kv_cache, workspace_buffer, - num_qo_heads, - num_kv_heads, - sm_scale, block_tables, kv_lens_tensor, - page_size, max_kv_len, - kv_cache_dtype, - k_scale, - v_scale, + bmm1_scale=k_scale * sm_scale, + bmm2_scale=v_scale, + out=output_trtllm, ) def time_fn(fn, warmup=10, trials=20): @@ -125,6 +123,8 @@ def time_fn(fn, warmup=10, trials=20): kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + output_baseline = torch.empty(q.shape, dtype=dtype) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -145,7 +145,7 @@ def time_fn(fn, warmup=10, trials=20): ) def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) + return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) baseline_mean, baseline_std = time_fn(baseline_decode) @@ -214,25 +214,39 @@ def write_results_to_csv(results, filename=None): max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print("Running benchmark for kv_cache_dtype: bfloat16") print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " + "output_dtype: bfloat16" + ) + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" ) for max_seq_len in max_seq_lens: for bs in num_seqs: result = benchmark_decode( - bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" + bs, + max_seq_len, + dtype=torch.bfloat16, + kv_cache_dtype="auto", ) all_results.append(result) - print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " + "output_dtype: bfloat16" + ) + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" ) for max_seq_len in max_seq_lens: for bs in num_seqs: result = benchmark_decode( - bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" + bs, + max_seq_len, + dtype=torch.bfloat16, + kv_cache_dtype="fp8", ) all_results.append(result) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py index 96eee13695a9..2e2130fab6a2 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py @@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_data_type=dtype, logits_soft_cap=soft_cap) - output = wrapper.run(query, key_value_cache, scale) + output = torch.empty(query.shape, dtype=dtype) + wrapper.run(query, key_value_cache, scale, out=output) # TRTLLM Decode max_kv_len = max(kv_lens) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=query.device) - output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + output_trtllm = torch.empty(query.shape, dtype=dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( query.contiguous(), key_value_cache, workspace_buffer, - num_query_heads, - num_kv_heads, - scale, block_tables, kv_lens_tensor, - block_size, max_kv_len, - "auto", - k_scale, - v_scale, + bmm1_scale=k_scale * scale, + bmm2_scale=v_scale, + out=output_trtllm, ) torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e6e60e756248..824ff8cca201 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1104,7 +1104,12 @@ def forward( window_left = window_size[0] if window_size is not None else -1 prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None + if num_decode_tokens > 0: + decode_output = torch.empty(decode_query.shape, + dtype=decode_query.dtype, + device=decode_query.device) + else: + decode_output = None stride_order = FlashInferBackend.get_kv_cache_stride_order() if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill @@ -1155,17 +1160,18 @@ def forward( num_decode_tokens, attn_metadata.max_decode_seq_len, kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): - decode_output = decode_meta.decode_wrapper.run( + decode_meta.decode_wrapper.run( decode_query, kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, + out=decode_output, ) else: workspace_buffer = ( - decode_meta.decode_wrapper._int_workspace_buffer) + decode_meta.decode_wrapper._float_workspace_buffer) assert FlashInferState.get_kv_cache_layout() == "HND" - decode_output = trtllm_batch_decode_with_kv_cache( + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache.permute(*stride_order), workspace_buffer=workspace_buffer, @@ -1174,6 +1180,7 @@ def forward( max_seq_len=attn_metadata.max_decode_seq_len, bmm1_scale=layer._k_scale_float * softmax_scale, bmm2_scale=layer._v_scale_float, + out=decode_output, ) if prefill_output is None and decode_output is not None: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b72745ef156e..775780807eae 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -194,7 +194,6 @@ class FlashInferMetadata: max_seq_len: int seq_lens: torch.Tensor block_table_tensor: torch.Tensor - workspace_buffer: torch.Tensor # For handling prefill decode split num_decodes: int @@ -473,7 +472,6 @@ def build(self, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, - workspace_buffer=self._get_workspace_buffer(), ) self._plan(num_prefills, num_decodes, attn_metadata) @@ -641,11 +639,11 @@ def forward( if decode_wrapper := attn_metadata.decode_wrapper: decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None if not FlashInferBackend.use_trtllm_decode_attention( attn_metadata.num_decodes, attn_metadata.max_seq_len, self.kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): - assert decode_wrapper is not None assert decode_wrapper._window_left == window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) @@ -666,22 +664,24 @@ def forward( num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[: num_decode_tokens] + workspace_buffer = decode_wrapper._float_workspace_buffer assert get_kv_cache_layout() == "HND" assert decode_query.is_contiguous() assert kv_cache_permute.is_contiguous() assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() - - output[:num_decode_tokens] = ( - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache_permute, - workspace_buffer=attn_metadata.workspace_buffer, - block_tables=block_tables_decode, - seq_lens=seq_lens_decode, - max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - )) + assert workspace_buffer.is_contiguous() + + trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache_permute, + workspace_buffer=workspace_buffer, + block_tables=block_tables_decode, + seq_lens=seq_lens_decode, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=layer._k_scale_float * self.scale, + bmm2_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) return output_padded