Skip to content

[Bugfix] Fix workspace buffer None issue for Flashinfer TRTLLM Backend #21525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions benchmarks/kernels/benchmark_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,20 @@ def benchmark_decode(
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)

output_trtllm = torch.empty(q.shape, dtype=dtype)

# Benchmark TRT decode
def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q,
kv_cache,
workspace_buffer,
num_qo_heads,
num_kv_heads,
sm_scale,
block_tables,
kv_lens_tensor,
page_size,
max_kv_len,
kv_cache_dtype,
k_scale,
v_scale,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
out=output_trtllm,
)

def time_fn(fn, warmup=10, trials=20):
Expand Down Expand Up @@ -125,6 +123,8 @@ def time_fn(fn, warmup=10, trials=20):
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

output_baseline = torch.empty(q.shape, dtype=dtype)

wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
Expand All @@ -145,7 +145,7 @@ def time_fn(fn, warmup=10, trials=20):
)

def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)

baseline_mean, baseline_std = time_fn(baseline_decode)

Expand Down Expand Up @@ -214,25 +214,39 @@ def write_results_to_csv(results, filename=None):
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []

print("Running benchmark for kv_cache_dtype: bfloat16")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)

print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="fp8",
)
all_results.append(result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_data_type=dtype,
logits_soft_cap=soft_cap)

output = wrapper.run(query, key_value_cache, scale)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query, key_value_cache, scale, out=output)

# TRTLLM Decode
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens,
dtype=torch.int,
device=query.device)
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query.contiguous(),
key_value_cache,
workspace_buffer,
num_query_heads,
num_kv_heads,
scale,
block_tables,
kv_lens_tensor,
block_size,
max_kv_len,
"auto",
k_scale,
v_scale,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)

torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
Expand Down
15 changes: 11 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,12 @@ def forward(
window_left = window_size[0] if window_size is not None else -1

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if num_decode_tokens > 0:
decode_output = torch.empty(decode_query.shape,
dtype=decode_query.dtype,
device=decode_query.device)
else:
decode_output = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
Expand Down Expand Up @@ -1155,17 +1160,18 @@ def forward(
num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
decode_output = decode_meta.decode_wrapper.run(
decode_meta.decode_wrapper.run(
decode_query,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=decode_output,
)
else:
workspace_buffer = (
decode_meta.decode_wrapper._int_workspace_buffer)
decode_meta.decode_wrapper._float_workspace_buffer)
assert FlashInferState.get_kv_cache_layout() == "HND"
decode_output = trtllm_batch_decode_with_kv_cache(
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
workspace_buffer=workspace_buffer,
Expand All @@ -1174,6 +1180,7 @@ def forward(
max_seq_len=attn_metadata.max_decode_seq_len,
bmm1_scale=layer._k_scale_float * softmax_scale,
bmm2_scale=layer._v_scale_float,
out=decode_output,
)

if prefill_output is None and decode_output is not None:
Expand Down
30 changes: 15 additions & 15 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ class FlashInferMetadata:
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
workspace_buffer: torch.Tensor

# For handling prefill decode split
num_decodes: int
Expand Down Expand Up @@ -473,7 +472,6 @@ def build(self,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
workspace_buffer=self._get_workspace_buffer(),
)

self._plan(num_prefills, num_decodes, attn_metadata)
Expand Down Expand Up @@ -641,11 +639,11 @@ def forward(
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
assert decode_wrapper is not None
assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
Expand All @@ -666,22 +664,24 @@ def forward(
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]
workspace_buffer = decode_wrapper._float_workspace_buffer

assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()

output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
))
assert workspace_buffer.is_contiguous()

trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded