Skip to content

Commit 9c59df6

Browse files
[V1] implement tree sampler for draft token acceptance
Signed-off-by: Giancarlo Delfin <[email protected]>
1 parent 4bf70cd commit 9c59df6

File tree

7 files changed

+513
-146
lines changed

7 files changed

+513
-146
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from torch import Generator
6+
7+
from vllm.platforms import current_platform
8+
from vllm.v1.sample.ops.topk_topp_sampler import (
9+
apply_top_k_top_p,
10+
is_flashinfer_available,
11+
)
12+
13+
DEVICE = current_platform.device_type
14+
15+
BATCH_SIZE = 1024
16+
VOCAB_SIZE = 128 * 1024
17+
18+
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
19+
if is_flashinfer_available:
20+
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def reset_default_device():
25+
"""
26+
Explicitly set the default device, which can affect subsequent tests.
27+
Adding this fixture helps avoid this problem.
28+
"""
29+
original_device = torch.get_default_device()
30+
yield
31+
torch.set_default_device(original_device)
32+
33+
34+
def test_topk_impl_equivalence():
35+
36+
torch.set_default_device(DEVICE)
37+
generator = Generator(device=DEVICE).manual_seed(33)
38+
39+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
40+
41+
# Random top-k values between 1 and 9.
42+
k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
43+
44+
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
45+
k.masked_fill_(
46+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
47+
)
48+
49+
# Top-k only implementation
50+
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
51+
52+
# Top-p + top-k
53+
no_op_top_p = torch.tensor([1.0])
54+
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
55+
56+
assert torch.allclose(result1, result2)
57+
58+
59+
def test_tree_rejection_sampler():
60+
"""
61+
This test verifies that the FlashInfer top-k and top-p sampling
62+
implementation produces the same results as the Python implementation.
63+
64+
NOTE: FlashInfer did not directly expose an interface for fused top-k and
65+
top-p prob renorm (it did provide fused sampling but we cannot compare
66+
sampling results due to randomness), so we will compare the probability
67+
renormed consequently by top-k and then top-p of FlashInfer implementation.
68+
"""
69+
70+
if not FLASHINFER_ENABLED:
71+
pytest.skip("FlashInfer not installed or not available on this platform.")
72+
73+
torch.set_default_device(DEVICE)
74+
generator = Generator(device=DEVICE).manual_seed(42)
75+
76+
# Generate random logits
77+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
78+
79+
# Generate various top-k and top-p values
80+
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
81+
p_values = (
82+
torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
83+
) # range in [0.5, 1.0]
84+
85+
# Sometimes disable top-k (k=vocab_size)
86+
k_values.masked_fill_(
87+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
88+
VOCAB_SIZE,
89+
)
90+
91+
# Sometimes disable top-p (p=1.0)
92+
p_values.masked_fill_(
93+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
94+
)
95+
96+
python_logits = apply_top_k_top_p(
97+
logits=logits.clone(),
98+
k=k_values,
99+
p=p_values,
100+
)
101+
python_probs = torch.softmax(python_logits, dim=-1)
102+
103+
# FlashInfer only exposed renorm interfaces for probs so convert first
104+
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
105+
flashinfer_probs = top_k_renorm_probs(
106+
probs=flashinfer_probs,
107+
top_k=k_values,
108+
)
109+
flashinfer_probs = top_p_renorm_probs(
110+
probs=flashinfer_probs,
111+
top_p=p_values,
112+
)
113+
114+
# Compare the results
115+
assert torch.allclose(
116+
python_probs, flashinfer_probs, atol=2e-2
117+
), "FlashInfer and Python sampling implementations do not match!"

vllm/config/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
try_get_tokenizer_config, uses_mrope)
4848
from vllm.transformers_utils.s3_utils import S3Model
4949
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
50+
from vllm.tree_drafter_params import TreeDrafterParams
5051
# yapf conflicts with isort for this block
5152
# yapf: disable
5253
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
@@ -2277,6 +2278,9 @@ class SpeculativeConfig:
22772278
ParallelConfig] = None # type: ignore
22782279
"""The parallel configuration for the draft model initialized internal."""
22792280

