Skip to content

Commit b568d6d

Browse files
russellbNickLucche
andcommitted
v1: Add Whisper encoder-decoder model support
Implements Whisper mdoel support in the V1 engine. Key changes include: - Add encoder-decoder architecture support with cross-attention KV cache management - Add CrossAttentionManager and CrossAttentionSpec for encoder-decoder KV cache - Update scheduler to handle cross-attention block allocation and disable prefix caching - Modify GPU model runner for encoder input processing and attention metadata - Disable BART / other enc-dec tests/examples (Whisper-only support for now) - Optimize test performance and fix various integration issues This closes a major feature gap between V0 and V1, enabling Whisper transcription in the new engine architecture while maintaining backward compatibility. Related to V0 deprecation (#18571) and 2025 Q3 roadmap (#20336). Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: NickLucche <[email protected]> Signed-off-by: Russell Bryant <[email protected]>
1 parent 04d1dd7 commit b568d6d

File tree

19 files changed

+550
-72
lines changed

19 files changed

+550
-72
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ steps:
295295
- python3 offline_inference/vision_language_pooling.py --seed 0
296296
- python3 offline_inference/vision_language_multi_image.py --seed 0
297297
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
298-
- python3 offline_inference/encoder_decoder.py
299298
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
300299
- python3 offline_inference/basic/classify.py
301300
- python3 offline_inference/basic/embed.py
@@ -580,7 +579,7 @@ steps:
580579
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
581580
- pip freeze | grep -E 'torch'
582581
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
583-
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
582+
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
584583

585584
- label: Multi-Modal Models Test (Extended) 1
586585
mirror_hardwares: [amdexperimental]

examples/offline_inference/encoder_decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
encoder/decoder models, specifically BART and mBART.
66
77
This script is refactored to allow model selection via command-line arguments.
8+
9+
NOTE: This example is not yet supported in V1.
810
"""
911

1012
import argparse

tests/encoder_decoder/test_e2e_correctness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def clear_cache():
6363
current_platform.is_cpu(),
6464
reason="CPU backend is not currently supported with encoder/decoder models"
6565
)
66+
@pytest.mark.skip(reason="bart not supported in V1")
6667
def test_encoder_decoder_e2e(
6768
hf_runner,
6869
vllm_runner,

tests/entrypoints/openai/correctness/test_transcription_api_correctness.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr):
4949
return latency, num_output_tokens, transcription.text
5050

5151

52-
async def bound_transcribe(model_name, sem, client, audio, reference):
53-
tokenizer = AutoTokenizer.from_pretrained(model_name)
52+
async def bound_transcribe(sem, client, tokenizer, audio, reference):
5453
# Use semaphore to limit concurrent requests.
5554
async with sem:
5655
result = await transcribe_audio(client, tokenizer, *audio)
@@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference):
6362
async def process_dataset(model, client, data, concurrent_request):
6463
sem = asyncio.Semaphore(concurrent_request)
6564

65+
# Load tokenizer once outside the loop
66+
tokenizer = AutoTokenizer.from_pretrained(model)
67+
6668
# Warmup call as the first `librosa.load` server-side is quite slow.
6769
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
68-
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
70+
_ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "")
6971

7072
tasks: list[asyncio.Task] = []
7173
for sample in data:
7274
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
7375
task = asyncio.create_task(
74-
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
76+
bound_transcribe(sem, client, tokenizer, (audio, sr),
77+
sample["text"]))
7578
tasks.append(task)
7679
return await asyncio.gather(*tasks)
7780

tests/entrypoints/openai/test_encoder_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ async def client(server):
3030

3131
@pytest.mark.asyncio
3232
@pytest.mark.parametrize("model_name", [MODEL_NAME])
33+
@pytest.mark.skip(reason="bart is not yet supported in V1")
3334
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
3435
completion = await client.completions.create(model=model_name,
3536
prompt="Hello, my name is",

tests/models/language/generation/test_bart.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def run_test(
178178
@pytest.mark.parametrize("max_tokens", [64])
179179
@pytest.mark.parametrize("num_logprobs", [5])
180180
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
181+
@pytest.mark.skip(reason="bart not supported in V1")
181182
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
182183
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
183184

@@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
201202
@pytest.mark.parametrize("max_tokens", [64])
202203
@pytest.mark.parametrize("num_logprobs", [5])
203204
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
205+
@pytest.mark.skip(reason="bart not supported in V1")
204206
def test_models_distributed(hf_runner, vllm_runner,
205207
example_encoder_decoder_prompts,
206208
distributed_executor_backend, model, dtype,

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
ARCH_TO_SKIP = {
3232
"MolmoForCausalLM": "incompatible requirements",
33+
"Florence2ForConditionalGeneration": "not supported in V1",
3334
}
3435
ARCH_NEEDS_EXTRAS = [
3536
"InternVLChatModel",

tests/models/test_initialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def _initialize_kv_caches_v1(self, vllm_config):
6868
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
6969
# L4 supports FA3.
7070
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
71+
if model_arch == "Florence2ForConditionalGeneration":
72+
# An encoder-decoder model that's V0-only. Just skip it
73+
# since V0 is about to be removed.
74+
pytest.skip("Skipping Florence2ForConditionalGeneration")
75+
if model_arch == "WhisperForConditionalGeneration":
76+
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
7177
LLM(
7278
model_info.default,
7379
tokenizer=model_info.tokenizer,

tests/v1/test_oracle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.engine.async_llm_engine import AsyncLLMEngine
1111

1212
UNSUPPORTED_MODELS_V1 = [
13-
"openai/whisper-large-v3", # transcription
1413
"facebook/bart-large-cnn", # encoder decoder
1514
]
1615

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import functools
4+
from copy import copy
5+
from typing import Optional
6+
7+
import torch
8+
from transformers import CacheConfig
9+
10+
from vllm import envs
11+
from vllm.attention.backends.abstract import (AttentionBackend,
12+
AttentionMetadata, AttentionType)
13+
from vllm.attention.layer import Attention
14+
from vllm.attention.selector import get_attn_backend
15+
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
16+
subclass_attention_backend)
17+
18+
19+
@functools.lru_cache
20+
def create_cross_attention_backend(
21+
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
22+
prefix = "CrossAttention_"
23+
underlying_builder = underlying_attn_backend.get_builder_cls()
24+
25+
class CrossAttentionBuilder(underlying_builder): # type: ignore
26+
27+
def build(self,
28+
common_prefix_len: int,
29+
common_attn_metadata: CommonAttentionMetadata,
30+
fast_build: bool = False) -> AttentionMetadata:
31+
# Cross-attention metadata is built in GPU model runner
32+
# We just ensure it's non-causal and pass through
33+
new_common_attn_metadata = copy(common_attn_metadata)
34+
new_common_attn_metadata.causal = False
35+
return super().build(common_prefix_len, new_common_attn_metadata,
36+
fast_build)
37+
38+
attn_backend = subclass_attention_backend(
39+
name_prefix=prefix,
40+
attention_backend_cls=underlying_attn_backend,
41+
builder_cls=CrossAttentionBuilder)
42+
43+
return attn_backend
44+
45+
46+
class CrossAttention(Attention):
47+
"""
48+
Cross-attention for encoder-decoder models.
49+
Handles attention between decoder queries and encoder keys/values.
50+
"""
51+
52+
def __init__(self,
53+
num_heads: int,
54+
head_size: int,
55+
scale: float,
56+
cache_config: Optional[CacheConfig] = None,
57+
attn_type: Optional[str] = None,
58+
**kwargs):
59+
dtype = torch.get_default_dtype()
60+
61+
if cache_config is not None:
62+
kv_cache_dtype = cache_config.cache_dtype
63+
block_size = cache_config.block_size
64+
else:
65+
kv_cache_dtype = "auto"
66+
block_size = 16
67+
68+
if envs.VLLM_USE_V1:
69+
underlying_attn_backend = get_attn_backend(head_size, dtype,
70+
kv_cache_dtype,
71+
block_size)
72+
73+
attn_backend = create_cross_attention_backend(
74+
underlying_attn_backend)
75+
else:
76+
# in v0 cross attention is handled inside the backends
77+
attn_backend = None
78+
79+
if attn_type is not None:
80+
assert attn_type == AttentionType.ENCODER_DECODER, (
81+
"CrossAttention only supports AttentionType.ENCODER_DECODER")
82+
83+
super().__init__(num_heads=num_heads,
84+
head_size=head_size,
85+
scale=scale,
86+
cache_config=cache_config,
87+
attn_backend=attn_backend,
88+
attn_type=AttentionType.ENCODER_DECODER,
89+
**kwargs)

0 commit comments

Comments
 (0)