Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7f850be
Pass token type ids as pooling param to the model runner
maxdebayser Jul 31, 2025
809384e
fix errors
maxdebayser Jul 31, 2025
6f330b7
fix cudagraph problem
maxdebayser Jul 31, 2025
794aaf2
compress token type ids
maxdebayser Jul 31, 2025
a6f949d
forgot to(gpu)
maxdebayser Jul 31, 2025
56dba67
Address review comments
maxdebayser Jul 31, 2025
cdf802a
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Jul 31, 2025
3fe425a
fix mistake
maxdebayser Jul 31, 2025
4b19f4c
address review comments
maxdebayser Jul 31, 2025
5d0999c
fix type hints
maxdebayser Jul 31, 2025
2074d29
address review comments
maxdebayser Jul 31, 2025
148ab54
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Jul 31, 2025
accf2f7
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 1, 2025
cb935de
change comment order
maxdebayser Aug 1, 2025
a250e5b
fix test error message
maxdebayser Aug 2, 2025
939165f
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 2, 2025
2add932
fix error msg inconsistency
maxdebayser Aug 4, 2025
4df6cd2
sync with gpu after changing input tensors
maxdebayser Aug 4, 2025
0123dc5
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 4, 2025
e486790
increase test tolerance
maxdebayser Aug 4, 2025
164d890
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 5, 2025
7e3b671
add TODO comment
maxdebayser Aug 5, 2025
2cac159
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 5, 2025
29ca69b
rename method
maxdebayser Aug 5, 2025
ed5a7ef
fix editing mistake
maxdebayser Aug 5, 2025
656059b
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 5, 2025
d9a8835
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 6, 2025
3d089dd
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 7, 2025
0471896
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 9, 2025
db612f7
rename argument
maxdebayser Aug 10, 2025
5184a3d
Merge branch 'upstream_main' into v1_token_type_ids
maxdebayser Aug 10, 2025
96e3871
rename argument
maxdebayser Aug 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"The capital of Germany is Berlin.",
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


DTYPE = "half"


Expand Down
54 changes: 26 additions & 28 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
Expand Down Expand Up @@ -1259,7 +1263,8 @@ def _cross_encoding_score(
if len(data_1) == 1:
data_1 = data_1 * len(data_2)

pooling_params = PoolingParams(task="score")
default_pooling_params = PoolingParams(task="score")
pooling_params = list[PoolingParams]()
tokenization_kwargs: dict[str, Any] = {}

_validate_truncation_size(model_config.max_model_len,
Expand All @@ -1269,34 +1274,27 @@ def _cross_encoding_score(

input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

if model_config.is_multimodal_model:
for q, d in input_pairs:
_, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=q,
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
model_config = self.llm_engine.model_config

parsed_prompts.append(engine_prompt)
else:
for q, t in input_pairs:
if model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(
text=q, # type: ignore[arg-type]
text_pair=t, # type: ignore[arg-type]
**tokenization_kwargs)
else:
# `llm as reranker` models defaults to not using pad_token.
prompt_inputs = tokenizer(
text=q + t, # type: ignore[operator]
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
for q, d in input_pairs:
_, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=q,
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)

if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
compressed = compress_token_type_ids(token_type_ids)
pooling_params.append(
PoolingParams(
task="score",
extra_args={"compressed_token_type_ids": compressed}))
else:
pooling_params.append(default_pooling_params)

parsed_prompts.append(engine_prompt)

self._validate_and_add_requests(
prompts=parsed_prompts,
Expand Down
79 changes: 28 additions & 51 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from copy import deepcopy
from typing import Any, Optional, Union

from fastapi import Request
Expand All @@ -17,11 +18,15 @@
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
Expand Down Expand Up @@ -188,64 +193,27 @@ async def _cross_encoding_score(

input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

if self.model_config.is_multimodal_model:
preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)

preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)
preprocessed_prompts = await asyncio.gather(
*(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))

preprocessed_prompts = await asyncio.gather(
*(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))

for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)

else:
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
use_pad_token = self.model_config.use_pad_token

if use_pad_token:
# cross_encoder models defaults to using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1, # type: ignore[arg-type]
text_pair=t2, # type: ignore[arg-type]
**tokenization_kwargs) for t1, t2 in input_pairs))
else:
# `llm as reranker` models defaults to not using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1 + # type: ignore[operator]
t2,
**tokenization_kwargs) for t1, t2 in input_pairs))

