From b0ed8d45533628d6915c0960d6abe6bc8a8a2dbf Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 15:44:00 +0800 Subject: [PATCH 01/18] fix deepseek ocr chat template Signed-off-by: Isotr0py --- vllm/transformers_utils/chat_templates/registry.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 3bdbe1d0a67b..7a5474d89e9f 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -29,12 +29,21 @@ def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Path | return CHAT_TEMPLATES_DIR / "template_chatml.jinja" +def _get_deepseek_vl_v2_chat_template_fallback( + tokenizer_name_or_path: str, +) -> Path | None: + if tokenizer_name_or_path.endswith("-OCR"): + return CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja" + + return CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja" + + _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", - "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", + "deepseek_vl_v2": _get_deepseek_vl_v2_chat_template_fallback, "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", From 8f538eb7699058ed4d197e33db4ed823868be2eb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 16:10:42 +0800 Subject: [PATCH 02/18] fix Signed-off-by: Isotr0py --- vllm/entrypoints/openai/protocol.py | 4 ++-- vllm/transformers_utils/chat_templates/registry.py | 11 +---------- vllm/transformers_utils/configs/deepseek_vl2.py | 6 ++++++ 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9782641296d6..3f9bc1f8cd7f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -769,7 +769,7 @@ class ChatCompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: dict[str, str | int | float] | None = Field( + vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( "Additional request parameters with string or " @@ -1266,7 +1266,7 @@ class CompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: dict[str, str | int | float] | None = Field( + vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( "Additional request parameters with string or " diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 7a5474d89e9f..3bdbe1d0a67b 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -29,21 +29,12 @@ def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Path | return CHAT_TEMPLATES_DIR / "template_chatml.jinja" -def _get_deepseek_vl_v2_chat_template_fallback( - tokenizer_name_or_path: str, -) -> Path | None: - if tokenizer_name_or_path.endswith("-OCR"): - return CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja" - - return CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja" - - _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", - "deepseek_vl_v2": _get_deepseek_vl_v2_chat_template_fallback, + "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 7abfe6229842..23b913157d6d 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -218,3 +218,9 @@ def __init__( self.global_view_pos = global_view_pos self.candidate_resolutions = candidate_resolutions self.vocab_size = self.text_config.vocab_size + + # update model_type for OCR model + if "DeepseekOCRForCausalLM" in ( + self.architectures or kwargs.get("architectures", []) + ): + self.model_type = "deepseek_ocr" From 980912556703fecfe1c4d239aafcd33b02d86026 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 16:13:11 +0800 Subject: [PATCH 03/18] revert Signed-off-by: Isotr0py --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3f9bc1f8cd7f..ae8bec1463e3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1266,7 +1266,7 @@ class CompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( "Additional request parameters with string or " From 3dce5a1adc87fa52fdf5d4b6e0d391b05daeb7c0 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 16:39:07 +0800 Subject: [PATCH 04/18] update description Signed-off-by: Isotr0py --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ae8bec1463e3..b95e943e7765 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -772,7 +772,7 @@ class ChatCompletionRequest(OpenAIBaseModel): vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( - "Additional request parameters with string or " + "Additional request parameters with (list of) string or " "numeric values, used by custom extensions." ), ) From 2618941db06a15bfd54713dbf8a989cfa352c24a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 19:40:41 +0800 Subject: [PATCH 05/18] refactor Signed-off-by: Isotr0py --- vllm/entrypoints/openai/logits_processors.py | 20 ++++++++++++-- vllm/entrypoints/openai/serving_chat.py | 10 +++++++ vllm/model_executor/models/deepseek_ocr.py | 29 ++++++++++++++------ vllm/v1/sample/logits_processor/__init__.py | 12 ++++++++ 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index dedbc23ec83f..efa3768b9c8f 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from functools import lru_cache, partial import torch -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + _load_custom_logitsprocs, +) class AllowedTokenIdsLogitsProcessor: @@ -90,3 +94,15 @@ def get_logits_processors( ) return logits_processors + + +def validate_logits_processors_parameters( + logits_processors: Sequence[str | LogitsProcessor] | None, + sampling_params: SamplingParams, +): + if logits_processors is None: + return None + + for logits_procs in _load_custom_logitsprocs(logits_processors): # type: ignore[arg-type] + if isinstance(logits_procs, AdapterLogitsProcessor): + logits_procs.validate_params(sampling_params) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3bf887c659dc..09ca38ad2de6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,6 +32,9 @@ render_for_completion, ) from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.logits_processors import ( + validate_logits_processors_parameters, +) from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -110,6 +113,9 @@ def __init__( self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs + # set up logits processors + self.logits_processors = self.model_config.logits_processors + # set up reasoning parser self.reasoning_parser = self._get_reasoning_parser( reasoning_parser_name=reasoning_parser @@ -291,6 +297,10 @@ async def create_chat_completion( self.model_config.logits_processor_pattern, self.default_sampling_params, ) + validate_logits_processors_parameters( + self.logits_processors, + sampling_params, + ) self._log_inputs( request_id, diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index fa24db456af4..6ad0f13cf558 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -139,17 +139,18 @@ def __init__( def is_argmax_invariant(self) -> bool: return True - def new_req_logits_processor( - self, - params: SamplingParams, - ) -> RequestLogitsProcessor | None: + @classmethod + def validate_params(cls, params: SamplingParams): ngram_size = params.extra_args and params.extra_args.get("ngram_size") window_size = params.extra_args and params.extra_args.get("window_size", 100) whitelist_token_ids = params.extra_args and params.extra_args.get( "whitelist_token_ids", None ) + # if ngram_size is not provided, skip validation because the processor + # will not be used. if ngram_size is None: return None + if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError( f"`ngram_size` has to be a strictly positive integer, got {ngram_size}." @@ -163,13 +164,23 @@ def new_req_logits_processor( whitelist_token_ids, Iterable ): raise ValueError( - "`whitelist_token_ids` has to be a set of integers, " + "`whitelist_token_ids` has to be a sequence of integers, " f"got {whitelist_token_ids}." ) - else: - whitelist_token_ids = ( - set(whitelist_token_ids) if whitelist_token_ids else None - ) + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + ngram_size = params.extra_args and params.extra_args.get("ngram_size") + window_size = params.extra_args and params.extra_args.get("window_size", 100) + whitelist_token_ids = params.extra_args and params.extra_args.get( + "whitelist_token_ids", None + ) + if ngram_size is None: + return None + + whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None return NoRepeatNGramLogitsProcessor( ngram_size=ngram_size, window_size=window_size, diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 566de5bcda77..59a40e2bcadc 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -245,6 +245,18 @@ def __init__( # was when the partial was created. self.req_info: dict[int, partial[torch.Tensor]] = {} + @classmethod + def validate_params(cls, params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError if params are invalid. + + Args: + params: request sampling params + + """ + raise NotImplementedError + @abstractmethod def new_req_logits_processor( self, From 14cd026f1c1b88305ff888842ac9a8de519bbd0d Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 20:09:06 +0800 Subject: [PATCH 06/18] update doc Signed-off-by: Isotr0py --- docs/features/custom_logitsprocs.md | 21 +++++++++++-------- .../logits_processor/custom_req.py | 15 ++++++------- .../logits_processor/custom_req_init.py | 15 ++++++------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index b8ad53863cd7..49a9fed030d2 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -180,7 +180,7 @@ RequestLogitsProcessor = Union[ While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. -You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: +You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate input sampling parameters and raise error for invalid parameters. Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: ??? code "Example of Wrapping a Request-Level Logits Processor" @@ -223,6 +223,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro def is_argmax_invariant(self) -> bool: return False + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} is not int" + ) + def new_req_logits_processor( self, params: SamplingParams, @@ -240,18 +250,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro Returns: `Callable` request logits processor, or None """ - target_token: Optional[Any] = params.extra_args and params.extra_args.get( + target_token: Any | None = params.extra_args and params.extra_args.get( "target_token" ) if target_token is None: return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) ``` diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py index 87cd7473fa9f..894d2ee2b285 100644 --- a/examples/offline_inference/logits_processor/custom_req.py +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -79,6 +79,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): def is_argmax_invariant(self) -> bool: return False + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError(f"target_token value {target_token} is not int") + def new_req_logits_processor( self, params: SamplingParams, @@ -101,13 +109,6 @@ def new_req_logits_processor( ) if target_token is None: return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py index 3bb82a786040..60d754211caa 100644 --- a/examples/offline_inference/logits_processor/custom_req_init.py +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -86,6 +86,14 @@ def __init__( def is_argmax_invariant(self) -> bool: return False + @classmethod + def validate_params(cls, params: SamplingParams): + target_token = params.extra_args and params.extra_args.get("target_token") + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"`target_token` has to be an integer, got {target_token}." + ) + def new_req_logits_processor( self, params: SamplingParams, @@ -113,13 +121,6 @@ def new_req_logits_processor( is None ): return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) From 865d10b209e97548471296cbce913c000f3a1f42 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 27 Oct 2025 20:13:15 +0800 Subject: [PATCH 07/18] update doc Signed-off-by: Isotr0py --- docs/features/custom_logitsprocs.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index 49a9fed030d2..81b8369ce3f1 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -180,7 +180,10 @@ RequestLogitsProcessor = Union[ While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. -You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate input sampling parameters and raise error for invalid parameters. Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: +You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.): +- Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters. +- Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. +- Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: ??? code "Example of Wrapping a Request-Level Logits Processor" From 4abab784cff1aac3212da5f0cd373c64396f92f3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 28 Oct 2025 23:08:05 +0800 Subject: [PATCH 08/18] move validate_params as abstract method for LogitsProcessor Signed-off-by: Isotr0py --- vllm/v1/sample/logits_processor/__init__.py | 12 ------------ vllm/v1/sample/logits_processor/builtin.py | 14 ++++++++++++++ vllm/v1/sample/logits_processor/interface.py | 9 +++++++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 59a40e2bcadc..566de5bcda77 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -245,18 +245,6 @@ def __init__( # was when the partial was created. self.req_info: dict[int, partial[torch.Tensor]] = {} - @classmethod - def validate_params(cls, params: SamplingParams): - """Validate sampling params for this logits processor. - - Raise ValueError if params are invalid. - - Args: - params: request sampling params - - """ - raise NotImplementedError - @abstractmethod def new_req_logits_processor( self, diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 4ee7dc2880c8..ecb7e549960c 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -46,6 +46,12 @@ def is_argmax_invariant(self) -> bool: """Min-p never impacts greedy sampling""" return True + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + min_p = sampling_params.min_p + if min_p is not None and (min_p < 0.0 or min_p > 1.0): + raise ValueError("min_p should be in the range [0.0, 1.0]") + def get_min_p_by_index(self, index: int) -> float: return float(self.min_p_cpu[index]) @@ -131,6 +137,10 @@ def is_argmax_invariant(self) -> bool: outcome of argmax in greedy sampling.""" return False + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + pass + def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.biases, batch_update, lambda params, _, __: params.logit_bias or None @@ -183,6 +193,10 @@ def is_argmax_invariant(self) -> bool: of the argmax operation in greedy sampling.""" return False + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + pass + @staticmethod def add_request( params: SamplingParams, _: list[int] | None, output_tok_ids: list[int] diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index efa0f62ad6e1..6fd50f71410e 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -64,6 +64,15 @@ def __init__( ) -> None: raise NotImplementedError + @classmethod + @abstractmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + raise NotImplementedError + @abstractmethod def apply(self, logits: torch.Tensor) -> torch.Tensor: """Apply LogitsProcessor to batch logits tensor. From b70bdaf7ea7f0f0054afcb6a3bf122b9e43806e7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 30 Oct 2025 01:39:54 +0800 Subject: [PATCH 09/18] guard CUDA initialization Signed-off-by: Isotr0py --- vllm/entrypoints/openai/serving_chat.py | 4 +--- vllm/utils/torch_utils.py | 20 ++++++++++++++++++++ vllm/v1/sample/logits_processor/__init__.py | 18 ++++++++++++++++++ vllm/v1/sample/logits_processor/builtin.py | 14 -------------- vllm/v1/sample/logits_processor/interface.py | 17 ++++++++--------- 5 files changed, 47 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 09ca38ad2de6..9c9c22e490f3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,9 +32,6 @@ render_for_completion, ) from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.logits_processors import ( - validate_logits_processors_parameters, -) from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -74,6 +71,7 @@ validate_request_params, ) from vllm.utils.collection_utils import as_list +from vllm.v1.sample.logits_processor import validate_logits_processors_parameters logger = init_logger(__name__) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index adcacb34cb7c..ea01ecd2c088 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import importlib.metadata +import os import threading from collections.abc import Callable, Collection from functools import lru_cache @@ -68,6 +69,25 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) +@contextlib.contextmanager +def guard_cuda_initialization(): + """Avoid unexpected CUDA initialization.""" + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + yield + return + + had_key = "CUDA_VISIBLE_DEVICES" in os.environ + old_value = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + yield + if had_key: + os.environ["CUDA_VISIBLE_DEVICES"] = old_value + else: + os.environ.pop("CUDA_VISIBLE_DEVICES") + + def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 566de5bcda77..63cb58913943 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams +from vllm.utils.torch_utils import guard_cuda_initialization from vllm.v1.sample.logits_processor.builtin import ( LogitBiasLogitsProcessor, MinPLogitsProcessor, @@ -206,6 +207,23 @@ def build_logitsprocs( ) +def validate_logits_processors_parameters( + logits_processors: Sequence[str | type[LogitsProcessor]] | None, + sampling_params: SamplingParams, +): + if logits_processors is None: + return None + + # we don't expect any CUDA initialization when loading custom logitsprocs, + # hide all visible GPUs here to guarantee process. + # TODO(Isotr0py): Make the error message more informative if CUDA is + # attempted to be initialized here. Currently, only an internal server + # error is raised. + with guard_cuda_initialization(): + for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_procs.validate_params(sampling_params) + + class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index ecb7e549960c..4ee7dc2880c8 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -46,12 +46,6 @@ def is_argmax_invariant(self) -> bool: """Min-p never impacts greedy sampling""" return True - @classmethod - def validate_params(cls, sampling_params: SamplingParams): - min_p = sampling_params.min_p - if min_p is not None and (min_p < 0.0 or min_p > 1.0): - raise ValueError("min_p should be in the range [0.0, 1.0]") - def get_min_p_by_index(self, index: int) -> float: return float(self.min_p_cpu[index]) @@ -137,10 +131,6 @@ def is_argmax_invariant(self) -> bool: outcome of argmax in greedy sampling.""" return False - @classmethod - def validate_params(cls, sampling_params: SamplingParams): - pass - def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.biases, batch_update, lambda params, _, __: params.logit_bias or None @@ -193,10 +183,6 @@ def is_argmax_invariant(self) -> bool: of the argmax operation in greedy sampling.""" return False - @classmethod - def validate_params(cls, sampling_params: SamplingParams): - pass - @staticmethod def add_request( params: SamplingParams, _: list[int] | None, output_tok_ids: list[int] diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 6fd50f71410e..670217528cf4 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -64,15 +64,6 @@ def __init__( ) -> None: raise NotImplementedError - @classmethod - @abstractmethod - def validate_params(cls, sampling_params: SamplingParams): - """Validate sampling params for this logits processor. - - Raise ValueError for invalid ones. - """ - raise NotImplementedError - @abstractmethod def apply(self, logits: torch.Tensor) -> torch.Tensor: """Apply LogitsProcessor to batch logits tensor. @@ -105,3 +96,11 @@ def update_state( to the batch makeup. """ raise NotImplementedError + + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + return None From 7b7473a285091fd8341bda1e98fed9ae27e33093 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 16:09:36 +0800 Subject: [PATCH 10/18] better error messages and logging Signed-off-by: Isotr0py --- vllm/utils/torch_utils.py | 9 ++++++++- vllm/v1/sample/logits_processor/__init__.py | 18 +++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index ea01ecd2c088..eb703d4d4ae4 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -81,7 +81,14 @@ def guard_cuda_initialization(): had_key = "CUDA_VISIBLE_DEVICES" in os.environ old_value = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "" - yield + try: + yield + except Exception as e: + if "No CUDA GPUs are available" in str(e): + err_msg = "CUDA initialization is blocked." + else: + err_msg = str(e) + raise RuntimeError(err_msg) from e if had_key: os.environ["CUDA_VISIBLE_DEVICES"] = old_value else: diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 63cb58913943..2cf8556b8c6d 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -73,8 +73,10 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: entrypoint.name, entrypoint.value, ) - classes.append(entrypoint.load()) + with guard_cuda_initialization(): + classes.append(entrypoint.load()) except Exception as e: + logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e) raise RuntimeError( f"Failed to load LogitsProcessor plugin {entrypoint}" ) from e @@ -127,8 +129,15 @@ def _load_logitsprocs_by_fqcns( try: # Load module - module = importlib.import_module(module_path) + with guard_cuda_initialization(): + module = importlib.import_module(module_path) except Exception as e: + logger.error( + "Failed to load {%s}th LogitsProcessor plugin {%s}: %s", + ldx, + logitproc, + e, + ) raise RuntimeError( f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" ) from e @@ -219,9 +228,8 @@ def validate_logits_processors_parameters( # TODO(Isotr0py): Make the error message more informative if CUDA is # attempted to be initialized here. Currently, only an internal server # error is raised. - with guard_cuda_initialization(): - for logits_procs in _load_custom_logitsprocs(logits_processors): - logits_procs.validate_params(sampling_params) + for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_procs.validate_params(sampling_params) class AdapterLogitsProcessor(LogitsProcessor): From b1b365fbd7b5a2ad8946b2d94306a98b7d28808b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 16:10:14 +0800 Subject: [PATCH 11/18] oops Signed-off-by: Isotr0py --- vllm/v1/sample/logits_processor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 2cf8556b8c6d..0f69dbd60e85 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -133,7 +133,7 @@ def _load_logitsprocs_by_fqcns( module = importlib.import_module(module_path) except Exception as e: logger.error( - "Failed to load {%s}th LogitsProcessor plugin {%s}: %s", + "Failed to load %sth LogitsProcessor plugin %s: %s", ldx, logitproc, e, From 843866fdb1b432e356f7b032c9336f7a87ae3203 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 17:00:22 +0800 Subject: [PATCH 12/18] update doc and example Signed-off-by: Isotr0py --- docs/design/logits_processors.md | 14 +++++++++- docs/features/custom_arguments.md | 3 +++ docs/features/custom_logitsprocs.md | 26 ++++++++++++++++--- .../logits_processor/custom.py | 17 ++++++++++-- tests/v1/logits_processors/utils.py | 17 ++++++++++-- 5 files changed, 68 insertions(+), 9 deletions(-) diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md index da61d2a85e46..acf7fc245462 100644 --- a/docs/design/logits_processors.md +++ b/docs/design/logits_processors.md @@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus changes to the batch makeup. """ raise NotImplementedError - + + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + return None + ``` A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods: @@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) * Use the `BatchUpdate` members to update logits processor internal state * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. +* `validate_params(cls, sampling_params: SamplingParams)`: + * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. + * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. + ### `BatchUpdate` data structure The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`): diff --git a/docs/features/custom_arguments.md b/docs/features/custom_arguments.md index 74ed40835b4d..7a650d0e79c2 100644 --- a/docs/features/custom_arguments.md +++ b/docs/features/custom_arguments.md @@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code. +!!! note + Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour. + ## Offline Custom Arguments Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`: diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index 81b8369ce3f1..8fe55325b7cf 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -38,6 +38,11 @@ Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsPr * Use the `BatchUpdate` members to update logits processor internal state * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. +* `validate_params(cls, sampling_params: SamplingParams)`: + * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. + * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. + * **Note:** it's important to implent `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor. + ### How the vLLM engine builds the `BatchUpdate` data structure !!! important @@ -118,6 +123,7 @@ The contrived example below implements a custom logits processor which consumes # Process added requests. for index, params, _, _ in batch_update.added: assert params is not None + self.validate_params(params) if params.extra_args and (target_token := params.extra_args.get("target_token")): self.req_info[index] = target_token @@ -157,6 +163,15 @@ The contrived example below implements a custom logits processor which consumes logits[rows, cols] = values_to_keep return logits + + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError(f"target_token value {target_token} is not int") + ``` In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor. @@ -180,10 +195,13 @@ RequestLogitsProcessor = Union[ While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. -You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.): -- Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters. -- Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. -- Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: +You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.): + +* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters. + +* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. + +* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: ??? code "Example of Wrapping a Request-Level Logits Processor" diff --git a/examples/offline_inference/logits_processor/custom.py b/examples/offline_inference/logits_processor/custom.py index 72e7ce24d7cc..73730f72c03a 100644 --- a/examples/offline_inference/logits_processor/custom.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -57,14 +57,17 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: BatchUpdate | None): + def extract_extra_arg(params: SamplingParams) -> int | None: + self.validate_params(params) + return params.extra_args and params.extra_args.get("target_token") + process_dict_updates( self.req_info, batch_update, # This function returns the LP's per-request state based on the # request details, or None if this LP does not apply to the # request. - lambda params, _, __: params.extra_args - and (params.extra_args.get("target_token")), + lambda params, _, __: extract_extra_arg(params), ) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -86,6 +89,16 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + # Sample prompts. prompts = [ diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index 36cffebb3b45..39175d704746 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -62,11 +62,14 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: BatchUpdate | None): + def extract_extra_arg(params: SamplingParams) -> int | None: + self.validate_params(params) + return params.extra_args and params.extra_args.get("target_token") + process_dict_updates( self.req_info, batch_update, - lambda params, _, __: params.extra_args - and (params.extra_args.get("target_token")), + lambda params, _, __: extract_extra_arg(params), ) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -88,6 +91,16 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + """Dummy module with dummy logitproc class""" dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE) From 8c085cf9ea51c6e9b4670d8b6d492d6e49f925c3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 17:52:18 +0800 Subject: [PATCH 13/18] add test Signed-off-by: Isotr0py --- .../logits_processors/test_custom_online.py | 29 +++++++++++++++++++ vllm/entrypoints/openai/serving_completion.py | 9 ++++++ vllm/v1/sample/logits_processor/__init__.py | 8 ----- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index 0d902b46bed5..3e0bb02ed68b 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): # Alternate whether to activate dummy logitproc for each request use_dummy_logitproc = not use_dummy_logitproc + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_custom_logitsproc_arg( + client: openai.AsyncOpenAI, model_name: str +): + """Test that request with invalid custom logitsproc is rejected""" + + prompt = "Hello, my name is" + # Pass invalid (non-int) target_token value to dummy logits processor + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + "extra_body": { + "vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"} + }, + } + + with pytest.raises(openai.OpenAIError) as exc_info: + await client.completions.create( + model=model_name, + prompt=prompt, + **request_keyword_args, + ) + + assert "is not int" in str(exc_info.value) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14dbdd4cb4c7..a114b77ebc16 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -36,6 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.async_utils import merge_async_iterators from vllm.utils.collection_utils import as_list +from vllm.v1.sample.logits_processor import validate_logits_processors_parameters logger = init_logger(__name__) @@ -59,6 +60,10 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids, log_error_stack=log_error_stack, ) + + # set up logits processors + self.logits_processors = self.model_config.logits_processors + self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = self.model_config.get_diff_sampling_param() self.enable_force_include_usage = enable_force_include_usage @@ -181,6 +186,10 @@ async def create_completion( self.model_config.logits_processor_pattern, self.default_sampling_params, ) + validate_logits_processors_parameters( + self.logits_processors, + sampling_params, + ) request_id_item = f"{request_id}-{i}" diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 0f69dbd60e85..eb537eae6c90 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -220,14 +220,6 @@ def validate_logits_processors_parameters( logits_processors: Sequence[str | type[LogitsProcessor]] | None, sampling_params: SamplingParams, ): - if logits_processors is None: - return None - - # we don't expect any CUDA initialization when loading custom logitsprocs, - # hide all visible GPUs here to guarantee process. - # TODO(Isotr0py): Make the error message more informative if CUDA is - # attempted to be initialized here. Currently, only an internal server - # error is raised. for logits_procs in _load_custom_logitsprocs(logits_processors): logits_procs.validate_params(sampling_params) From 2f1a66b697a9fe7b485411d8c5300cbf2e05488b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 20:39:12 +0800 Subject: [PATCH 14/18] move classmethod to top of class Signed-off-by: Isotr0py --- docs/features/custom_logitsprocs.md | 32 +++++++++---------- .../logits_processor/custom.py | 22 +++++++------ .../logits_processor/custom_req.py | 6 ++-- .../logits_processor/custom_req_init.py | 16 +++++----- tests/v1/logits_processors/utils.py | 20 ++++++------ vllm/model_executor/models/deepseek_ocr.py | 16 +++++----- vllm/v1/sample/logits_processor/interface.py | 16 +++++----- 7 files changed, 65 insertions(+), 63 deletions(-) diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index 8fe55325b7cf..52fcc44efacc 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods: +* `validate_params(cls, sampling_params: SamplingParams)`: + * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. + * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. + * **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor. + * `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` * `vllm_config`: engine configuration data structure * `device`: hardware accelerator device info @@ -38,11 +43,6 @@ Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsPr * Use the `BatchUpdate` members to update logits processor internal state * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. -* `validate_params(cls, sampling_params: SamplingParams)`: - * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. - * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. - * **Note:** it's important to implent `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor. - ### How the vLLM engine builds the `BatchUpdate` data structure !!! important @@ -108,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError(f"target_token value {target_token} is not int") + def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): self.req_info: dict[int, int] = {} @@ -164,14 +172,6 @@ The contrived example below implements a custom logits processor which consumes return logits - @classmethod - def validate_params(cls, params: SamplingParams): - target_token: int | None = params.extra_args and params.extra_args.get( - "target_token" - ) - if target_token is not None and not isinstance(target_token, int): - raise ValueError(f"target_token value {target_token} is not int") - ``` In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor. @@ -241,9 +241,6 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro """Example of wrapping a fake request-level logit processor to create a batch-level logits processor""" - def is_argmax_invariant(self) -> bool: - return False - @classmethod def validate_params(cls, params: SamplingParams): target_token: Any | None = params.extra_args and params.extra_args.get( @@ -254,6 +251,9 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro f"target_token value {target_token} is not int" ) + def is_argmax_invariant(self) -> bool: + return False + def new_req_logits_processor( self, params: SamplingParams, diff --git a/examples/offline_inference/logits_processor/custom.py b/examples/offline_inference/logits_processor/custom.py index 73730f72c03a..ce000872dc96 100644 --- a/examples/offline_inference/logits_processor/custom.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -33,6 +33,8 @@ class object. ------------------------------------------------------------ """ +from typing import Any + import torch from vllm import LLM, SamplingParams @@ -48,6 +50,16 @@ class object. class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): @@ -89,16 +101,6 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits - @classmethod - def validate_params(cls, params: SamplingParams): - target_token: int | None = params.extra_args and params.extra_args.get( - "target_token" - ) - if target_token is not None and not isinstance(target_token, int): - raise ValueError( - f"target_token value {target_token} {type(target_token)} is not int" - ) - # Sample prompts. prompts = [ diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py index 894d2ee2b285..5763fff5410d 100644 --- a/examples/offline_inference/logits_processor/custom_req.py +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -76,9 +76,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of wrapping a fake request-level logit processor to create a batch-level logits processor""" - def is_argmax_invariant(self) -> bool: - return False - @classmethod def validate_params(cls, params: SamplingParams): target_token: Any | None = params.extra_args and params.extra_args.get( @@ -87,6 +84,9 @@ def validate_params(cls, params: SamplingParams): if target_token is not None and not isinstance(target_token, int): raise ValueError(f"target_token value {target_token} is not int") + def is_argmax_invariant(self) -> bool: + return False + def new_req_logits_processor( self, params: SamplingParams, diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py index 60d754211caa..acd2c47f230f 100644 --- a/examples/offline_inference/logits_processor/custom_req_init.py +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of overriding the wrapper class `__init__()` in order to utilize info about the device type""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token = params.extra_args and params.extra_args.get("target_token") + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"`target_token` has to be an integer, got {target_token}." + ) + def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): @@ -86,14 +94,6 @@ def __init__( def is_argmax_invariant(self) -> bool: return False - @classmethod - def validate_params(cls, params: SamplingParams): - target_token = params.extra_args and params.extra_args.get("target_token") - if target_token is not None and not isinstance(target_token, int): - raise ValueError( - f"`target_token` has to be an integer, got {target_token}." - ) - def new_req_logits_processor( self, params: SamplingParams, diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index 39175d704746..b8548bc31955 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -52,6 +52,16 @@ class CustomLogitprocSource(Enum): class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): @@ -91,16 +101,6 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits - @classmethod - def validate_params(cls, params: SamplingParams): - target_token: int | None = params.extra_args and params.extra_args.get( - "target_token" - ) - if target_token is not None and not isinstance(target_token, int): - raise ValueError( - f"target_token value {target_token} {type(target_token)} is not int" - ) - """Dummy module with dummy logitproc class""" dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE) diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 6ad0f13cf558..7bea143408df 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -131,14 +131,6 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of overriding the wrapper class `__init__()` in order to utilize info about the device type""" - def __init__( - self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool - ): - super().__init__(vllm_config, device, is_pin_memory) - - def is_argmax_invariant(self) -> bool: - return True - @classmethod def validate_params(cls, params: SamplingParams): ngram_size = params.extra_args and params.extra_args.get("ngram_size") @@ -168,6 +160,14 @@ def validate_params(cls, params: SamplingParams): f"got {whitelist_token_ids}." ) + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + + def is_argmax_invariant(self) -> bool: + return True + def new_req_logits_processor( self, params: SamplingParams, diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 670217528cf4..0cbfb187878a 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -58,6 +58,14 @@ class BatchUpdate: class LogitsProcessor(ABC): + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + return None + @abstractmethod def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool @@ -96,11 +104,3 @@ def update_state( to the batch makeup. """ raise NotImplementedError - - @classmethod - def validate_params(cls, sampling_params: SamplingParams): - """Validate sampling params for this logits processor. - - Raise ValueError for invalid ones. - """ - return None From b583b018bc939f3bd34004aaf8143dfaa38a1b8f Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 20:43:10 +0800 Subject: [PATCH 15/18] use finally for cuda initialization guard Signed-off-by: Isotr0py --- vllm/utils/torch_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index eb703d4d4ae4..fd5c1b73f191 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -89,10 +89,11 @@ def guard_cuda_initialization(): else: err_msg = str(e) raise RuntimeError(err_msg) from e - if had_key: - os.environ["CUDA_VISIBLE_DEVICES"] = old_value - else: - os.environ.pop("CUDA_VISIBLE_DEVICES") + finally: + if had_key: + os.environ["CUDA_VISIBLE_DEVICES"] = old_value + else: + os.environ.pop("CUDA_VISIBLE_DEVICES") def get_dtype_size(dtype: torch.dtype) -> int: From 4da4f6a8515d472e879982d4053f57e3bd5995e9 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 3 Nov 2025 20:54:02 +0800 Subject: [PATCH 16/18] clean ngram logitsproc __init__ Signed-off-by: Isotr0py --- vllm/model_executor/models/deepseek_ocr.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 7bea143408df..bfde8328da6e 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -160,11 +160,6 @@ def validate_params(cls, params: SamplingParams): f"got {whitelist_token_ids}." ) - def __init__( - self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool - ): - super().__init__(vllm_config, device, is_pin_memory) - def is_argmax_invariant(self) -> bool: return True From c4c316775c862111fae5f9f7f2609eafe0a0129a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 4 Nov 2025 12:04:21 +0800 Subject: [PATCH 17/18] fix failed test Signed-off-by: Isotr0py --- tests/entrypoints/openai/test_lora_resolvers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index a85418d5b5f4..b05fa379c69f 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -40,6 +40,7 @@ class MockModelConfig: tokenizer_revision: str | None = None multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) + logits_processors: list[str] | None = None logits_processor_pattern: str | None = None diff_sampling_param: dict | None = None allowed_local_media_path: str = "" From dd91b8303e3dd5b32b2e497928973647da4bea16 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 5 Nov 2025 20:12:17 +0800 Subject: [PATCH 18/18] fix Signed-off-by: Isotr0py --- tests/entrypoints/openai/test_serving_chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 1b83ed7e31e7..dd10384a7e8c 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -353,6 +353,7 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + logits_processors: list[str] | None = None logits_processor_pattern = None diff_sampling_param: dict | None = None allowed_local_media_path: str = ""