From 62894736132ef15b28f4f694e983b74f10ad50b1 Mon Sep 17 00:00:00 2001 From: Weiliang Liu Date: Thu, 24 Jul 2025 03:20:36 +0000 Subject: [PATCH] update flashinfer to v0.2.9rc1 Signed-off-by: Weiliang Liu --- docker/Dockerfile | 2 +- vllm/attention/backends/flashinfer.py | 10 +++------- vllm/v1/attention/backends/flashinfer.py | 9 ++------- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index d1fa92ce6d19..4c592f7d1508 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -390,7 +390,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -ARG FLASHINFER_GIT_REF="v0.2.8rc1" +ARG FLASHINFER_GIT_REF="v0.2.9rc1" RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 56d3da699f40..e6e60e756248 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1169,16 +1169,12 @@ def forward( query=decode_query, kv_cache=kv_cache.permute(*stride_order), workspace_buffer=workspace_buffer, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - scale=softmax_scale, block_tables=attn_metadata.block_tables, seq_lens=decode_meta.seq_lens_tensor, - block_size=attn_metadata.page_size, max_seq_len=attn_metadata.max_decode_seq_len, - kv_cache_dtype=kv_cache_dtype, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float) + bmm1_scale=layer._k_scale_float * softmax_scale, + bmm2_scale=layer._v_scale_float, + ) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 953ef26c8143..e19d941350c8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -669,15 +669,10 @@ def forward( query=decode_query, kv_cache=kv_cache_permute, workspace_buffer=attn_metadata.workspace_buffer, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - scale=self.scale, block_tables=block_tables_decode, seq_lens=seq_lens_decode, - block_size=attn_metadata.page_size, max_seq_len=attn_metadata.max_seq_len, - kv_cache_dtype=self.kv_cache_dtype, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, + bmm1_scale=layer._k_scale_float * self.scale, + bmm2_scale=layer._v_scale_float, )) return output_padded