for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
sep_token = tokenizer.sep_token if (tokenizer.sep_token
and use_pad_token) else ''
request_prompt = f"{t1}{sep_token}{t2}"

input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)

# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

pooling_params = request.to_pooling_params()
default_pooling_params = request.to_pooling_params()

try:
pooling_params.verify("score", self.model_config)
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))

Expand All @@ -254,9 +222,18 @@ async def _cross_encoding_score(

self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
params=default_pooling_params,
lora_request=lora_request)

if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
pooling_params = deepcopy(default_pooling_params)
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_args = {
"compressed_token_type_ids": compressed
}
else:
pooling_params = (default_pooling_params)

generator = self.engine_client.encode(
engine_prompt,
pooling_params,
Expand Down
40 changes: 37 additions & 3 deletions vllm/entrypoints/score_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,49 @@ def get_score_prompt(
model_config,
tokenizer,
)
from vllm.model_executor.model_loader import get_model_cls

full_prompt = apply_score_template(model_config, prompt_1, prompt_2)

prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
model = get_model_cls(model_config)
if supports_score_template(model):
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
elif model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(text=prompt_1,
text_pair=prompt_2,
**tokenization_kwargs)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` models defaults to not using pad_token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)

engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])

if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids

post_process_tokens(model_config, engine_prompt)

if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return full_prompt, engine_prompt


def compress_token_type_ids(token_type_ids: list[int]) -> int:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to minimize the amount of data that is transferred between processes

"""
Return position of the first 1 or the length of the list
if not found.
"""
first_one = len(token_type_ids)
err_msg = "Token type ids are expected to be a sequence"\
" of zeros followed by a sequence of ones"
for i, type_id in enumerate(token_type_ids):
if type_id == 0 and first_one < i:
raise ValueError(err_msg)
elif type_id == 1 and first_one > i:
first_one = i
elif type_id > 1:
raise ValueError(err_msg)

return first_one
59 changes: 39 additions & 20 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix


Expand Down Expand Up @@ -60,21 +60,13 @@ def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()

# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
token_type_ids = _decode_token_type_ids(input_ids)

# Position embeddings.
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)

token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings + position_embeddings
Expand Down Expand Up @@ -361,14 +353,12 @@ def forward(
position_ids: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
position_ids=position_ids)
return self.encoder(hidden_states)

def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Expand Down Expand Up @@ -468,13 +458,11 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down Expand Up @@ -508,8 +496,33 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
})


class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
TOKEN_TYPE_SHIFT = 30


def _encode_token_type_ids(input_ids: torch.tensor,
token_type_ids: torch.tensor) -> None:

# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(
token_type_ids << TOKEN_TYPE_SHIFT)


def _decode_token_type_ids(input_ids: torch.tensor) -> torch.tensor:

ids_mask = torch.ones(input_ids.shape,
dtype=torch.int32,
device=input_ids.device) << TOKEN_TYPE_SHIFT
tokens_mask = ids_mask.bitwise_not()

token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

input_ids.bitwise_and_(tokens_mask)

return token_type_ids


class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down Expand Up @@ -554,6 +567,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
})

assert config.vocab_size < (1 << TOKEN_TYPE_SHIFT)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
Expand All @@ -567,8 +582,12 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if token_type_ids is not None:
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)

return self.bert(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids)
intermediate_tensors=intermediate_tensors)
Loading