Skip to content

Commit dfbc1f8

Browse files
authored
[Speculative Decoding] Add speculators config support (#21345)
1 parent 87c94bc commit dfbc1f8

File tree

9 files changed

+232
-11
lines changed

9 files changed

+232
-11
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
7+
@pytest.mark.parametrize(
8+
"model_path",
9+
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
10+
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
11+
def test_llama(vllm_runner, example_prompts, model_path):
12+
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
13+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
14+
max_tokens=20)
15+
print(vllm_outputs)
16+
assert vllm_outputs

vllm/config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
ConfigFormat, get_config, get_hf_image_processor_config,
4040
get_hf_text_config, get_pooling_config,
4141
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
42-
try_get_generation_config, try_get_safetensors_metadata,
43-
try_get_tokenizer_config, uses_mrope)
42+
maybe_override_with_speculators_target_model, try_get_generation_config,
43+
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
4444
from vllm.transformers_utils.s3_utils import S3Model
4545
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
4646
# yapf conflicts with isort for this block
@@ -535,6 +535,15 @@ def __post_init__(self) -> None:
535535
"affect the random state of the Python process that "
536536
"launched vLLM.", self.seed)
537537

538+
if self.runner != "draft":
539+
# If we're not running the draft model, check for speculators config
540+
# If speculators config, set model / tokenizer to be target model
541+
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
542+
model=self.model,
543+
tokenizer=self.tokenizer,
544+
revision=self.revision,
545+
trust_remote_code=self.trust_remote_code)
546+
538547
# Keep set served_model_name before maybe_model_redirect(self.model)
539548
self.served_model_name = get_served_model_name(self.model,
540549
self.served_model_name)
@@ -606,8 +615,8 @@ def __post_init__(self) -> None:
606615
self.config_format,
607616
hf_overrides_kw=hf_overrides_kw,
608617
hf_overrides_fn=hf_overrides_fn)
609-
self.hf_config = hf_config
610618

619+
self.hf_config = hf_config
611620
self.hf_text_config = get_hf_text_config(self.hf_config)
612621
self.attention_chunk_size = getattr(self.hf_text_config,
613622
"attention_chunk_size", None)
@@ -2980,10 +2989,13 @@ def __post_init__(self):
29802989
"Chunked prefill and EAGLE are not compatible "
29812990
"when using V0.")
29822991

