Skip to content

Commit 3da3a66

Browse files
[V1] implement tree sampler for draft token acceptance
Signed-off-by: Giancarlo Delfin <[email protected]>
1 parent 3253ae7 commit 3da3a66

File tree

9 files changed

+536
-153
lines changed

9 files changed

+536
-153
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from torch import Generator
6+
7+
from vllm.platforms import current_platform
8+
from vllm.v1.sample.ops.topk_topp_sampler import (
9+
apply_top_k_top_p,
10+
is_flashinfer_available,
11+
)
12+
13+
DEVICE = current_platform.device_type
14+
15+
BATCH_SIZE = 1024
16+
VOCAB_SIZE = 128 * 1024
17+
18+
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
19+
if is_flashinfer_available:
20+
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def reset_default_device():
25+
"""
26+
Explicitly set the default device, which can affect subsequent tests.
27+
Adding this fixture helps avoid this problem.
28+
"""
29+
original_device = torch.get_default_device()
30+
yield
31+
torch.set_default_device(original_device)
32+
33+
34+
def test_topk_impl_equivalence():
35+
36+
torch.set_default_device(DEVICE)
37+
generator = Generator(device=DEVICE).manual_seed(33)
38+
39+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
40+
41+
# Random top-k values between 1 and 9.
42+
k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
43+
44+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
45+
k.masked_fill_(
46+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
47+
)
48+
49+
# Top-k only implementation
50+
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
51+
52+
# Top-p + top-k
53+
no_op_top_p = torch.tensor([1.0])
54+
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
55+
56+
assert torch.allclose(result1, result2)
57+
58+
59+
def test_tree_rejection_sampler():
60+
"""
61+
This test verifies that the FlashInfer top-k and top-p sampling
62+
implementation produces the same results as the Python implementation.
63+
64+
NOTE: FlashInfer did not directly expose an interface for fused top-k and
65+
top-p prob renorm (it did provide fused sampling but we cannot compare
66+
sampling results due to randomness), so we will compare the probability
67+
renormed consequently by top-k and then top-p of FlashInfer implementation.
68+
"""
69+
70+
if not FLASHINFER_ENABLED:
71+
pytest.skip("FlashInfer not installed or not available on this platform.")
72+
73+
torch.set_default_device(DEVICE)
74+
generator = Generator(device=DEVICE).manual_seed(42)
75+
76+
# Generate random logits
77+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
78+
79+
# Generate various top-k and top-p values
80+
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
81+
p_values = (
82+
torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
83+
) # range in [0.5, 1.0]
84+
85+
# Sometimes disable top-k (k=vocab_size)
86+
k_values.masked_fill_(
87+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
88+
VOCAB_SIZE,
89+
)
90+
91+
# Sometimes disable top-p (p=1.0)
92+
p_values.masked_fill_(
93+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
94+
)
95+
96+
python_logits = apply_top_k_top_p(
97+
logits=logits.clone(),
98+
k=k_values,
99+
p=p_values,
100+
)
101+
python_probs = torch.softmax(python_logits, dim=-1)
102+
103+
# FlashInfer only exposed renorm interfaces for probs so convert first
104+
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
105+
flashinfer_probs = top_k_renorm_probs(
106+
probs=flashinfer_probs,
107+
top_k=k_values,
108+
)
109+
flashinfer_probs = top_p_renorm_probs(
110+
probs=flashinfer_probs,
111+
top_p=p_values,
112+
)
113+
114+
# Compare the results
115+
assert torch.allclose(
116+
python_probs, flashinfer_probs, atol=2e-2
117+
), "FlashInfer and Python sampling implementations do not match!"

vllm/attention/layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
320320
def get_attn_backend(self) -> type[AttentionBackend]:
321321
return self.attn_backend
322322

323+
def get_kv_cache(self) -> torch.Tensor:
324+
forward_context: ForwardContext = get_forward_context()
325+
return self.kv_cache[forward_context.virtual_engine]
326+
323327

