Skip to content

Commit 84e689c

Browse files
committed
fix None workspace buffer issue
Signed-off-by: elvischenv <[email protected]>
1 parent 15a72ac commit 84e689c

File tree

4 files changed

+60
-42
lines changed

4 files changed

+60
-42
lines changed

benchmarks/kernels/benchmark_trtllm_attention.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,20 @@ def benchmark_decode(
7171
if kv_cache_dtype.startswith("fp8"):
7272
kv_cache, _ = to_float8(kv_cache)
7373

74+
output_trtllm = torch.empty(q.shape, dtype=dtype)
75+
7476
# Benchmark TRT decode
7577
def trt_decode():
7678
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
7779
q,
7880
kv_cache,
7981
workspace_buffer,
80-
num_qo_heads,
81-
num_kv_heads,
82-
sm_scale,
8382
block_tables,
8483
kv_lens_tensor,
85-
page_size,
8684
max_kv_len,
87-
kv_cache_dtype,
88-
k_scale,
89-
v_scale,
85+
bmm1_scale=k_scale * sm_scale,
86+
bmm2_scale=v_scale,
87+
out=output_trtllm,
9088
)
9189

9290
def time_fn(fn, warmup=10, trials=20):
@@ -125,6 +123,8 @@ def time_fn(fn, warmup=10, trials=20):
125123
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
126124
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
127125

126+
output_baseline = torch.empty(q.shape, dtype=dtype)
127+
128128
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
129129
workspace_buffer,
130130
kv_layout,
@@ -145,7 +145,7 @@ def time_fn(fn, warmup=10, trials=20):
145145
)
146146

147147
def baseline_decode():
148-
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
148+
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
149149

150150
baseline_mean, baseline_std = time_fn(baseline_decode)
151151

@@ -214,25 +214,39 @@ def write_results_to_csv(results, filename=None):
214214
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
215215
all_results = []
216216

217-
print("Running benchmark for kv_cache_dtype: bfloat16")
218217
print(
219-
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
218+
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
219+
"output_dtype: bfloat16"
220+
)
221+
print(
222+
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
223+
"baseline_std\tspeedup_percent"
220224
)
221225
for max_seq_len in max_seq_lens:
222226
for bs in num_seqs:
223227
result = benchmark_decode(
224-
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
228+
bs,
229+
max_seq_len,
230+
dtype=torch.bfloat16,
231+
kv_cache_dtype="auto",
225232
)
226233
all_results.append(result)
227234

228-
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
229235
print(
230-
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
236+
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
237+
"output_dtype: bfloat16"
238+
)
239+
print(
240+
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
241+
"baseline_std\tspeedup_percent"
231242
)
232243
for max_seq_len in max_seq_lens:
233244
for bs in num_seqs:
234245
result = benchmark_decode(
235-
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
246+
bs,
247+
max_seq_len,
248+
dtype=torch.bfloat16,
249+
kv_cache_dtype="fp8",
236250
)
237251
all_results.append(result)
238252

tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
113113
kv_data_type=dtype,
114114
logits_soft_cap=soft_cap)
115115

116-
output = wrapper.run(query, key_value_cache, scale)
116+
output = torch.empty(query.shape, dtype=dtype)
117+
wrapper.run(query, key_value_cache, scale, out=output)
117118

118119
# TRTLLM Decode
119120
max_kv_len = max(kv_lens)
120121
kv_lens_tensor = torch.tensor(kv_lens,
121122
dtype=torch.int,
122123
device=query.device)
123-
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
124+
output_trtllm = torch.empty(query.shape, dtype=dtype)
125+
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
124126
query.contiguous(),
125127
key_value_cache,
126128
workspace_buffer,
127-
num_query_heads,
128-
num_kv_heads,
129-
scale,
130129
block_tables,
131130
kv_lens_tensor,
132-
block_size,
133131
max_kv_len,
134-
"auto",
135-
k_scale,
136-
v_scale,
132+
bmm1_scale=k_scale * scale,
133+
bmm2_scale=v_scale,
134+
out=output_trtllm,
137135
)
138136

139137
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \

