Skip to content

Commit aeec6ab

Browse files
committed
update flashinfer to v0.2.9rc1
1 parent f3137cd commit aeec6ab

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
390390

391391
# Install FlashInfer from source
392392
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
393-
ARG FLASHINFER_GIT_REF="v0.2.8rc1"
393+
ARG FLASHINFER_GIT_REF="v0.2.9rc1"
394394
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
395395
. /etc/environment
396396
git clone --depth 1 --recursive --shallow-submodules \

vllm/attention/backends/flashinfer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,16 +1169,12 @@ def forward(
11691169
query=decode_query,
11701170
kv_cache=kv_cache.permute(*stride_order),
11711171
workspace_buffer=workspace_buffer,
1172-
num_heads=num_heads,
1173-
num_kv_heads=num_kv_heads,
1174-
scale=softmax_scale,
11751172
block_tables=attn_metadata.block_tables,
11761173
seq_lens=decode_meta.seq_lens_tensor,
1177-
block_size=attn_metadata.page_size,
11781174
max_seq_len=attn_metadata.max_decode_seq_len,
1179-
kv_cache_dtype=kv_cache_dtype,
1180-
k_scale=layer._k_scale_float,
1181-
v_scale=layer._v_scale_float)
1175+
bmm1_scale=layer._k_scale_float * softmax_scale,
1176+
bmm2_scale=layer._v_scale_float,
1177+
)
11821178

11831179
if prefill_output is None and decode_output is not None:
11841180
# Decode only batch.

vllm/v1/attention/backends/flashinfer.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -664,20 +664,15 @@ def forward(
664664
assert block_tables_decode.is_contiguous()
665665
assert seq_lens_decode.is_contiguous()
666666

667-
output[:num_decode_tokens] = (
668-
trtllm_batch_decode_with_kv_cache(
669-
query=decode_query,
670-
kv_cache=kv_cache_permute,
671-
workspace_buffer=attn_metadata.workspace_buffer,
672-
num_heads=self.num_heads,
673-
num_kv_heads=self.num_kv_heads,
674-
scale=self.scale,
675-
block_tables=block_tables_decode,
676-
seq_lens=seq_lens_decode,
677-
block_size=attn_metadata.page_size,
678-
max_seq_len=attn_metadata.max_seq_len,
679-
kv_cache_dtype=self.kv_cache_dtype,
680-
k_scale=layer._k_scale_float,
681-
v_scale=layer._v_scale_float,
682-
))
667+
668+
output[:num_decode_tokens] = trtllm_batch_decode_with_kv_cache(
669+
query=decode_query,
670+
kv_cache=kv_cache_permute,
671+
workspace_buffer=attn_metadata.workspace_buffer,
672+
block_tables=block_tables_decode,
673+
seq_lens=seq_lens_decode,
674+
max_seq_len=attn_metadata.max_seq_len,
675+
bmm1_scale=layer._k_scale_float * self.scale,
676+
bmm2_scale=layer._v_scale_float,
677+
)
683678
return output_padded

0 commit comments

Comments
 (0)