Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors


class MyGemma2Embedding(nn.Module):

is_pooling_model = True

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -24,7 +26,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
Expand Down Expand Up @@ -54,13 +56,6 @@ def forward(
# Return all-zero embeddings
return torch.zeros_like(hidden_states)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

weights = self.hf_to_vllm_mapper.apply(weights)
Expand Down
34 changes: 5 additions & 29 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# --8<-- [start:embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:embedding-pooling-params]

# --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field(
default=True,
Expand All @@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params]

def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions)


class EmbeddingChatRequest(OpenAIBaseModel):
Expand All @@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# --8<-- [start:chat-embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:chat-embedding-pooling-params]

# --8<-- [start:chat-embedding-extra-params]
add_special_tokens: bool = Field(
default=False,
Expand Down Expand Up @@ -1323,8 +1314,7 @@ def check_generation_prompt(cls, data):
return data

def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions)


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
Expand All @@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
text_2: Union[list[str], str, ScoreMultiModalParam]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# --8<-- [start:score-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:score-pooling-params]

# --8<-- [start:score-extra-params]

mm_processor_kwargs: Optional[dict[str, Any]] = Field(
Expand All @@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]

def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
return PoolingParams(use_cross_encoder=use_cross_encoder)


class RerankRequest(OpenAIBaseModel):
Expand All @@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# --8<-- [start:rerank-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:rerank-pooling-params]

# --8<-- [start:rerank-extra-params]

mm_processor_kwargs: Optional[dict[str, Any]] = Field(
Expand All @@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]

def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
return PoolingParams(use_cross_encoder=use_cross_encoder)


class RerankDocument(BaseModel):
Expand Down Expand Up @@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[int] = None
user: Optional[str] = None

# --8<-- [start:classification-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:classification-pooling-params]

# --8<-- [start:classification-extra-params]
priority: int = Field(
default=0,
Expand All @@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params]

def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
return PoolingParams()


class ClassificationData(OpenAIBaseModel):
Expand Down
Loading