Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/design/logits_processors.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus
changes to the batch makeup.
"""
raise NotImplementedError


@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.

Raise ValueError for invalid ones.
"""
return None

```

A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods:
Expand All @@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum)
* Use the `BatchUpdate` members to update logits processor internal state
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.

* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.

### `BatchUpdate` data structure

The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`):
Expand Down
3 changes: 3 additions & 0 deletions docs/features/custom_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t

Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.

!!! note
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.

## Offline Custom Arguments

Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:
Expand Down
42 changes: 33 additions & 9 deletions docs/features/custom_logitsprocs.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s

Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:

* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.

* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
* `vllm_config`: engine configuration data structure
* `device`: hardware accelerator device info
Expand Down Expand Up @@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, int] = {}
Expand All @@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
self.validate_params(params)
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
Expand Down Expand Up @@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
logits[rows, cols] = values_to_keep

return logits

```

In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
Expand All @@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[

While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.

You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):

* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.

* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.

* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:

??? code "Example of Wrapping a Request-Level Logits Processor"

Expand Down Expand Up @@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} is not int"
)

def is_argmax_invariant(self) -> bool:
return False

Expand All @@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)
```

Expand Down
19 changes: 17 additions & 2 deletions examples/offline_inference/logits_processor/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class object.
------------------------------------------------------------
"""

from typing import Any

import torch

from vllm import LLM, SamplingParams
Expand All @@ -48,6 +50,16 @@ class object.
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
Expand All @@ -57,14 +69,17 @@ def is_argmax_invariant(self) -> bool:
return False

def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")

process_dict_updates(
self.req_info,
batch_update,
# This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the
# request.
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
lambda params, _, __: extract_extra_arg(params),
)

def apply(self, logits: torch.Tensor) -> torch.Tensor:
Expand Down
15 changes: 8 additions & 7 deletions examples/offline_inference/logits_processor/custom_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")

def is_argmax_invariant(self) -> bool:
return False

Expand All @@ -101,13 +109,6 @@ def new_req_logits_processor(
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


Expand Down
15 changes: 8 additions & 7 deletions examples/offline_inference/logits_processor/custom_req_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
Expand Down Expand Up @@ -113,13 +121,6 @@ def new_req_logits_processor(
is None
):
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


Expand Down
29 changes: 29 additions & 0 deletions tests/v1/logits_processors/test_custom_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):

# Alternate whether to activate dummy logitproc for each request
use_dummy_logitproc = not use_dummy_logitproc


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_invalid_custom_logitsproc_arg(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that request with invalid custom logitsproc is rejected"""

prompt = "Hello, my name is"
# Pass invalid (non-int) target_token value to dummy logits processor
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
"extra_body": {
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
},
}

with pytest.raises(openai.OpenAIError) as exc_info:
await client.completions.create(
model=model_name,
prompt=prompt,
**request_keyword_args,
)

assert "is not int" in str(exc_info.value)
17 changes: 15 additions & 2 deletions tests/v1/logits_processors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ class CustomLogitprocSource(Enum):
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)

def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
Expand All @@ -62,11 +72,14 @@ def is_argmax_invariant(self) -> bool:
return False

def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")

process_dict_updates(
self.req_info,
batch_update,
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
lambda params, _, __: extract_extra_arg(params),
)

def apply(self, logits: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description="KVTransfer parameters used for disaggregated serving.",
)

vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
Expand Down
8 changes: 8 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
validate_request_params,
)
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters

logger = init_logger(__name__)

Expand Down Expand Up @@ -110,6 +111,9 @@ def __init__(
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs

# set up logits processors
self.logits_processors = self.model_config.logits_processors

# set up reasoning parser
self.reasoning_parser = self._get_reasoning_parser(
reasoning_parser_name=reasoning_parser
Expand Down Expand Up @@ -294,6 +298,10 @@ async def create_chat_completion(
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)

self._log_inputs(
request_id,
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters

logger = init_logger(__name__)

Expand All @@ -59,6 +60,10 @@ def __init__(
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)

# set up logits processors
self.logits_processors = self.model_config.logits_processors

self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
Expand Down Expand Up @@ -181,6 +186,10 @@ async def create_completion(
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)

request_id_item = f"{request_id}-{i}"

Expand Down
Loading