324328
class MultiHeadAttention(nn.Module):
325329
"""Multi-headed attention without any cache, used for ViT."""
@@ -409,7 +413,6 @@ def forward(
409413

410414
return out.reshape(bsz, q_len, -1)
411415

412-
413416
def wait_for_kv_layer_from_connector(layer_name: str):
414417
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
415418
return

vllm/config/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
try_get_tokenizer_config, uses_mrope)
4949
from vllm.transformers_utils.s3_utils import S3Model
5050
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
51+
from vllm.tree_drafter_params import TreeDrafterParams
5152
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType,
5253
LazyLoader, common_broadcastable_dtype, random_uuid)
5354

@@ -1980,6 +1981,9 @@ class SpeculativeConfig:
19801981
ParallelConfig] = None # type: ignore
19811982
"""The parallel configuration for the draft model initialized internal."""
19821983

1984+
# params generated in the post-init stage for tree drafting.
1985+
tree_drafter_params: SkipValidation[TreeDrafterParams] = None
1986+
19831987
def compute_hash(self) -> str:
19841988
"""
19851989
WARNING: Whenever a new field is added to this config,
@@ -2201,12 +2205,9 @@ def __post_init__(self):
22012205
(i + 1) * (0, )
22022206
for i in range(self.num_speculative_tokens)
22032207
])
2204-
else:
2205-
# Sort the token tree breadth-first.
2206-
tree_choices = ast.literal_eval(
2207-
self.speculative_token_tree)
2208-
self.speculative_token_tree = str(
2209-
sorted(tree_choices, key=lambda t: (len(t), t)))
2208+
# Construct tree drafter params from the serialized token tree.
2209+
self.tree_drafter_params = TreeDrafterParams.from_spec_token_tree(
2210+
self.speculative_token_tree)
22102211

22112212
self.draft_tensor_parallel_size = \
22122213
SpeculativeConfig._verify_and_get_draft_tp(
@@ -2518,7 +2519,7 @@ class MultiModalConfig:
25182519

25192520
skip_mm_profiling: bool = False
25202521
"""
2521-
When enabled, skips multimodal memory profiling and only profiles with
2522+
When enabled, skips multimodal memory profiling and only profiles with
25222523
language backbone model during engine initialization.
25232524
25242525
This reduces engine startup time but shifts the responsibility to users for
@@ -2581,24 +2582,24 @@ class PoolerConfig:
25812582
## for embeddings models
25822583
normalize: Optional[bool] = None
25832584
"""
2584-
Whether to normalize the embeddings outputs.
2585+
Whether to normalize the embeddings outputs.
25852586
"""
25862587
dimensions: Optional[int] = None
25872588
"""
2588-
Reduce the dimensions of embeddings if model
2589+
Reduce the dimensions of embeddings if model
25892590
support matryoshka representation.
25902591
"""
25912592

25922593
## for classification models
25932594
activation: Optional[bool] = None
25942595
"""
2595-
Whether to apply activation function to the classification outputs.
2596+
Whether to apply activation function to the classification outputs.
25962597
"""
25972598

25982599
## for reward models
25992600
softmax: Optional[bool] = None
26002601
"""
2601-
Whether to apply softmax to the reward outputs.
2602+
Whether to apply softmax to the reward outputs.
26022603
"""
26032604
step_tag_id: Optional[int] = None
26042605
"""
@@ -2624,9 +2625,9 @@ class PoolerConfig:
26242625

26252626
max_embed_len: Optional[int] = None
26262627
"""
2627-
Maximum input length allowed for embedding generation. When set, allows
2628+
Maximum input length allowed for embedding generation. When set, allows
26282629
inputs longer than max_embed_len to be accepted for embedding models.
2629-
This parameter enables accepting long inputs without requiring
2630+
This parameter enables accepting long inputs without requiring
26302631
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
26312632
max_embed_len, it will be handled according to the original max_model_len
26322633
validation logic. Defaults to None (i.e. set to max_model_len).

