@@ -194,7 +194,6 @@ class FlashInferMetadata:
194
194
max_seq_len : int
195
195
seq_lens : torch .Tensor
196
196
block_table_tensor : torch .Tensor
197
- workspace_buffer : torch .Tensor
198
197
199
198
# For handling prefill decode split
200
199
num_decodes : int
@@ -473,7 +472,6 @@ def build(self,
473
472
max_seq_len = max_seq_len ,
474
473
seq_lens = seq_lens ,
475
474
block_table_tensor = block_table_tensor ,
476
- workspace_buffer = self ._get_workspace_buffer (),
477
475
)
478
476
479
477
self ._plan (num_prefills , num_decodes , attn_metadata )
@@ -641,11 +639,11 @@ def forward(
641
639
if decode_wrapper := attn_metadata .decode_wrapper :
642
640
decode_query = query [:num_decode_tokens ]
643
641
assert decode_query .shape [0 ] == num_decode_tokens
642
+ assert decode_wrapper is not None
644
643
if not FlashInferBackend .use_trtllm_decode_attention (
645
644
attn_metadata .num_decodes , attn_metadata .max_seq_len ,
646
645
self .kv_cache_dtype , attn_metadata .num_qo_heads ,
647
646
attn_metadata .num_kv_heads , attn_metadata .head_dim ):
648
- assert decode_wrapper is not None
649
647
assert decode_wrapper ._window_left == window_left
650
648
assert decode_wrapper ._logits_soft_cap == (self .logits_soft_cap
651
649
or 0.0 )
@@ -666,18 +664,20 @@ def forward(
666
664
num_decode_tokens ]
667
665
seq_lens_decode = attn_metadata .seq_lens [:
668
666
num_decode_tokens ]
667
+ workspace_buffer = decode_wrapper ._float_workspace_buffer
669
668
670
669
assert get_kv_cache_layout () == "HND"
671
670
assert decode_query .is_contiguous ()
672
671
assert kv_cache_permute .is_contiguous ()
673
672
assert block_tables_decode .is_contiguous ()
674
673
assert seq_lens_decode .is_contiguous ()
674
+ assert workspace_buffer .is_contiguous ()
675
675
676
676
output [:num_decode_tokens ] = (
677
677
trtllm_batch_decode_with_kv_cache (
678
678
query = decode_query ,
679
679
kv_cache = kv_cache_permute ,
680
- workspace_buffer = attn_metadata . workspace_buffer ,
680
+ workspace_buffer = workspace_buffer ,
681
681
num_heads = self .num_heads ,
682
682
num_kv_heads = self .num_kv_heads ,
683
683
scale = self .scale ,
0 commit comments