diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md index da61d2a85e46..acf7fc245462 100644 --- a/docs/design/logits_processors.md +++ b/docs/design/logits_processors.md @@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus changes to the batch makeup. """ raise NotImplementedError - + + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + return None + ``` A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods: @@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) * Use the `BatchUpdate` members to update logits processor internal state * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. +* `validate_params(cls, sampling_params: SamplingParams)`: + * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. + * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. + ### `BatchUpdate` data structure The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`): diff --git a/docs/features/custom_arguments.md b/docs/features/custom_arguments.md index 74ed40835b4d..7a650d0e79c2 100644 --- a/docs/features/custom_arguments.md +++ b/docs/features/custom_arguments.md @@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code. +!!! note + Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour. + ## Offline Custom Arguments Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`: diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index b8ad53863cd7..52fcc44efacc 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods: +* `validate_params(cls, sampling_params: SamplingParams)`: + * Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor. + * When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments. + * **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor. + * `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` * `vllm_config`: engine configuration data structure * `device`: hardware accelerator device info @@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError(f"target_token value {target_token} is not int") + def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): self.req_info: dict[int, int] = {} @@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes # Process added requests. for index, params, _, _ in batch_update.added: assert params is not None + self.validate_params(params) if params.extra_args and (target_token := params.extra_args.get("target_token")): self.req_info[index] = target_token @@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes logits[rows, cols] = values_to_keep return logits + ``` In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor. @@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[ While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. -You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: +You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.): + +* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters. + +* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. + +* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: ??? code "Example of Wrapping a Request-Level Logits Processor" @@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro """Example of wrapping a fake request-level logit processor to create a batch-level logits processor""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} is not int" + ) + def is_argmax_invariant(self) -> bool: return False @@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro Returns: `Callable` request logits processor, or None """ - target_token: Optional[Any] = params.extra_args and params.extra_args.get( + target_token: Any | None = params.extra_args and params.extra_args.get( "target_token" ) if target_token is None: return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) ``` diff --git a/examples/offline_inference/logits_processor/custom.py b/examples/offline_inference/logits_processor/custom.py index 72e7ce24d7cc..ce000872dc96 100644 --- a/examples/offline_inference/logits_processor/custom.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -33,6 +33,8 @@ class object. ------------------------------------------------------------ """ +from typing import Any + import torch from vllm import LLM, SamplingParams @@ -48,6 +50,16 @@ class object. class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): @@ -57,14 +69,17 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: BatchUpdate | None): + def extract_extra_arg(params: SamplingParams) -> int | None: + self.validate_params(params) + return params.extra_args and params.extra_args.get("target_token") + process_dict_updates( self.req_info, batch_update, # This function returns the LP's per-request state based on the # request details, or None if this LP does not apply to the # request. - lambda params, _, __: params.extra_args - and (params.extra_args.get("target_token")), + lambda params, _, __: extract_extra_arg(params), ) def apply(self, logits: torch.Tensor) -> torch.Tensor: diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py index 87cd7473fa9f..5763fff5410d 100644 --- a/examples/offline_inference/logits_processor/custom_req.py +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of wrapping a fake request-level logit processor to create a batch-level logits processor""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError(f"target_token value {target_token} is not int") + def is_argmax_invariant(self) -> bool: return False @@ -101,13 +109,6 @@ def new_req_logits_processor( ) if target_token is None: return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py index 3bb82a786040..acd2c47f230f 100644 --- a/examples/offline_inference/logits_processor/custom_req_init.py +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of overriding the wrapper class `__init__()` in order to utilize info about the device type""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token = params.extra_args and params.extra_args.get("target_token") + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"`target_token` has to be an integer, got {target_token}." + ) + def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): @@ -113,13 +121,6 @@ def new_req_logits_processor( is None ): return None - if not isinstance(target_token, int): - logger.warning( - "target_token value %s is not int; not applying logits" - " processor to request.", - target_token, - ) - return None return DummyPerReqLogitsProcessor(target_token) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index a85418d5b5f4..b05fa379c69f 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -40,6 +40,7 @@ class MockModelConfig: tokenizer_revision: str | None = None multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) + logits_processors: list[str] | None = None logits_processor_pattern: str | None = None diff_sampling_param: dict | None = None allowed_local_media_path: str = "" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 1b83ed7e31e7..dd10384a7e8c 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -353,6 +353,7 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + logits_processors: list[str] | None = None logits_processor_pattern = None diff_sampling_param: dict | None = None allowed_local_media_path: str = "" diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index 0d902b46bed5..3e0bb02ed68b 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): # Alternate whether to activate dummy logitproc for each request use_dummy_logitproc = not use_dummy_logitproc + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_invalid_custom_logitsproc_arg( + client: openai.AsyncOpenAI, model_name: str +): + """Test that request with invalid custom logitsproc is rejected""" + + prompt = "Hello, my name is" + # Pass invalid (non-int) target_token value to dummy logits processor + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + "extra_body": { + "vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"} + }, + } + + with pytest.raises(openai.OpenAIError) as exc_info: + await client.completions.create( + model=model_name, + prompt=prompt, + **request_keyword_args, + ) + + assert "is not int" in str(exc_info.value) diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index 36cffebb3b45..b8548bc31955 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -52,6 +52,16 @@ class CustomLogitprocSource(Enum): class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" + @classmethod + def validate_params(cls, params: SamplingParams): + target_token: int | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is not None and not isinstance(target_token, int): + raise ValueError( + f"target_token value {target_token} {type(target_token)} is not int" + ) + def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): @@ -62,11 +72,14 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: BatchUpdate | None): + def extract_extra_arg(params: SamplingParams) -> int | None: + self.validate_params(params) + return params.extra_args and params.extra_args.get("target_token") + process_dict_updates( self.req_info, batch_update, - lambda params, _, __: params.extra_args - and (params.extra_args.get("target_token")), + lambda params, _, __: extract_extra_arg(params), ) def apply(self, logits: torch.Tensor) -> torch.Tensor: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 33256de6dd47..cf80c4fccbad 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -772,10 +772,10 @@ class ChatCompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: dict[str, str | int | float] | None = Field( + vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( default=None, description=( - "Additional request parameters with string or " + "Additional request parameters with (list of) string or " "numeric values, used by custom extensions." ), ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 25979d5502b0..e63909e82949 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -71,6 +71,7 @@ validate_request_params, ) from vllm.utils.collection_utils import as_list +from vllm.v1.sample.logits_processor import validate_logits_processors_parameters logger = init_logger(__name__) @@ -110,6 +111,9 @@ def __init__( self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs + # set up logits processors + self.logits_processors = self.model_config.logits_processors + # set up reasoning parser self.reasoning_parser = self._get_reasoning_parser( reasoning_parser_name=reasoning_parser @@ -294,6 +298,10 @@ async def create_chat_completion( self.model_config.logits_processor_pattern, self.default_sampling_params, ) + validate_logits_processors_parameters( + self.logits_processors, + sampling_params, + ) self._log_inputs( request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14dbdd4cb4c7..a114b77ebc16 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -36,6 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.async_utils import merge_async_iterators from vllm.utils.collection_utils import as_list +from vllm.v1.sample.logits_processor import validate_logits_processors_parameters logger = init_logger(__name__) @@ -59,6 +60,10 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids, log_error_stack=log_error_stack, ) + + # set up logits processors + self.logits_processors = self.model_config.logits_processors + self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = self.model_config.get_diff_sampling_param() self.enable_force_include_usage = enable_force_include_usage @@ -181,6 +186,10 @@ async def create_completion( self.model_config.logits_processor_pattern, self.default_sampling_params, ) + validate_logits_processors_parameters( + self.logits_processors, + sampling_params, + ) request_id_item = f"{request_id}-{i}" diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index fa24db456af4..bfde8328da6e 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -131,25 +131,18 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of overriding the wrapper class `__init__()` in order to utilize info about the device type""" - def __init__( - self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool - ): - super().__init__(vllm_config, device, is_pin_memory) - - def is_argmax_invariant(self) -> bool: - return True - - def new_req_logits_processor( - self, - params: SamplingParams, - ) -> RequestLogitsProcessor | None: + @classmethod + def validate_params(cls, params: SamplingParams): ngram_size = params.extra_args and params.extra_args.get("ngram_size") window_size = params.extra_args and params.extra_args.get("window_size", 100) whitelist_token_ids = params.extra_args and params.extra_args.get( "whitelist_token_ids", None ) + # if ngram_size is not provided, skip validation because the processor + # will not be used. if ngram_size is None: return None + if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError( f"`ngram_size` has to be a strictly positive integer, got {ngram_size}." @@ -163,13 +156,26 @@ def new_req_logits_processor( whitelist_token_ids, Iterable ): raise ValueError( - "`whitelist_token_ids` has to be a set of integers, " + "`whitelist_token_ids` has to be a sequence of integers, " f"got {whitelist_token_ids}." ) - else: - whitelist_token_ids = ( - set(whitelist_token_ids) if whitelist_token_ids else None - ) + + def is_argmax_invariant(self) -> bool: + return True + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + ngram_size = params.extra_args and params.extra_args.get("ngram_size") + window_size = params.extra_args and params.extra_args.get("window_size", 100) + whitelist_token_ids = params.extra_args and params.extra_args.get( + "whitelist_token_ids", None + ) + if ngram_size is None: + return None + + whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None return NoRepeatNGramLogitsProcessor( ngram_size=ngram_size, window_size=window_size, diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 7abfe6229842..23b913157d6d 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -218,3 +218,9 @@ def __init__( self.global_view_pos = global_view_pos self.candidate_resolutions = candidate_resolutions self.vocab_size = self.text_config.vocab_size + + # update model_type for OCR model + if "DeepseekOCRForCausalLM" in ( + self.architectures or kwargs.get("architectures", []) + ): + self.model_type = "deepseek_ocr" diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index adcacb34cb7c..fd5c1b73f191 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import importlib.metadata +import os import threading from collections.abc import Callable, Collection from functools import lru_cache @@ -68,6 +69,33 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) +@contextlib.contextmanager +def guard_cuda_initialization(): + """Avoid unexpected CUDA initialization.""" + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + yield + return + + had_key = "CUDA_VISIBLE_DEVICES" in os.environ + old_value = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + try: + yield + except Exception as e: + if "No CUDA GPUs are available" in str(e): + err_msg = "CUDA initialization is blocked." + else: + err_msg = str(e) + raise RuntimeError(err_msg) from e + finally: + if had_key: + os.environ["CUDA_VISIBLE_DEVICES"] = old_value + else: + os.environ.pop("CUDA_VISIBLE_DEVICES") + + def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 566de5bcda77..eb537eae6c90 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams +from vllm.utils.torch_utils import guard_cuda_initialization from vllm.v1.sample.logits_processor.builtin import ( LogitBiasLogitsProcessor, MinPLogitsProcessor, @@ -72,8 +73,10 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: entrypoint.name, entrypoint.value, ) - classes.append(entrypoint.load()) + with guard_cuda_initialization(): + classes.append(entrypoint.load()) except Exception as e: + logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e) raise RuntimeError( f"Failed to load LogitsProcessor plugin {entrypoint}" ) from e @@ -126,8 +129,15 @@ def _load_logitsprocs_by_fqcns( try: # Load module - module = importlib.import_module(module_path) + with guard_cuda_initialization(): + module = importlib.import_module(module_path) except Exception as e: + logger.error( + "Failed to load %sth LogitsProcessor plugin %s: %s", + ldx, + logitproc, + e, + ) raise RuntimeError( f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" ) from e @@ -206,6 +216,14 @@ def build_logitsprocs( ) +def validate_logits_processors_parameters( + logits_processors: Sequence[str | type[LogitsProcessor]] | None, + sampling_params: SamplingParams, +): + for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_procs.validate_params(sampling_params) + + class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index efa0f62ad6e1..0cbfb187878a 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -58,6 +58,14 @@ class BatchUpdate: class LogitsProcessor(ABC): + @classmethod + def validate_params(cls, sampling_params: SamplingParams): + """Validate sampling params for this logits processor. + + Raise ValueError for invalid ones. + """ + return None + @abstractmethod def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool