Skip to content

Commit 78b341b

Browse files
committed
fix None workspace buffer issue
Signed-off-by: elvischenv <[email protected]>
1 parent 5a19a6c commit 78b341b

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ def forward(
11631163
)
11641164
else:
11651165
workspace_buffer = (
1166-
decode_meta.decode_wrapper._int_workspace_buffer)
1166+
decode_meta.decode_wrapper._float_workspace_buffer)
11671167
assert FlashInferState.get_kv_cache_layout() == "HND"
11681168
decode_output = trtllm_batch_decode_with_kv_cache(
11691169
query=decode_query,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ class FlashInferMetadata:
194194
max_seq_len: int
195195
seq_lens: torch.Tensor
196196
block_table_tensor: torch.Tensor
197-
workspace_buffer: torch.Tensor
198197

199198
# For handling prefill decode split
200199
num_decodes: int
@@ -473,7 +472,6 @@ def build(self,
473472
max_seq_len=max_seq_len,
474473
seq_lens=seq_lens,
475474
block_table_tensor=block_table_tensor,
476-
workspace_buffer=self._get_workspace_buffer(),
477475
)
478476

479477
self._plan(num_prefills, num_decodes, attn_metadata)
@@ -641,11 +639,11 @@ def forward(
641639
if decode_wrapper := attn_metadata.decode_wrapper:
642640
decode_query = query[:num_decode_tokens]
643641
assert decode_query.shape[0] == num_decode_tokens
642+
assert decode_wrapper is not None
644643
if not FlashInferBackend.use_trtllm_decode_attention(
645644
attn_metadata.num_decodes, attn_metadata.max_seq_len,
646645
self.kv_cache_dtype, attn_metadata.num_qo_heads,
647646
attn_metadata.num_kv_heads, attn_metadata.head_dim):
648-
assert decode_wrapper is not None
649647
assert decode_wrapper._window_left == window_left
650648
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
651649
or 0.0)
@@ -666,18 +664,20 @@ def forward(
666664
num_decode_tokens]
667665
seq_lens_decode = attn_metadata.seq_lens[:
668666
num_decode_tokens]
667+
workspace_buffer = decode_wrapper._float_workspace_buffer
669668

670669
assert get_kv_cache_layout() == "HND"
671670
assert decode_query.is_contiguous()
672671
assert kv_cache_permute.is_contiguous()
673672
assert block_tables_decode.is_contiguous()
674673
assert seq_lens_decode.is_contiguous()
674+
assert workspace_buffer.is_contiguous()
675675

676676
output[:num_decode_tokens] = (
677677
trtllm_batch_decode_with_kv_cache(
678678
query=decode_query,
679679
kv_cache=kv_cache_permute,
680-
workspace_buffer=attn_metadata.workspace_buffer,
680+
workspace_buffer=workspace_buffer,
681681
num_heads=self.num_heads,
682682
num_kv_heads=self.num_kv_heads,
683683
scale=self.scale,

0 commit comments

Comments
 (0)