Skip to content

Commit 544ec15

Browse files
[V1] implement tree sampler for draft token acceptance
Signed-off-by: Giancarlo Delfin <[email protected]>
1 parent 3253ae7 commit 544ec15

File tree

9 files changed

+665
-184
lines changed

9 files changed

+665
-184
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,6 @@ def test_eagle_correctness(
162162
mm_enabled: bool,
163163
attn_backend: str,
164164
):
165-
if attn_backend == "TREE_ATTN":
166-
# TODO: Fix this flaky test
167-
pytest.skip(
168-
"TREE_ATTN is flaky in the test disable for now until it can be "
169-
"reolved (see https://github.com/vllm-project/vllm/issues/22922)")
170-
171165
# Generate test prompts inside the function instead of using fixture
172166
test_prompts = get_test_prompts(mm_enabled)
173167
'''
@@ -222,7 +216,15 @@ def test_eagle_correctness(
222216

223217
# Heuristic: expect at least 66% of the prompts to match exactly
224218
# Upon failure, inspect the outputs to check for inaccuracy.
225-
assert matches > int(0.66 * len(ref_outputs))
219+
accuracy_threshold = 0.66
220+
221+
if attn_backend == "TREE_ATTN":
222+
# Tree attention uses Triton kernels, which perform can perform
223+
# non-deterministic floating arithmetic. Threshold needs to be
224+
# reduced to 50% to prevent flaky tests.
225+
accuracy_threshold = 0.50
226+
227+
assert matches > int(accuracy_threshold * len(ref_outputs))
226228
del spec_llm
227229
torch.cuda.empty_cache()
228230
cleanup_dist_env_and_memory()
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any, Optional
4+
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
9+
from vllm.platforms import current_platform
10+
from vllm.tree_drafter_params import TreeDrafterParams
11+
from vllm.v1.sample.logits_processor import LogitsProcessorManager
12+
from vllm.v1.sample.metadata import SamplingMetadata
13+
from vllm.v1.sample.sampler import Sampler
14+
from vllm.v1.sample.tree_rejection_sampler import (PLACEHOLDER_TOKEN_ID,
15+
TreeRejectionSampler)
16+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
17+
18+
DEVICE = current_platform.device_type
19+
20+
21+
def create_logits_tensor(output_token_ids: list[list[int]],
22+
vocab_size: int = 100) -> torch.Tensor:
23+
"""Helper function to create logits tensor that
24+
will produce desired token ids on argmax"""
25+
token_ids = [tokens for tokens in output_token_ids]
26+
num_total_tokens = sum(len(tokens) for tokens in token_ids)
27+
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
28+
start_loc = 0
29+
for tokens in token_ids:
30+
for j, token_id in enumerate(tokens):
31+
logits[start_loc + j, token_id] = 100.0
32+
start_loc += len(tokens)
33+
return logits
34+
35+
36+
def create_sampling_metadata(
37+
all_greedy: bool,
38+
temperature: Optional[torch.Tensor] = None,
39+
top_k: Optional[torch.Tensor] = None,
40+
top_p: Optional[torch.Tensor] = None,
41+
generators: Optional[dict[int, Any]] = None,
42+
) -> SamplingMetadata:
43+
"""Create a v1 sampling metadata object with all_greedy set
44+
to the given value. Either all greedy or all random sampling
45+
is used.
46+
"""
47+
generators = generators or {}
48+
if all_greedy:
49+
temperature = None
50+
else:
51+
assert temperature is not None
52+
53+
return SamplingMetadata(
54+
temperature=temperature,
55+
all_greedy=all_greedy,
56+
all_random=not all_greedy,
57+
top_p=top_p,
58+
top_k=top_k,
59+
generators=generators,
60+
max_num_logprobs=0,
61+
no_penalties=True,
62+
prompt_token_ids=None,
63+
frequency_penalties=torch.tensor([]),
64+
presence_penalties=torch.tensor([]),
65+
repetition_penalties=torch.tensor([]),
66+
output_token_ids=[],
67+
allowed_token_ids_mask=None,
68+
bad_words_token_ids={},
69+
logitsprocs=LogitsProcessorManager(),
70+
)
71+
72+
73+
########################### Tests for Greedy Sampling ###################
74+
75+
def test_perfect_match():
76+
"""Test when output tokens perfectly match speculated tokens"""
77+
tree_drafter_params = TreeDrafterParams.from_spec_token_tree(
78+
"[(0, ), (0, 0), (0, 0, 0)]")
79+
tree_rejection_sampler = TreeRejectionSampler(
80+
tree_drafter_params=tree_drafter_params,
81+
max_batch_size=1,
82+
main_sampler=Sampler(),
83+
device=None,
84+
)
85+
86+
spec_tokens = [[1, 2, 3]]
87+
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
88+
89+
metadata = create_sampling_metadata(all_greedy=True)
90+
logits = create_logits_tensor(output_tokens)
91+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
92+
device=logits.device)
93+
94+
output = tree_rejection_sampler(
95+
spec_decode_metadata,
96+
draft_probs=None,
97+
target_logits=logits,
98+
bonus_token_ids=None,
99+
sampling_metadata=metadata,
100+
)
101+
expected = torch.tensor(output_tokens,
102+
dtype=torch.int,
103+
device=logits.device)
104+
assert torch.equal(output, expected)
105+
106+
@pytest.mark.parametrize(
107+
"spec_token_tree",
108+
[
109+
[(0, )], # A single token
110+
[(0, ), (0, 0), (0, 0, 0)], # Chain
111+
[(0, ), (1, ), (2, )], # Parallel
112+
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
113+
(2, 1)], # Tree
114+
])
115+
def test_greedy_sampling(spec_token_tree: list[tuple[int]]):
116+
"""Test when output tokens perfectly match speculated tokens"""
117+
tree_drafter_params = TreeDrafterParams.from_spec_token_tree(
118+
str(spec_token_tree))
119+
tree_rejection_sampler = TreeRejectionSampler(
120+
tree_drafter_params=tree_drafter_params,
121+
max_batch_size=1,
122+
main_sampler=Sampler(),
123+
device=None,
124+
)
125+
126+
spec_tokens = [[i + 1 for i in range(len(spec_token_tree) + 1)]]
127+
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
128+
longest_path = find_longest_path()
129+
130+
metadata = create_sampling_metadata(all_greedy=True)
131+
logits = create_logits_tensor(output_tokens)
132+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
133+
device=logits.device)
134+
135+
output = tree_rejection_sampler(
136+
spec_decode_metadata,
137+
draft_probs=None,
138+
target_logits=logits,
139+
bonus_token_ids=None,
140+
sampling_metadata=metadata,
141+
)
142+
expected = torch.tensor(longest_path,
143+
dtype=torch.int,
144+
device=logits.device)
145+
146+
assert torch.equal(output, expected)

vllm/config/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
try_get_tokenizer_config, uses_mrope)
4949
from vllm.transformers_utils.s3_utils import S3Model
5050
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
51+
from vllm.tree_drafter_params import TreeDrafterParams
5152
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType,
5253
LazyLoader, common_broadcastable_dtype, random_uuid)
5354

@@ -1980,6 +1981,9 @@ class SpeculativeConfig:
19801981
ParallelConfig] = None # type: ignore
19811982
"""The parallel configuration for the draft model initialized internal."""
19821983

1984+
# params generated in the post-init stage for tree drafting.
1985+
tree_drafter_params: SkipValidation[TreeDrafterParams] = None
1986+
19831987
def compute_hash(self) -> str:
19841988
"""
19851989
WARNING: Whenever a new field is added to this config,
@@ -2201,12 +2205,9 @@ def __post_init__(self):
22012205
(i + 1) * (0, )
22022206
for i in range(self.num_speculative_tokens)
22032207
])
2204-
else:
2205-
# Sort the token tree breadth-first.
2206-
tree_choices = ast.literal_eval(
2207-
self.speculative_token_tree)
2208-
self.speculative_token_tree = str(
2209-
sorted(tree_choices, key=lambda t: (len(t), t)))
2208+
# Construct tree drafter params from the serialized token tree.
2209+
self.tree_drafter_params = TreeDrafterParams.from_spec_token_tree(
2210+
self.speculative_token_tree)
22102211

22112212
self.draft_tensor_parallel_size = \
22122213
SpeculativeConfig._verify_and_get_draft_tp(
@@ -2518,7 +2519,7 @@ class MultiModalConfig:
25182519

25192520
skip_mm_profiling: bool = False
25202521
"""
2521-
When enabled, skips multimodal memory profiling and only profiles with
2522+
When enabled, skips multimodal memory profiling and only profiles with
25222523
language backbone model during engine initialization.
25232524
25242525
This reduces engine startup time but shifts the responsibility to users for
@@ -2581,24 +2582,24 @@ class PoolerConfig:
25812582
## for embeddings models
25822583
normalize: Optional[bool] = None
25832584
"""
2584-
Whether to normalize the embeddings outputs.
2585+
Whether to normalize the embeddings outputs.
25852586
"""
25862587
dimensions: Optional[int] = None
25872588
"""
2588-
Reduce the dimensions of embeddings if model
2589+
Reduce the dimensions of embeddings if model
25892590
support matryoshka representation.
25902591
"""
25912592

25922593
## for classification models
25932594
activation: Optional[bool] = None
25942595
"""
2595-
Whether to apply activation function to the classification outputs.
2596+
Whether to apply activation function to the classification outputs.
25962597
"""
25972598

25982599
## for reward models
25992600
softmax: Optional[bool] = None
26002601
"""
2601-
Whether to apply softmax to the reward outputs.
2602+
Whether to apply softmax to the reward outputs.
26022603
"""
26032604
step_tag_id: Optional[int] = None
26042605
"""
@@ -2624,9 +2625,9 @@ class PoolerConfig:
26242625

26252626
max_embed_len: Optional[int] = None
26262627
"""
2627-
Maximum input length allowed for embedding generation. When set, allows
2628+
Maximum input length allowed for embedding generation. When set, allows
26282629
inputs longer than max_embed_len to be accepted for embedding models.
2629-
This parameter enables accepting long inputs without requiring
2630+
This parameter enables accepting long inputs without requiring
26302631
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
26312632
max_embed_len, it will be handled according to the original max_model_len
26322633
validation logic. Defaults to None (i.e. set to max_model_len).

vllm/tree_drafter_params.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Attention layer with TreeAttention."""
4+
5+
import ast
6+
from dataclasses import dataclass
7+
8+
@dataclass
9+
class TreeDrafterParams:
10+
tree_choices: list[tuple[int, ...]]
11+
attn_mask: list[list[bool]]
12+
cu_drafts_per_level: list[int]
13+
child_drafts_per_level: list[int]
14+
15+
@staticmethod
16+
def from_spec_token_tree(spec_token_tree: str) -> "TreeDrafterParams":
17+
# Parse the speculative token tree.
18+
tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
19+
# Sort the tree breadth-first.
20+
tree_choices.sort(key=lambda t: (len(t), t))
21+
22+
tree_depth = len(tree_choices[-1])
23+
# Precompute per-level properties of the tree.
24+
num_drafts_per_level = [0] * tree_depth
25+
for node in tree_choices:
26+
num_drafts_per_level[len(node) - 1] += 1
27+
cu_drafts_per_level = [num_drafts_per_level[0]]
28+
child_drafts_per_level = [num_drafts_per_level[0]]
29+
for level in range(1, tree_depth):
30+
cu_drafts_per_level.append(
31+
cu_drafts_per_level[-1] + num_drafts_per_level[level]
32+
)
33+
child_drafts_per_level.append(
34+
num_drafts_per_level[level] // num_drafts_per_level[level - 1]
35+
)
36+
37+
# Construct the tree attention bias.
38+
depth_counts = _get_depth_counts(tree_choices)
39+
attn_mask = _prepare_tree_attn_bias(
40+
tree_choices,
41+
depth_counts,
42+
)
43+
44+
return TreeDrafterParams(
45+
tree_choices=tree_choices,
46+
attn_mask=attn_mask,
47+
cu_drafts_per_level=cu_drafts_per_level,
48+
child_drafts_per_level=child_drafts_per_level,
49+
)
50+
51+
52+
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
53+
# Count the number of choices at each depth of the tree.
54+
depth_counts = []
55+
prev_depth = 0
56+
for path in sorted_tree_choices:
57+
depth = len(path)
58+
if depth != prev_depth:
59+
depth_counts.append(0)
60+
depth_counts[depth - 1] += 1
61+
prev_depth = depth
62+
return depth_counts
63+
64+
65+
def _prepare_tree_attn_bias(
66+
sorted_tree_choices: list[tuple[int, ...]],
67+
depth_counts: list[int],
68+
) -> list[list[bool]]:
69+
# +1 comes from the additional root node.
70+
tree_len = len(sorted_tree_choices) + 1
71+
tree_attn_mask = [[False for _ in range(tree_len)] for _ in range(tree_len)]
72+
73+
mask_val = True
74+
for i in range(tree_len):
75+
# Set diagonal to all True. Each token should attend to itself.
76+
tree_attn_mask[i][i] = mask_val
77+
# Set root column to all True. All tokens attend to it.
78+
tree_attn_mask[i][0] = mask_val
79+
80+
# Set all ancestors to True.
81+
start = 0
82+
for i in range(len(depth_counts)):
83+
for j in range(depth_counts[i]):
84+
cur_tree_choice = sorted_tree_choices[start + j]
85+
if len(cur_tree_choice) == 1:
86+
continue
87+
88+
for c in range(len(cur_tree_choice) - 1):
89+
ancestor_idx = sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
90+
tree_attn_mask[j + start + 1][ancestor_idx] = mask_val
91+
start += depth_counts[i]
92+
return tree_attn_mask

0 commit comments

Comments
 (0)