2992+
from vllm.transformers_utils.configs import (
2993+
SpeculatorsConfig)
29832994
from vllm.transformers_utils.configs.eagle import (
29842995
EAGLEConfig)
2996+
29852997
if isinstance(self.draft_model_config.hf_config,
2986-
EAGLEConfig):
2998+
(EAGLEConfig, SpeculatorsConfig)):
29872999
pass
29883000
else:
29893001
eagle_config = EAGLEConfig(

vllm/engine/arg_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,28 @@ def create_speculative_config(
978978
provided as a JSON string input via CLI arguments or directly as a
979979
dictionary from the engine.
980980
"""
981+
982+
from vllm.transformers_utils.config import get_config
983+
from vllm.transformers_utils.configs.speculators.base import (
984+
SpeculatorsConfig)
985+
981986
if self.speculative_config is None:
982-
return None
987+
hf_config = get_config(self.hf_config_path or self.model,
988+
self.trust_remote_code, self.revision,
989+
self.code_revision, self.config_format)
990+
991+
# if loading a SpeculatorsConfig, load the specualtive_config
992+
# details from the config directly
993+
# no user input required / expected
994+
if isinstance(hf_config, SpeculatorsConfig):
995+
# We create one since we dont create one
996+
self.speculative_config = {}
997+
self.speculative_config[
998+
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
999+
self.speculative_config["model"] = self.model
1000+
self.speculative_config["method"] = hf_config.method
1001+
else:
1002+
return None
9831003

9841004
# Note(Shangming): These parameters are not obtained from the cli arg
9851005
# '--speculative-config' and must be passed in when creating the engine

vllm/model_executor/models/llama_eagle3.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ def __init__(
5151

5252
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5353

54+
if getattr(config, "norm_before_residual", False):
55+
self._residual_norm = self._norm_before_residual
56+
else:
57+
self._residual_norm = self._norm_after_residual
58+
59+
def _norm_before_residual(
60+
self,
61+
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
62+
hidden_states = self.hidden_norm(hidden_states)
63+
residual = hidden_states
64+
return hidden_states, residual
65+
66+
def _norm_after_residual(
67+
self,
68+
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
69+
residual = hidden_states
70+
hidden_states = self.hidden_norm(hidden_states)
71+
return hidden_states, residual
72+
5473
def forward(
5574
self,
5675
positions: torch.Tensor,
@@ -59,9 +78,10 @@ def forward(
5978
residual: Optional[torch.Tensor],
6079
) -> tuple[torch.Tensor, torch.Tensor]:
6180

62-
residual = hidden_states
6381
embeds = self.input_layernorm(embeds)
64-
hidden_states = self.hidden_norm(hidden_states)
82+
83+
hidden_states, residual = self._residual_norm(
84+
hidden_states=hidden_states)
6585

6686
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
6787
# Self Attention
@@ -102,7 +122,7 @@ def __init__(
102122

103123
self.layers = nn.ModuleList([
104124
LlamaDecoderLayer(
105-
self.config,
125+
config=self.config,
106126
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
107127
)
108128
])

vllm/transformers_utils/config.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
MllamaConfig, MLPSpeculatorConfig,
3636
Nemotron_Nano_VL_Config,
3737
NemotronConfig, NVLM_D_Config,
38-
RWConfig, Step3TextConfig,
39-
Step3VLConfig, UltravoxConfig)
38+
RWConfig, SpeculatorsConfig,
39+
Step3TextConfig, Step3VLConfig,
40+
UltravoxConfig)
4041
# yapf: enable
4142
from vllm.transformers_utils.configs.mistral import adapt_config_dict
4243
from vllm.transformers_utils.utils import check_gguf_file
@@ -81,6 +82,7 @@ def _get_hf_token() -> Optional[str]:
8182
"mlp_speculator": MLPSpeculatorConfig,
8283
"medusa": MedusaConfig,
8384
"eagle": EAGLEConfig,
85+
"speculators": SpeculatorsConfig,
8486
"nemotron": NemotronConfig,
8587
"NVLM_D": NVLM_D_Config,
8688
"ultravox": UltravoxConfig,
@@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
287289
return config
288290

289291

292+
def maybe_override_with_speculators_target_model(
293+
model: str,
294+
tokenizer: str,
295+
trust_remote_code: bool,
296+
revision: Optional[str] = None) -> tuple[str, str]:
297+
"""
298+
If running a speculators config, override running model with target model
299+
"""
300+
config_dict, _ = PretrainedConfig.get_config_dict(
301+
model,
302+
revision=revision,
303+
trust_remote_code=trust_remote_code,
304+
token=_get_hf_token(),
305+
)
306+
spec_config = config_dict.get("speculators_config")
307+
# Return the target model
308+
if spec_config is not None:
309+
model = tokenizer = spec_config["verifier"]["name_or_path"]
310+
return model, tokenizer
311+
312+
290313
def get_config(
291314
model: Union[str, Path],
292315
trust_remote_code: bool,
@@ -345,9 +368,12 @@ def get_config(
345368
token=_get_hf_token(),
346369
**kwargs,
347370
)
348-
349371
# Use custom model class if it's in our registry
350372
model_type = config_dict.get("model_type")
373+
if model_type is None:
374+
model_type = "speculators" if config_dict.get(
375+
"speculators_config") is not None else model_type
376+
351377
if model_type in _CONFIG_REGISTRY:
352378
config_class = _CONFIG_REGISTRY[model_type]
353379
config = config_class.from_pretrained(

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
2525
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
2626
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
27+
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
2728
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
2829
Step3VisionEncoderConfig,
2930
Step3VLConfig)
@@ -44,6 +45,7 @@
4445
"NemotronHConfig",
4546
"Nemotron_Nano_VL_Config",
4647
"NVLM_D_Config",
48+
"SpeculatorsConfig",
4749
"UltravoxConfig",
4850
"Step3VLConfig",
4951
"Step3VisionEncoderConfig",
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
SUPPORTED_SPECULATORS_TYPES = {}
5+
6+
7+
def register_speculator(name):
8+
9+
def decorator(fn):
10+
SUPPORTED_SPECULATORS_TYPES[name] = fn
11+
return fn
12+
13+
return decorator
14+
15+
16+
@register_speculator("eagle3")
17+
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
18+
"""
19+
Apply Eagle-3 specific configuration transformations.
20+
21+
Eagle-3 specific fields:
22+
- draft_vocab_size: Size of the draft model's vocabulary
23+
- target_hidden_size: Hidden size of the target model
24+
- norm_before_residual: Whether to apply norm before residual connection
25+
"""
26+
27+
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
28+
if config_dict.get("target_hidden_size") is not None:
29+
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
30+
vllm_config["norm_before_residual"] = config_dict.get(
31+
"norm_before_residual", True)
32+
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
from typing import Any, Union
5+
6+
from transformers import PretrainedConfig
7+
8+
from vllm.transformers_utils.configs.speculators.algos import (
9+
SUPPORTED_SPECULATORS_TYPES)
10+
11+
__all__ = ["SpeculatorsConfig"]
12+
13+
14+
class SpeculatorsConfig(PretrainedConfig):
15+
model_type = "speculators"
16+
17+
@classmethod
18+
def from_pretrained(
19+
cls,
20+
pretrained_model_name_or_path: Union[str, os.PathLike],
21+
**kwargs,
22+
) -> "SpeculatorsConfig":
23+
"""Load speculators Eagle config and convert to vLLM format."""
24+
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
25+
**kwargs)
26+
27+
speculators_model_type = config_dict.get("speculators_model_type")
28+
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
29+
raise ValueError(
30+
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
31+
"Please ensure you're loading a speculators-format model.")
32+
33+
# validate fields
34+
# TODO: @dsikka - use speculators pydantic model to validate
35+
cls.validate_speculators_config(config_dict=config_dict)
36+
# Convert from speculators config -> format that can be ingested by vLLM
37+
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
38+
# Apply anything specific to the supported algorithm
39+
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
40+
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
41+
return cls(**vllm_config)
42+
43+
@classmethod
44+
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
45+
try:
46+
spec_config = config_dict["speculators_config"]
47+
methods = spec_config["proposal_methods"]
48+
first_method = methods[0]
49+
_ = first_method["speculative_tokens"]
50+
_ = spec_config["verifier"]["name_or_path"]
51+
_ = config_dict["speculators_model_type"]
52+
except (KeyError, IndexError, TypeError) as e:
53+
raise ValueError("Invalid speculators config structure") from e
54+
55+
if "transformer_layer_config" not in config_dict:
56+
raise ValueError("Must provide transformer_layer_config")
57+
58+
if not isinstance(config_dict["transformer_layer_config"], dict):
59+
raise TypeError(
60+
"'transformer_layer_config' must be a dictionary if provided")
61+
62+
@classmethod
63+
def convert_speculators_to_vllm(
64+
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
65+
"""
66+
Convert speculators config format to vLLM format.
67+
68+
This method handles the translation of field names and structure
69+
between speculators and vLLM formats.
70+
71+
Returns:
72+
Dictionary with vLLM-compatible configuration
73+
"""
74+
# Currently we only support one proposal method
75+
spec_config = config_dict["speculators_config"]
76+
first_method = spec_config.get("proposal_methods")[0]
77+
num_lookahead_tokens = first_method.get("speculative_tokens")
78+
79+
if num_lookahead_tokens is None:
80+
raise ValueError(
81+
"Missing 'speculative_tokens' in proposal method. "
82+
f"Got: {first_method}")
83+
84+
# Build base vLLM config
85+
vllm_config = {
86+
"method": config_dict.get("speculators_model_type"),
87+
"num_lookahead_tokens": num_lookahead_tokens,
88+
"target_model": spec_config.get("verifier")["name_or_path"]
89+
}
90+
vllm_config.update(config_dict["transformer_layer_config"])
91+
return vllm_config

0 commit comments

Comments
 (0)