Skip to content

Commit 2ac6d0e

Browse files
[Misc] Consolidate pooler config overrides (#10351)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 2ec8827 commit 2ac6d0e

File tree

7 files changed

+143
-192
lines changed

7 files changed

+143
-192
lines changed

docs/source/models/supported_models.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ Text Embedding
345345
Some model architectures support both generation and embedding tasks.
346346
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
347347

348+
.. tip::
349+
You can override the model's pooling method by passing :code:`--override-pooler-config`.
350+
348351
Reward Modeling
349352
---------------
350353

@@ -364,7 +367,7 @@ Reward Modeling
364367
- ✅︎
365368

366369
.. note::
367-
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
370+
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
368371

369372
Classification
370373
---------------
@@ -385,7 +388,7 @@ Classification
385388
- ✅︎
386389

387390
.. note::
388-
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
391+
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
389392

390393

391394
Multimodal Language Models
@@ -600,6 +603,9 @@ Multimodal Embedding
600603
Some model architectures support both generation and embedding tasks.
601604
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
602605

606+
.. tip::
607+
You can override the model's pooling method by passing :code:`--override-pooler-config`.
608+
603609
Model Support Policy
604610
=====================
605611

tests/engine/test_arg_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from vllm.config import PoolerConfig
56
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
67
from vllm.utils import FlexibleArgumentParser
78

@@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
3233

3334
def test_valid_pooling_config():
3435
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
35-
args = parser.parse_args(["--pooling-type=MEAN"])
36+
args = parser.parse_args([
37+
'--override-pooler-config',
38+
'{"pooling_type": "MEAN"}',
39+
])
3640
engine_args = EngineArgs.from_cli_args(args=args)
37-
assert engine_args.pooling_type == 'MEAN'
41+
assert engine_args.override_pooler_config == PoolerConfig(
42+
pooling_type="MEAN", )
3843

3944

4045
@pytest.mark.parametrize(

tests/test_config.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from dataclasses import asdict
2+
13
import pytest
24

3-
from vllm.config import ModelConfig
5+
from vllm.config import ModelConfig, PoolerConfig
46
from vllm.model_executor.layers.pooler import PoolingType
57
from vllm.platforms import current_platform
68

@@ -108,7 +110,7 @@ def test_get_sliding_window():
108110
reason="Xformers backend is not supported on ROCm.")
109111
def test_get_pooling_config():
110112
model_id = "sentence-transformers/all-MiniLM-L12-v2"
111-
minilm_model_config = ModelConfig(
113+
model_config = ModelConfig(
112114
model_id,
113115
task="auto",
114116
tokenizer=model_id,
@@ -119,39 +121,31 @@ def test_get_pooling_config():
119121
revision=None,
120122
)
121123

122-
minilm_pooling_config = minilm_model_config._init_pooler_config(
123-
pooling_type=None,
124-
pooling_norm=None,
125-
pooling_returned_token_ids=None,
126-
pooling_softmax=None,
127-
pooling_step_tag_id=None)
124+
pooling_config = model_config._init_pooler_config(None)
125+
assert pooling_config is not None
128126

129-
assert minilm_pooling_config.pooling_norm
130-
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
127+
assert pooling_config.normalize
128+
assert pooling_config.pooling_type == PoolingType.MEAN.name
131129

132130

133131
@pytest.mark.skipif(current_platform.is_rocm(),
134132
reason="Xformers backend is not supported on ROCm.")
135133
def test_get_pooling_config_from_args():
136134
model_id = "sentence-transformers/all-MiniLM-L12-v2"
137-
minilm_model_config = ModelConfig(model_id,
138-
task="auto",
139-
tokenizer=model_id,
140-
tokenizer_mode="auto",
141-
trust_remote_code=False,
142-
seed=0,
143-
dtype="float16",
144-
revision=None)
145-
146-
minilm_pooling_config = minilm_model_config._init_pooler_config(
147-
pooling_type='CLS',
148-
pooling_norm=True,
149-
pooling_returned_token_ids=None,
150-
pooling_softmax=None,
151-
pooling_step_tag_id=None)
152-
153-
assert minilm_pooling_config.pooling_norm
154-
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
135+
model_config = ModelConfig(model_id,
136+
task="auto",
137+
tokenizer=model_id,
138+
tokenizer_mode="auto",
139+
trust_remote_code=False,
140+
seed=0,
141+
dtype="float16",
142+
revision=None)
143+
144+
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
145+
146+
pooling_config = model_config._init_pooler_config(override_config)
147+
assert pooling_config is not None
148+
assert asdict(pooling_config) == asdict(override_config)
155149

156150

157151
@pytest.mark.skipif(current_platform.is_rocm(),

vllm/config.py

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,19 @@ class ModelConfig:
112112
the model name will be the same as `model`.
113113
limit_mm_per_prompt: Maximum number of data items per modality
114114
per prompt. Only applicable for multimodal models.
115-
override_neuron_config: Initialize non default neuron config or
116-
override default neuron config that are specific to Neuron devices,
117-
this argument will be used to configure the neuron config that
118-
can not be gathered from the vllm arguments.
119115
config_format: The config format which shall be loaded.
120116
Defaults to 'auto' which defaults to 'hf'.
121117
hf_overrides: If a dictionary, contains arguments to be forwarded to the
122118
HuggingFace config. If a callable, it is called to update the
123119
HuggingFace config.
124120
mm_processor_kwargs: Arguments to be forwarded to the model's processor
125121
for multi-modal data, e.g., image processor.
126-
pooling_type: Used to configure the pooling method in the embedding
127-
model.
128-
pooling_norm: Used to determine whether to normalize the pooled
129-
data in the embedding model.
130-
pooling_softmax: Used to determine whether to softmax the pooled
131-
data in the embedding model.
132-
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
133-
that the score corresponding to the pooling_step_tag_id in the
134-
generated sentence should be returned. Otherwise, it returns
135-
the scores for all tokens.
136-
pooling_returned_token_ids: pooling_returned_token_ids represents a
137-
list of indices for the vocabulary dimensions to be extracted,
138-
such as the token IDs of good_token and bad_token in the
139-
math-shepherd-mistral-7b-prm model.
122+
override_neuron_config: Initialize non default neuron config or
123+
override default neuron config that are specific to Neuron devices,
124+
this argument will be used to configure the neuron config that
125+
can not be gathered from the vllm arguments.
126+
override_pooling_config: Initialize non default pooling config or
127+
override default pooling config for the embedding model.
140128
"""
141129

142130
def __init__(
@@ -166,16 +154,12 @@ def __init__(
166154
served_model_name: Optional[Union[str, List[str]]] = None,
167155
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
168156
use_async_output_proc: bool = True,
169-
override_neuron_config: Optional[Dict[str, Any]] = None,
170157
config_format: ConfigFormat = ConfigFormat.AUTO,
171158
chat_template_text_format: str = "string",
172159
hf_overrides: Optional[HfOverrides] = None,
173160
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
174-
pooling_type: Optional[str] = None,
175-
pooling_norm: Optional[bool] = None,
176-
pooling_softmax: Optional[bool] = None,
177-
pooling_step_tag_id: Optional[int] = None,
178-
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
161+
override_neuron_config: Optional[Dict[str, Any]] = None,
162+
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
179163
self.model = model
180164
self.tokenizer = tokenizer
181165
self.tokenizer_mode = tokenizer_mode
@@ -280,13 +264,7 @@ def __init__(
280264
supported_tasks, task = self._resolve_task(task, self.hf_config)
281265
self.supported_tasks = supported_tasks
282266
self.task: Final = task
283-
self.pooler_config = self._init_pooler_config(
284-
pooling_type,
285-
pooling_norm,
286-
pooling_softmax,
287-
pooling_step_tag_id,
288-
pooling_returned_token_ids,
289-
)
267+
self.pooler_config = self._init_pooler_config(override_pooler_config)
290268

291269
self._verify_quantization()
292270
self._verify_cuda_graph()
@@ -311,27 +289,21 @@ def _get_encoder_config(self):
311289

312290
def _init_pooler_config(
313291
self,
314-
pooling_type: Optional[str] = None,
315-
pooling_norm: Optional[bool] = None,
316-
pooling_softmax: Optional[bool] = None,
317-
pooling_step_tag_id: Optional[int] = None,
318-
pooling_returned_token_ids: Optional[List[int]] = None
292+
override_pooler_config: Optional["PoolerConfig"],
319293
) -> Optional["PoolerConfig"]:
294+
320295
if self.task == "embedding":
321-
pooling_config = get_pooling_config(self.model, self.revision)
322-
if pooling_config is not None:
323-
# override if user does not
324-
# specifies pooling_type and/or pooling_norm
325-
if pooling_type is None:
326-
pooling_type = pooling_config["pooling_type"]
327-
if pooling_norm is None:
328-
pooling_norm = pooling_config["normalize"]
329-
return PoolerConfig(
330-
pooling_type=pooling_type,
331-
pooling_norm=pooling_norm,
332-
pooling_softmax=pooling_softmax,
333-
pooling_step_tag_id=pooling_step_tag_id,
334-
pooling_returned_token_ids=pooling_returned_token_ids)
296+
user_config = override_pooler_config or PoolerConfig()
297+
298+
base_config = get_pooling_config(self.model, self.revision)
299+
if base_config is not None:
300+
# Only set values that are not overridden by the user
301+
for k, v in base_config.items():
302+
if getattr(user_config, k) is None:
303+
setattr(user_config, k, v)
304+
305+
return user_config
306+
335307
return None
336308

337309
def _init_attention_free(self) -> bool:
@@ -1786,13 +1758,43 @@ class MultiModalConfig:
17861758

17871759
@dataclass
17881760
class PoolerConfig:
1789-
"""Controls the behavior of pooler in embedding model"""
1761+
"""Controls the behavior of output pooling in embedding models."""
17901762

17911763
pooling_type: Optional[str] = None
1792-
pooling_norm: Optional[bool] = None
1793-
pooling_softmax: Optional[bool] = None
1794-
pooling_step_tag_id: Optional[int] = None
1795-
pooling_returned_token_ids: Optional[List[int]] = None
1764+
"""
1765+
The pooling method of the embedding model. This should be a key in
1766+
:class:`vllm.model_executor.layers.pooler.PoolingType`.
1767+
"""
1768+
1769+
normalize: Optional[bool] = None
1770+
"""
1771+
Whether to normalize the pooled outputs. Usually, this should be set to
1772+
``True`` for embedding outputs.
1773+
"""
1774+
1775+
softmax: Optional[bool] = None
1776+
"""
1777+
Whether to apply softmax to the pooled outputs. Usually, this should be set
1778+
to ``True`` for classification outputs.
1779+
"""
1780+
1781+
step_tag_id: Optional[int] = None
1782+
"""
1783+
If set, only the score corresponding to the ``step_tag_id`` in the
1784+
generated sentence should be returned. Otherwise, the scores for all tokens
1785+
are returned.
1786+
"""
1787+
1788+
returned_token_ids: Optional[List[int]] = None
1789+
"""
1790+
A list of indices for the vocabulary dimensions to be extracted,
1791+
such as the token IDs of ``good_token`` and ``bad_token`` in the
1792+
``math-shepherd-mistral-7b-prm`` model.
1793+
"""
1794+
1795+
@staticmethod
1796+
def from_json(json_str: str) -> "PoolerConfig":
1797+
return PoolerConfig(**json.loads(json_str))
17961798

17971799

17981800
_STR_DTYPE_TO_TORCH_DTYPE = {

0 commit comments

Comments
 (0)