vllm/model_executor/models/llama.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ def forward(
304304
hidden_states = self.mlp(hidden_states)
305305
return hidden_states, residual
306306

307+
def get_kv_cache(self) -> torch.Tensor:
308+
return self.self_attn.attn.get_kv_cache()
309+
307310

308311
@support_torch_compile
309312
class LlamaModel(nn.Module):
@@ -556,6 +559,12 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
556559
num_layers = len(self.model.layers)
557560
return (2, num_layers // 2, num_layers - 3)
558561

562+
def get_layer_kv_caches(self) -> list[torch.Tensor]:
563+
kv_caches = []
564+
for layer in self.model.layers:
565+
kv_caches.append(layer.get_kv_cache())
566+
return kv_caches
567+
559568
def _init_model(self,
560569
vllm_config: VllmConfig,
561570
prefix: str = "",

vllm/tree_drafter_params.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Attention layer with TreeAttention."""
4+
5+
import ast
6+
from dataclasses import dataclass
7+
from typing import Optional
8+
9+
10+
@dataclass
11+
class TreeDrafterParams:
12+
tree_choices: list[tuple[int, ...]]
13+
attn_mask: list[list[bool]]
14+
first_branching_level: Optional[int]
15+
cu_drafts_per_level: list[int]
16+
child_drafts_per_level: list[int]
17+
18+
@staticmethod
19+
def from_spec_token_tree(spec_token_tree: str) -> "TreeDrafterParams":
20+
# Parse the speculative token tree.
21+
tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
22+
# Sort the tree breadth-first.
23+
tree_choices.sort(key=lambda t: (len(t), t))
24+
25+
tree_depth = len(tree_choices[-1])
26+
# Precompute per-level properties of the tree.
27+
num_drafts_per_level = [0] * tree_depth
28+
for node in tree_choices:
29+
num_drafts_per_level[len(node) - 1] += 1
30+
cu_drafts_per_level = [num_drafts_per_level[0]]
31+
child_drafts_per_level = [num_drafts_per_level[0]]
32+
for level in range(1, tree_depth):
33+
cu_drafts_per_level.append(
34+
cu_drafts_per_level[-1] + num_drafts_per_level[level]
35+
)
36+
child_drafts_per_level.append(
37+
num_drafts_per_level[level] // num_drafts_per_level[level - 1]
38+
)
39+
# Find the first level where the tree branches off into one or more
40+
# children.
41+
first_branching_level = None
42+
for level in range(tree_depth):
43+
if child_drafts_per_level[level] > 1:
44+
first_branching_level = level
45+
break
46+
47+
# Construct the tree attention bias.
48+
depth_counts = _get_depth_counts(tree_choices)
49+
attn_mask = _prepare_tree_attn_bias(
50+
tree_choices,
51+
depth_counts,
52+
)
53+
54+
return TreeDrafterParams(
55+
tree_choices=tree_choices,
56+
attn_mask=attn_mask,
57+
first_branching_level=first_branching_level,
58+
cu_drafts_per_level=cu_drafts_per_level,
59+
child_drafts_per_level=child_drafts_per_level,
60+
)
61+
62+
63+
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
64+
# Count the number of choices at each depth of the tree.
65+
depth_counts = []
66+
prev_depth = 0
67+
for path in sorted_tree_choices:
68+
depth = len(path)
69+
if depth != prev_depth:
70+
depth_counts.append(0)
71+
depth_counts[depth - 1] += 1
72+
prev_depth = depth
73+
return depth_counts
74+
75+
76+
def _prepare_tree_attn_bias(
77+
sorted_tree_choices: list[tuple[int, ...]],
78+
depth_counts: list[int],
79+
) -> list[list[bool]]:
80+
# +1 comes from the additional root node.
81+
tree_len = len(sorted_tree_choices) + 1
82+
tree_attn_mask = [[False for _ in range(tree_len)] for _ in range(tree_len)]
83+
84+
mask_val = True
85+
for i in range(tree_len):
86+
# Set diagonal to all True. Each token should attend to itself.
87+
tree_attn_mask[i][i] = mask_val
88+
# Set root column to all True. All tokens attend to it.
89+
tree_attn_mask[i][0] = mask_val
90+
91+
# Set all ancestors to True.
92+
start = 0
93+
for i in range(len(depth_counts)):
94+
for j in range(depth_counts[i]):
95+
cur_tree_choice = sorted_tree_choices[start + j]
96+
if len(cur_tree_choice) == 1:
97+
continue
98+
99+
for c in range(len(cur_tree_choice) - 1):
100+
ancestor_idx = sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
101+
tree_attn_mask[j + start + 1][ancestor_idx] = mask_val
102+
start += depth_counts[i]
103+
return tree_attn_mask

0 commit comments

Comments
 (0)