2281+
# params generated in the post-init stage for tree drafting.
2282+
tree_drafter_params: SkipValidation[TreeDrafterParams] = None
2283+
22802284
def compute_hash(self) -> str:
22812285
"""
22822286
WARNING: Whenever a new field is added to this config,
@@ -2498,12 +2502,9 @@ def __post_init__(self):
24982502
(i + 1) * (0, )
24992503
for i in range(self.num_speculative_tokens)
25002504
])
2501-
else:
2502-
# Sort the token tree breadth-first.
2503-
tree_choices = ast.literal_eval(
2504-
self.speculative_token_tree)
2505-
self.speculative_token_tree = str(
2506-
sorted(tree_choices, key=lambda t: (len(t), t)))
2505+
# Construct tree drafter params from the serialized token tree.
2506+
self.tree_drafter_params = TreeDrafterParams.from_spec_token_tree(
2507+
self.speculative_token_tree)
25072508

25082509
self.draft_tensor_parallel_size = \
25092510
SpeculativeConfig._verify_and_get_draft_tp(

vllm/tree_drafter_params.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Attention layer with TreeAttention."""
4+
5+
import ast
6+
from dataclasses import dataclass
7+
from typing import Optional
8+
9+
10+
@dataclass
11+
class TreeDrafterParams:
12+
tree_choices: list[tuple[int, ...]]
13+
attn_mask: list[list[bool]]
14+
first_branching_level: Optional[int]
15+
cu_drafts_per_level: list[int]
16+
child_drafts_per_level: list[int]
17+
18+
@staticmethod
19+
def from_spec_token_tree(spec_token_tree: str) -> "TreeDrafterParams":
20+
# Parse the speculative token tree.
21+
tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
22+
# Sort the tree breadth-first.
23+
tree_choices.sort(key=lambda t: (len(t), t))
24+
25+
tree_depth = len(tree_choices[-1])
26+
# Precompute per-level properties of the tree.
27+
num_drafts_per_level = [0] * tree_depth
28+
for node in tree_choices:
29+
num_drafts_per_level[len(node) - 1] += 1
30+
cu_drafts_per_level = [num_drafts_per_level[0]]
31+
child_drafts_per_level = [num_drafts_per_level[0]]
32+
for level in range(1, tree_depth):
33+
cu_drafts_per_level.append(
34+
cu_drafts_per_level[-1] + num_drafts_per_level[level]
35+
)
36+
child_drafts_per_level.append(
37+
num_drafts_per_level[level] // num_drafts_per_level[level - 1]
38+
)
39+
# Find the first level where the tree branches off into one or more
40+
# children.
41+
first_branching_level = None
42+
for level in range(tree_depth):
43+
if child_drafts_per_level[level] > 1:
44+
first_branching_level = level
45+
break
46+
47+
# Construct the tree attention bias.
48+
depth_counts = _get_depth_counts(tree_choices)
49+
attn_mask = _prepare_tree_attn_bias(
50+
tree_choices,
51+
depth_counts,
52+
)
53+
54+
return TreeDrafterParams(
55+
tree_choices=tree_choices,
56+
attn_mask=attn_mask,
57+
first_branching_level=first_branching_level,
58+
cu_drafts_per_level=cu_drafts_per_level,
59+
child_drafts_per_level=child_drafts_per_level,
60+
)
61+
62+
63+
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
64+
# Count the number of choices at each depth of the tree.
65+
depth_counts = []
66+
prev_depth = 0
67+
for path in sorted_tree_choices:
68+
depth = len(path)
69+
if depth != prev_depth:
70+
depth_counts.append(0)
71+
depth_counts[depth - 1] += 1
72+
prev_depth = depth
73+
return depth_counts
74+
75+
76+
def _prepare_tree_attn_bias(
77+
sorted_tree_choices: list[tuple[int, ...]],
78+
depth_counts: list[int],
79+
) -> list[list[bool]]:
80+
# +1 comes from the additional root node.
81+
tree_len = len(sorted_tree_choices) + 1
82+
tree_attn_mask = [[False for _ in range(tree_len)] for _ in range(tree_len)]
83+
84+
mask_val = True
85+
for i in range(tree_len):
86+
# Set diagonal to all True. Each token should attend to itself.
87+
tree_attn_mask[i][i] = mask_val
88+
# Set root column to all True. All tokens attend to it.
89+
tree_attn_mask[i][0] = mask_val
90+
91+
# Set all ancestors to True.
92+
start = 0
93+
for i in range(len(depth_counts)):
94+
for j in range(depth_counts[i]):
95+
cur_tree_choice = sorted_tree_choices[start + j]
96+
if len(cur_tree_choice) == 1:
97+
continue
98+
99+
for c in range(len(cur_tree_choice) - 1):
100+
ancestor_idx = sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
101+
tree_attn_mask[j + start + 1][ancestor_idx] = mask_val
102+
start += depth_counts[i]
103+
return tree_attn_mask

0 commit comments

Comments
 (0)