vllm/attention/backends/flashinfer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,11 @@ def forward(
11041104
window_left = window_size[0] if window_size is not None else -1
11051105

11061106
prefill_output: Optional[torch.Tensor] = None
1107-
decode_output: Optional[torch.Tensor] = None
1107+
if num_decode_tokens > 0:
1108+
decode_output = torch.empty(decode_query.shape,
1109+
dtype=decode_query.dtype)
1110+
else:
1111+
decode_output = None
11081112
stride_order = FlashInferBackend.get_kv_cache_stride_order()
11091113
if prefill_meta := attn_metadata.prefill_metadata:
11101114
# We will use flash attention for prefill
@@ -1155,17 +1159,18 @@ def forward(
11551159
num_decode_tokens, attn_metadata.max_decode_seq_len,
11561160
kv_cache_dtype, attn_metadata.num_qo_heads,
11571161
attn_metadata.num_kv_heads, attn_metadata.head_dim):
1158-
decode_output = decode_meta.decode_wrapper.run(
1162+
decode_meta.decode_wrapper.run(
11591163
decode_query,
11601164
kv_cache.permute(*stride_order),
11611165
k_scale=layer._k_scale_float,
11621166
v_scale=layer._v_scale_float,
1167+
out=decode_output,
11631168
)
11641169
else:
11651170
workspace_buffer = (
1166-
decode_meta.decode_wrapper._int_workspace_buffer)
1171+
decode_meta.decode_wrapper._float_workspace_buffer)
11671172
assert FlashInferState.get_kv_cache_layout() == "HND"
1168-
decode_output = trtllm_batch_decode_with_kv_cache(
1173+
trtllm_batch_decode_with_kv_cache(
11691174
query=decode_query,
11701175
kv_cache=kv_cache.permute(*stride_order),
11711176
workspace_buffer=workspace_buffer,
@@ -1174,6 +1179,7 @@ def forward(
11741179
max_seq_len=attn_metadata.max_decode_seq_len,
11751180
bmm1_scale=layer._k_scale_float * softmax_scale,
11761181
bmm2_scale=layer._v_scale_float,
1182+
out=decode_output,
11771183
)
11781184

11791185
if prefill_output is None and decode_output is not None:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ class FlashInferMetadata:
194194
max_seq_len: int
195195
seq_lens: torch.Tensor
196196
block_table_tensor: torch.Tensor
197-
workspace_buffer: torch.Tensor
198197

199198
# For handling prefill decode split
200199
num_decodes: int
@@ -473,7 +472,6 @@ def build(self,
473472
max_seq_len=max_seq_len,
474473
seq_lens=seq_lens,
475474
block_table_tensor=block_table_tensor,
476-
workspace_buffer=self._get_workspace_buffer(),
477475
)
478476

479477
self._plan(num_prefills, num_decodes, attn_metadata)
@@ -641,11 +639,11 @@ def forward(
641639
if decode_wrapper := attn_metadata.decode_wrapper:
642640
decode_query = query[:num_decode_tokens]
643641
assert decode_query.shape[0] == num_decode_tokens
642+
assert decode_wrapper is not None
644643
if not FlashInferBackend.use_trtllm_decode_attention(
645644
attn_metadata.num_decodes, attn_metadata.max_seq_len,
646645
self.kv_cache_dtype, attn_metadata.num_qo_heads,
647646
attn_metadata.num_kv_heads, attn_metadata.head_dim):
648-
assert decode_wrapper is not None
649647
assert decode_wrapper._window_left == window_left
650648
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
651649
or 0.0)
@@ -666,22 +664,24 @@ def forward(
666664
num_decode_tokens]
667665
seq_lens_decode = attn_metadata.seq_lens[:
668666
num_decode_tokens]
667+
workspace_buffer = decode_wrapper._float_workspace_buffer
669668

670669
assert get_kv_cache_layout() == "HND"
671670
assert decode_query.is_contiguous()
672671
assert kv_cache_permute.is_contiguous()
673672
assert block_tables_decode.is_contiguous()
674673
assert seq_lens_decode.is_contiguous()
675-
676-
output[:num_decode_tokens] = (
677-
trtllm_batch_decode_with_kv_cache(
678-
query=decode_query,
679-
kv_cache=kv_cache_permute,
680-
workspace_buffer=attn_metadata.workspace_buffer,
681-
block_tables=block_tables_decode,
682-
seq_lens=seq_lens_decode,
683-
max_seq_len=attn_metadata.max_seq_len,
684-
bmm1_scale=layer._k_scale_float * self.scale,
685-
bmm2_scale=layer._v_scale_float,
686-
))
674+
assert workspace_buffer.is_contiguous()
675+
676+
trtllm_batch_decode_with_kv_cache(
677+
query=decode_query,
678+
kv_cache=kv_cache_permute,
679+
workspace_buffer=workspace_buffer,
680+
block_tables=block_tables_decode,
681+
seq_lens=seq_lens_decode,
682+
max_seq_len=attn_metadata.max_seq_len,
683+
bmm1_scale=layer._k_scale_float * self.scale,
684+
bmm2_scale=layer._v_scale_float,
685+
out=output[:num_decode_tokens],
686+
)
687687
return output_padded

0 commit comments

Comments
 (0)