|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import os |
| 4 | +from typing import Any, Union |
| 5 | + |
| 6 | +from transformers import PretrainedConfig |
| 7 | + |
| 8 | +from vllm.transformers_utils.configs.speculators.algos import ( |
| 9 | + SUPPORTED_SPECULATORS_TYPES) |
| 10 | + |
| 11 | +__all__ = ["SpeculatorsConfig"] |
| 12 | + |
| 13 | + |
| 14 | +class SpeculatorsConfig(PretrainedConfig): |
| 15 | + model_type = "speculators" |
| 16 | + |
| 17 | + @classmethod |
| 18 | + def from_pretrained( |
| 19 | + cls, |
| 20 | + pretrained_model_name_or_path: Union[str, os.PathLike], |
| 21 | + **kwargs, |
| 22 | + ) -> "SpeculatorsConfig": |
| 23 | + """Load speculators Eagle config and convert to vLLM format.""" |
| 24 | + config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, |
| 25 | + **kwargs) |
| 26 | + |
| 27 | + speculators_model_type = config_dict.get("speculators_model_type") |
| 28 | + if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: |
| 29 | + raise ValueError( |
| 30 | + f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. " |
| 31 | + "Please ensure you're loading a speculators-format model.") |
| 32 | + |
| 33 | + # validate fields |
| 34 | + # TODO: @dsikka - use speculators pydantic model to validate |
| 35 | + cls.validate_speculators_config(config_dict=config_dict) |
| 36 | + # Convert from speculators config -> format that can be ingested by vLLM |
| 37 | + vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) |
| 38 | + # Apply anything specific to the supported algorithm |
| 39 | + algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] |
| 40 | + algo_updater(config_dict=config_dict, vllm_config=vllm_config) |
| 41 | + return cls(**vllm_config) |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: |
| 45 | + try: |
| 46 | + spec_config = config_dict["speculators_config"] |
| 47 | + methods = spec_config["proposal_methods"] |
| 48 | + first_method = methods[0] |
| 49 | + _ = first_method["speculative_tokens"] |
| 50 | + _ = spec_config["verifier"]["name_or_path"] |
| 51 | + _ = config_dict["speculators_model_type"] |
| 52 | + except (KeyError, IndexError, TypeError) as e: |
| 53 | + raise ValueError("Invalid speculators config structure") from e |
| 54 | + |
| 55 | + if "transformer_layer_config" not in config_dict: |
| 56 | + raise ValueError("Must provide transformer_layer_config") |
| 57 | + |
| 58 | + if not isinstance(config_dict["transformer_layer_config"], dict): |
| 59 | + raise TypeError( |
| 60 | + "'transformer_layer_config' must be a dictionary if provided") |
| 61 | + |
| 62 | + @classmethod |
| 63 | + def convert_speculators_to_vllm( |
| 64 | + cls, config_dict: dict[str, Any]) -> dict[str, Any]: |
| 65 | + """ |
| 66 | + Convert speculators config format to vLLM format. |
| 67 | + |
| 68 | + This method handles the translation of field names and structure |
| 69 | + between speculators and vLLM formats. |
| 70 | + |
| 71 | + Returns: |
| 72 | + Dictionary with vLLM-compatible configuration |
| 73 | + """ |
| 74 | + # Currently we only support one proposal method |
| 75 | + spec_config = config_dict["speculators_config"] |
| 76 | + first_method = spec_config.get("proposal_methods")[0] |
| 77 | + num_lookahead_tokens = first_method.get("speculative_tokens") |
| 78 | + |
| 79 | + if num_lookahead_tokens is None: |
| 80 | + raise ValueError( |
| 81 | + "Missing 'speculative_tokens' in proposal method. " |
| 82 | + f"Got: {first_method}") |
| 83 | + |
| 84 | + # Build base vLLM config |
| 85 | + vllm_config = { |
| 86 | + "method": config_dict.get("speculators_model_type"), |
| 87 | + "num_lookahead_tokens": num_lookahead_tokens, |
| 88 | + "target_model": spec_config.get("verifier")["name_or_path"] |
| 89 | + } |
| 90 | + vllm_config.update(config_dict["transformer_layer_config"]) |
| 91 | + return vllm_config |
0 commit comments