Skip to content

Commit 1cd6eab

Browse files
Support encoder-only models without KV-Cache (#21270)
Signed-off-by: Max de Bayser <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Russell Bryant <[email protected]>
1 parent f27fdfc commit 1cd6eab

File tree

17 files changed

+352
-99
lines changed

17 files changed

+352
-99
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import argparse
44
import datetime
55
import os
6-
import re
76
from typing import Union
87

98
import albumentations
109
import numpy as np
1110
import rasterio
11+
import regex as re
1212
import torch
1313
from einops import rearrange
1414
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

tests/conftest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,17 @@ def score(
10621062
return [req_output.outputs.score for req_output in req_outputs]
10631063

10641064
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1065-
executor = self.llm.llm_engine.model_executor
1066-
return executor.apply_model(func)
1065+
if hasattr(self.llm.llm_engine, "model_executor"):
1066+
# This works either in V0 or in V1 with
1067+
# VLLM_ENABLE_V1_MULTIPROCESSING=0
1068+
executor = self.llm.llm_engine.model_executor
1069+
return executor.apply_model(func)
1070+
1071+
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
1072+
def _apply_model(self):
1073+
return func(self.get_model())
1074+
1075+
return self.llm.llm_engine.collective_rpc(_apply_model)
10671076

10681077
def __enter__(self):
10691078
return self

tests/model_executor/test_model_load_with_params.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323
@pytest.mark.skipif(current_platform.is_rocm(),
2424
reason="Xformers backend is not supported on ROCm.")
25-
def test_model_loading_with_params(vllm_runner):
25+
def test_model_loading_with_params(vllm_runner, monkeypatch):
2626
"""
2727
Test parameter weight loading with tp>1.
2828
"""
29+
# to use apply_model
30+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
2931
with vllm_runner(model_name=MODEL_NAME,
3032
revision=REVISION,
3133
dtype="float16",
@@ -61,10 +63,12 @@ def check_model(model):
6163

6264
@pytest.mark.skipif(current_platform.is_rocm(),
6365
reason="Xformers backend is not supported on ROCm.")
64-
def test_roberta_model_loading_with_params(vllm_runner):
66+
def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
6567
"""
6668
Test parameter weight loading with tp>1.
6769
"""
70+
# to use apply_model
71+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
6872
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
6973
revision=REVISION_ROBERTA,
7074
dtype="float16",
@@ -101,10 +105,12 @@ def check_model(model):
101105

102106
@pytest.mark.skipif(current_platform.is_rocm(),
103107
reason="Xformers backend is not supported on ROCm.")
104-
def test_facebook_roberta_model_loading_with_params(vllm_runner):
108+
def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch):
105109
"""
106110
Test loading roberta-base model with no lm_head.
107111
"""
112+
# to use apply_model
113+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
108114
model_name = "FacebookAI/roberta-base"
109115
with vllm_runner(model_name=model_name,
110116
dtype="float16",

tests/models/language/pooling/test_embedding.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,9 @@ def v1(run_with_both_engines):
3939
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
4040
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
4141
# [Encoder-only]
42-
pytest.param(
43-
"BAAI/bge-base-en-v1.5",
44-
marks=[
45-
# CPU only supports V1
46-
pytest.mark.core_model,
47-
pytest.mark.skip_v1
48-
]),
49-
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
50-
marks=[pytest.mark.skip_v1]),
51-
pytest.param("intfloat/multilingual-e5-small",
52-
marks=[pytest.mark.skip_v1]),
42+
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
43+
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
44+
pytest.param("intfloat/multilingual-e5-small"),
5345
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
5446
marks=[pytest.mark.skip_v1]),
5547
# [Cross-Encoder]

tests/models/language/pooling/test_jina.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
]
2424

2525

26+
@pytest.fixture(autouse=True)
27+
def v1(run_with_both_engines):
28+
# Simple autouse wrapper to run both engines for each test
29+
# This can be promoted up to conftest.py to run for every
30+
# test in a package
31+
pass
32+
33+
2634
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
2735
def test_embed_models_mteb(hf_runner, vllm_runner,
2836
model_info: EmbedModelInfo) -> None:

