Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/speculative_decoding/speculators/test_eagle3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch


@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path):
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
20 changes: 16 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
maybe_override_with_speculators_target_model, try_get_generation_config,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -534,6 +534,15 @@ def __post_init__(self) -> None:
"affect the random state of the Python process that "
"launched vLLM.", self.seed)

if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)

# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
Expand Down Expand Up @@ -605,8 +614,8 @@ def __post_init__(self) -> None:
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn)
self.hf_config = hf_config

self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
Expand Down Expand Up @@ -2973,10 +2982,13 @@ def __post_init__(self):
"Chunked prefill and EAGLE are not compatible "
"when using V0.")

from vllm.transformers_utils.configs import (
SpeculatorsConfig)
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)

if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(
Expand Down
22 changes: 21 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,28 @@ def create_speculative_config(
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
"""

from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)

if self.speculative_config is None:
return None
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)

# if loading a SpeculatorsConfig, load the specualtive_config
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we dont create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = self.model
self.speculative_config["method"] = hf_config.method
else:
return None

# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
Expand Down
26 changes: 23 additions & 3 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ def __init__(

self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
else:
self._residual_norm = self._norm_after_residual

def _norm_before_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.hidden_norm(hidden_states)
residual = hidden_states
return hidden_states, residual

def _norm_after_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
return hidden_states, residual

def forward(
self,
positions: torch.Tensor,
Expand All @@ -59,9 +78,10 @@ def forward(
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:

residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)

hidden_states, residual = self._residual_norm(
hidden_states=hidden_states)

hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
Expand Down Expand Up @@ -102,7 +122,7 @@ def __init__(

self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
config=self.config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
)
])
Expand Down
32 changes: 29 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
RWConfig, Step3TextConfig,
Step3VLConfig, UltravoxConfig)
RWConfig, SpeculatorsConfig,
Step3TextConfig, Step3VLConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
Expand Down Expand Up @@ -81,6 +82,7 @@ def _get_hf_token() -> Optional[str]:
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"speculators": SpeculatorsConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
Expand Down Expand Up @@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config


def maybe_override_with_speculators_target_model(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> tuple[str, str]:
"""
If running a speculators config, override running model with target model
"""
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
)
spec_config = config_dict.get("speculators_config")
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer


def get_config(
model: Union[str, Path],
trust_remote_code: bool,
Expand Down Expand Up @@ -345,9 +368,12 @@ def get_config(
token=_get_hf_token(),
**kwargs,
)

# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type

if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
Expand All @@ -44,6 +45,7 @@
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"SpeculatorsConfig",
"UltravoxConfig",
"Step3VLConfig",
"Step3VisionEncoderConfig",
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/speculators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
32 changes: 32 additions & 0 deletions vllm/transformers_utils/configs/speculators/algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

SUPPORTED_SPECULATORS_TYPES = {}


def register_speculator(name):

def decorator(fn):
SUPPORTED_SPECULATORS_TYPES[name] = fn
return fn

return decorator


@register_speculator("eagle3")
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
"""
Apply Eagle-3 specific configuration transformations.

Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
"""

vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
91 changes: 91 additions & 0 deletions vllm/transformers_utils/configs/speculators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Union

from transformers import PretrainedConfig

from vllm.transformers_utils.configs.speculators.algos import (
SUPPORTED_SPECULATORS_TYPES)

__all__ = ["SpeculatorsConfig"]


class SpeculatorsConfig(PretrainedConfig):
model_type = "speculators"

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "SpeculatorsConfig":
"""Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)

speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
"Please ensure you're loading a speculators-format model.")

# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return cls(**vllm_config)

@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
try:
spec_config = config_dict["speculators_config"]
methods = spec_config["proposal_methods"]
first_method = methods[0]
_ = first_method["speculative_tokens"]
_ = spec_config["verifier"]["name_or_path"]
_ = config_dict["speculators_model_type"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError("Invalid speculators config structure") from e

if "transformer_layer_config" not in config_dict:
raise ValueError("Must provide transformer_layer_config")

if not isinstance(config_dict["transformer_layer_config"], dict):
raise TypeError(
"'transformer_layer_config' must be a dictionary if provided")

@classmethod
def convert_speculators_to_vllm(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.

This method handles the translation of field names and structure
between speculators and vLLM formats.

Returns:
Dictionary with vLLM-compatible configuration
"""
# Currently we only support one proposal method
spec_config = config_dict["speculators_config"]
first_method = spec_config.get("proposal_methods")[0]
num_lookahead_tokens = first_method.get("speculative_tokens")

if num_lookahead_tokens is None:
raise ValueError(
"Missing 'speculative_tokens' in proposal method. "
f"Got: {first_method}")

# Build base vLLM config
vllm_config = {
"method": config_dict.get("speculators_model_type"),
"num_lookahead_tokens": num_lookahead_tokens,
"target_model": spec_config.get("verifier")["name_or_path"]
}
vllm_config.update(config_dict["transformer_layer_config"])
return vllm_config