tests/v1/attention/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def create_common_attn_metadata(
9393
max_query_len=max_query_len,
9494
block_table_tensor=block_table_tensor,
9595
slot_mapping=slot_mapping,
96+
causal=True,
9697
)
9798

9899

tests/v1/test_oracle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"openai/whisper-large-v3", # transcription
1414
"facebook/bart-large-cnn", # encoder decoder
1515
"state-spaces/mamba-130m-hf", # mamba1
16-
"BAAI/bge-m3", # embedding
1716
]
1817

1918
MODEL = "meta-llama/Llama-3.2-1B-Instruct"

tests/v1/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import re
5-
64
import pytest
5+
import regex as re
76
import requests
87
import torch
98

vllm/engine/arg_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,
16491649

16501650
if (self.max_num_seqs is None
16511651
and usage_context in default_max_num_seqs):
1652-
self.max_num_seqs = default_max_num_seqs[usage_context]
1652+
self.max_num_seqs = min(default_max_num_seqs[usage_context],
1653+
self.max_num_batched_tokens or sys.maxsize)
16531654

16541655
logger.debug("Setting max_num_seqs to %d for %s usage context.",
16551656
self.max_num_seqs, use_context_value)

vllm/model_executor/models/bert.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.compilation.decorators import support_torch_compile
1313
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
1414
from vllm.distributed import get_tensor_model_parallel_world_size
15-
from vllm.forward_context import get_forward_context
1615
from vllm.model_executor.layers.activation import get_act_fn
1716
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1817
QKVParallelLinear,
@@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
6059
def forward(
6160
self,
6261
input_ids: torch.Tensor,
63-
seq_lens: torch.Tensor,
6462
position_ids: torch.Tensor,
6563
token_type_ids: Optional[torch.Tensor] = None,
6664
) -> torch.Tensor:
@@ -119,7 +117,6 @@ def forward(
119117
return pooled_output
120118

121119

122-
@support_torch_compile
123120
class BertEncoder(nn.Module):
124121

125122
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
@@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
337334
return hidden_states
338335

339336

337+
@support_torch_compile
340338
class BertModel(nn.Module, SupportsQuant):
341339

342340
is_pooling_model = True
@@ -368,13 +366,9 @@ def forward(
368366
if inputs_embeds is not None:
369367
hidden_states = inputs_embeds
370368
else:
371-
attn_metadata = get_forward_context().attn_metadata
372-
assert hasattr(attn_metadata, "seq_lens_tensor")
373-
hidden_states = self.embeddings(
374-
input_ids=input_ids,
375-
seq_lens=attn_metadata.seq_lens_tensor,
376-
position_ids=position_ids,
377-
token_type_ids=token_type_ids)
369+
hidden_states = self.embeddings(input_ids=input_ids,
370+
position_ids=position_ids,
371+
token_type_ids=token_type_ids)
378372
return self.encoder(hidden_states)
379373

380374
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447441
return loaded_params
448442

449443

450-
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
444+
class BertEmbeddingModel(nn.Module, SupportsQuant):
451445
"""A model that uses Bert to provide embedding functionalities.
452446
453447
This class encapsulates the BertModel and provides an interface for
@@ -474,11 +468,13 @@ def forward(
474468
self,
475469
input_ids: Optional[torch.Tensor],
476470
positions: torch.Tensor,
471+
token_type_ids: Optional[torch.Tensor] = None,
477472
intermediate_tensors: Optional[IntermediateTensors] = None,
478473
inputs_embeds: Optional[torch.Tensor] = None,
479474
) -> torch.Tensor:
480475
return self.model(input_ids=input_ids,
481476
position_ids=positions,
477+
token_type_ids=token_type_ids,
482478
inputs_embeds=inputs_embeds,
483479
intermediate_tensors=intermediate_tensors)
484480

0 commit comments

Comments
 (0)