diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 7b3f45831279..fbca3b2872ad 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -162,12 +162,6 @@ def test_eagle_correctness( mm_enabled: bool, attn_backend: str, ): - if attn_backend == "TREE_ATTN": - # TODO: Fix this flaky test - pytest.skip( - "TREE_ATTN is flaky in the test disable for now until it can be " - "reolved (see https://github.com/vllm-project/vllm/issues/22922)") - # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) ''' @@ -222,7 +216,15 @@ def test_eagle_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) + accuracy_threshold = 0.66 + + if attn_backend == "TREE_ATTN": + # The tree attention backend uses Triton kernels, which exhibit + # floating-point nondeterminism. Reducing the threshold to 50% + # to prevent flaky tests. + accuracy_threshold = 0.50 + + assert matches > int(accuracy_threshold * len(ref_outputs)) del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() diff --git a/tests/v1/sample/test_tree_rejection_sampler.py b/tests/v1/sample/test_tree_rejection_sampler.py new file mode 100644 index 000000000000..3c2fbc1cddba --- /dev/null +++ b/tests/v1/sample/test_tree_rejection_sampler.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +from vllm.platforms import current_platform +from vllm.tree_drafter_params import TreeDrafterParams +from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler +from vllm.v1.sample.tree_rejection_sampler import (PLACEHOLDER_TOKEN_ID, + TreeRejectionSampler) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +DEVICE = current_platform.device_type +VOCAB_SIZE = 100 +Node = tuple[int] + +########################### Helper Functions ########################### + + +def create_tree_rejection_sampler(tree_structure: list[Node], + batch_size: int) -> TreeRejectionSampler: + tree_drafter_params = TreeDrafterParams.from_spec_token_tree( + str(tree_structure)) + return TreeRejectionSampler( + tree_drafter_params=tree_drafter_params, + max_batch_size=batch_size, + main_sampler=Sampler(), + device=DEVICE, + ) + + +def get_token_id(tree: list[Node], node: Node) -> int: + # Token id is just the position of this node in the tree. + return tree.index(node) + + +def to_input_draft_token_ids(tree: list[Node], num_drafts: int, + draft_nodes: list[Node]) -> torch.Tensor: + """ + Creates a tensor of draft token ids to input into the rejection sampler. + Each given node is mapped to a unique token id. All other positions are + given a random token id. + """ + draft_token_ids = torch.randint( + # Offset the random token ids by the size of the tree. + low=len(tree), + high=VOCAB_SIZE, + size=(num_drafts, ), + device=DEVICE) + for draft_node in draft_nodes: + # Get the draft node's position in the tree, excluding the root node. + index = tree.index(draft_node) - 1 + # Assign unique token id to the node. + token_id = get_token_id(tree, draft_node) + draft_token_ids[index] = token_id + return draft_token_ids + + +def to_output_token_ids(tree: list[Node], + num_drafts: int, + accepted: list[Node], + bonus: Node) -> torch.Tensor: + """ + Creates a tensor where only the accepted and bonus nodes are mapped to + their token ids. + """ + output_token_ids = torch.empty(num_drafts + 1, device=DEVICE) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + for accepted_node in accepted: + index = tree.index(accepted_node) - 1 + token_id = get_token_id(tree, accepted_node) + output_token_ids[index] = token_id + output_token_ids[-1] = get_token_id(tree, bonus) + return output_token_ids + + +def create_logits_tensor(tree: list[Node], num_logits: int, + sample_map: dict[Node, Node]) -> torch.Tensor: + """ + Helper function to create logits tensor that will produce the desired + token ids on argmax + """ + logits = torch.full((num_logits, VOCAB_SIZE), -100.0, device=DEVICE) + for index in range(num_logits): + node = tree[index] + if node not in sample_map: + continue + sampled_node = sample_map[node] + token_id = get_token_id(tree, sampled_node) + logits[index, token_id] = 100.0 + return logits + + +def create_sampling_metadata( + all_greedy: bool, + temperature: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, + generators: Optional[dict[int, Any]] = None, +) -> SamplingMetadata: + """ + Create a v1 sampling metadata object with all_greedy set to the given + value. Either all greedy or all random sampling is used. + """ + generators = generators or {} + if all_greedy: + temperature = None + else: + assert temperature is not None + + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators=generators, + max_num_logprobs=0, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager(), + ) + + +def assert_rejection_sample( + draft_tree: list[Node], + spec_nodes: list[list[Node]], + target_sample_maps: list[dict[Node, Node]], + expected_accepted_nodes: list[list[Node]], + expected_bonus_nodes: list[Node], +): + num_drafts = len(draft_tree) + # Create tree rejection sampler. + tree_rejection_sampler = create_tree_rejection_sampler( + draft_tree, len(spec_nodes)) + + # Create the bonus level. + last_level = len(draft_tree[-1]) + leaves = [node for node in draft_tree if len(node) == last_level] + bonus_level = [leaf + (0, ) for leaf in leaves] + # Create tree with root node and bonus level added. + tree = [()] + draft_tree + bonus_level + + # Convert drafted tokens mapping to tensor representation. + input_draft_token_ids = torch.stack( + [to_input_draft_token_ids(tree, num_drafts, s) for s in spec_nodes]) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + input_draft_token_ids.tolist(), device=DEVICE) + + # Generate logits that deterministically produce the given sampled + # tokens. + target_logits = torch.cat([ + create_logits_tensor(tree, num_drafts + 1, sample_map) + for sample_map in target_sample_maps + ]) + + # Create greedy sampling metadata. + metadata = create_sampling_metadata(all_greedy=True) + + # Rejection sample. + output_tokens, _ = tree_rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=target_logits, + bonus_token_ids=None, + sampling_metadata=metadata, + ) + + # Compare with output with expected. + expected_tokens = torch.stack( + [to_output_token_ids(tree, num_drafts, a, b) for a, b in zip(expected_accepted_nodes, expected_bonus_nodes)]) + assert torch.equal(output_tokens, expected_tokens) + + +########################### Tests ########################### + + +def test_single_node(): + """ + Test exact match for a single node. + """ + draft_tree = [ + (0, ), + ] + drafted_tokens = [ + [(0, )], + ] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 0), + }] + expected_accepted_tokens = [ + [(0, )], + ] + expected_bonus_tokens = [ + (0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_chain_full_acceptance(): + draft_tree = [ + (0, ), + (0, 0), + (0, 0, 0), + ] + drafted_tokens = [ + [(0, ), (0, 0), (0, 0, 0)], + ] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + (0, 0, 0): (0, 0, 0, 0) + }] + expected_accepted_tokens = [ + [(0, ), (0, 0), (0, 0, 0)], + ] + expected_bonus_tokens = [ + (0, 0, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_chain_partial_acceptance(): + draft_tree = [ + (0, ), + (0, 0), + (0, 0, 0), + ] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + }] + drafted_tokens = [ + [(0, ), (0, 0), (0, 0)], # Mismatch for final draft (expected (0,0,0)) + ] + expected_accepted_tokens = [ + [(0, ), (0, 0)], + ] + expected_bonus_tokens = [ + (0, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_full_acceptance(): + draft_tree = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens = [ + [(1, ), (1, 1)], + ] + target_sample_maps = [{ + (): (1, ), + (1, ): (1, 1), + (1, 1): (1, 1, 0), + }] + expected_accepted_tokens = [ + [(1, ), (1, 1)], + ] + expected_bonus_tokens = [ + (1, 1, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_partial_acceptance(): + draft_tree = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 1), + }] + drafted_tokens = [ + [(0, ), (0, 0)], # Mismatch for final draft (expected (0,0)) + ] + expected_accepted_tokens = [ + [(0, )], + ] + expected_bonus_tokens = [ + (0, 1), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_early_rejection(): + draft_tree = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + }] + drafted_tokens = [ + [(1, ), (0, 1)], # Mismatch for the first draft (expected (0,)) + ] + expected_accepted_tokens = [ + [], + ] + expected_bonus_tokens = [ + (0, ), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_full_acceptance_multiple_sequences(): + draft_tree = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 1), + (0, 1): (0, 1, 0), + }, { + (): (1, ), + (1, ): (1, 0), + (1, 0): (1, 0, 0), + }] + drafted_tokens = [ + [(0, ), (0, 1)], # Sequence 1 + [(1, ), (1, 0)], # Sequence 2 + ] + expected_accepted_tokens = [ + [(0, ), (0, 1)], + [(1, ), (1, 0)], + ] + expected_bonus_tokens = [ + (0, 1, 0), + (1, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_partial_acceptance_multiple_sequences(): + draft_tree = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + target_sample_maps = [{ + (): (0, ), + (0, ): (0, 1), + }, { + (): (1, ), + (1, ): (1, 0), + }] + drafted_tokens = [ + [(0, ), (0, 0)], # Mismatch for the second draft (expected (0,1)) + [(0, ), (0, 1)], # Mismatch for the first draft (expected (1,)) + ] + expected_accepted_tokens = [ + [(0, )], + [], + ] + expected_bonus_tokens = [ + (0, 1), + (1, ), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_deep_tree_full_acceptance(): + draft_tree = [ + (0, ), + (1, ), # Level 1 + (0, 0), + (0, 1), + (1, 0), + (1, 1), # Level 2 + (0, 0, 0), + (0, 0, 1), + (0, 1, 0), + (0, 1, 1), + (1, 0, 0), + (1, 0, 1), + (1, 1, 0), + (1, 1, 1) # Level 3 + ] + target_sample_maps = [{ + (): (1, ), + (0, ): (0, 1), + (1, ): (1, 1), + (0, 0): (0, 0, 0), + (1, 1): (1, 1, 0), + (1, 1, 0): (1, 1, 0, 0), + }] + drafted_tokens = [ + [(1, ), (1, 1), (1, 1, 0)], + ] + expected_accepted_tokens = [ + [(1, ), (1, 1), (1, 1, 0)], + ] + expected_bonus_tokens = [ + (1, 1, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7b8445a0b287..58280b4d0541 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -443,7 +443,7 @@ def create_deterministic_logits(token_ids, k: int): # Mock the model forward calls. forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), torch.zeros(total_tokens, hidden_size, device=device))] - for cu_num_drafts in proposer.cu_drafts_per_level: + for cu_num_drafts in proposer.cu_drafts_per_level[1:]: h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) @@ -454,7 +454,7 @@ def create_deterministic_logits(token_ids, k: int): model_mock.side_effect = forward_returns # Mock the compute_logits calls. - cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, + cu_num_drafts_tensor = torch.tensor(proposer.cu_drafts_per_level, dtype=torch.int32, device=device) logits_returns = [] diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 14fc5589a89a..e2183f14d392 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -48,6 +48,7 @@ try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +from vllm.tree_drafter_params import TreeDrafterParams from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType, LazyLoader, common_broadcastable_dtype, random_uuid) @@ -1980,6 +1981,9 @@ class SpeculativeConfig: ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # params generated in the post-init stage for tree drafting. + tree_drafter_params: SkipValidation[TreeDrafterParams] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2201,12 +2205,9 @@ def __post_init__(self): (i + 1) * (0, ) for i in range(self.num_speculative_tokens) ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) + # Construct tree drafter params from the serialized token tree. + self.tree_drafter_params = TreeDrafterParams.from_spec_token_tree( + self.speculative_token_tree) self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( diff --git a/vllm/tree_drafter_params.py b/vllm/tree_drafter_params.py new file mode 100644 index 000000000000..79340e1ac6c0 --- /dev/null +++ b/vllm/tree_drafter_params.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with TreeAttention.""" + +import ast +from dataclasses import dataclass + + +@dataclass +class TreeDrafterParams: + tree_choices: list[tuple[int, ...]] + attn_mask: list[list[bool]] + # Cumulative number of drafts at each level. + cu_drafts_per_level: list[int] + # Number of child drafts that each token has at the given level. + child_drafts_per_level: list[int] + # Maps each draft token to its level in the tree. + draft_levels: list[int] + + @staticmethod + def from_spec_token_tree(spec_token_tree: str) -> "TreeDrafterParams": + # Parse the speculative token tree. + tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) + # Sort the tree breadth-first. + tree_choices.sort(key=lambda t: (len(t), t)) + # Only trees with fixed branching factor per level are + # currently supported for tree attention. + _assert_fixed_branching_factor_per_level(tree_choices, spec_token_tree) + + tree_depth = len(tree_choices[-1]) + 1 + # Precompute per-level properties of the tree. + num_nodes_per_level = [0] * tree_depth + num_nodes_per_level[0] = 1 + for node in tree_choices: + num_nodes_per_level[len(node)] += 1 + + cu_drafts_per_level = [0] + child_drafts_per_level = [] + draft_levels = [] + for level in range(1, tree_depth): + cu_drafts_per_level.append(cu_drafts_per_level[-1] + + num_nodes_per_level[level]) + child_drafts_per_level.append(num_nodes_per_level[level] // + num_nodes_per_level[level - 1]) + draft_levels += [level - 1] * num_nodes_per_level[level] + + # Construct the tree attention bias. + depth_counts = _get_depth_counts(tree_choices) + attn_mask = _prepare_tree_attn_bias( + tree_choices, + depth_counts, + ) + + return TreeDrafterParams( + tree_choices=tree_choices, + attn_mask=attn_mask, + cu_drafts_per_level=cu_drafts_per_level, + child_drafts_per_level=child_drafts_per_level, + draft_levels=draft_levels, + ) + +def _has_fixed_branching_factor(tree_nodes, level): + """ + Checks if all nodes at the given level have the same number of children. + """ + next_level_nodes = [node for node in tree_nodes if len(node) == level + 1] + if len(next_level_nodes) == 0: + return True + + level_nodes = [node for node in tree_nodes if len(node) == level] + child_counts = [] + for parent in level_nodes: + child_counts.append( + sum(1 for child in next_level_nodes if child[:-1] == parent) + ) + return len(set(child_counts)) <= 1 # All counts are the same. + +def _assert_fixed_branching_factor_per_level( + tree_nodes: list[tuple[int, ...]], + spec_token_tree: str) -> None: + """ + Asserts that each level of the tree has a fixed branching factor. That is, + the number of children per node is the same within a level, but can vary + across levels. + """ + tree_depth = len(tree_nodes[-1]) + 1 + for level in range(1, tree_depth): + assert _has_fixed_branching_factor(tree_nodes, level), \ + f"The configured spec token tree '{spec_token_tree}' has variable " \ + f"branching at level {level}. Tree speculative decoding requires " \ + f"a uniform number of children per level." + +def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: + """ + Counts the number of choices at each depth of the tree. + """ + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + return depth_counts + + +def _prepare_tree_attn_bias( + sorted_tree_choices: list[tuple[int, ...]], + depth_counts: list[int], +) -> list[list[bool]]: + # +1 comes from the additional root node. + tree_len = len(sorted_tree_choices) + 1 + tree_attn_mask = [[False for _ in range(tree_len)] + for _ in range(tree_len)] + + mask_val = True + for i in range(tree_len): + # Set diagonal to all True. Each token should attend to itself. + tree_attn_mask[i][i] = mask_val + # Set root column to all True. All tokens attend to it. + tree_attn_mask[i][0] = mask_val + + # Set all ancestors to True. + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + if len(cur_tree_choice) == 1: + continue + + for c in range(len(cur_tree_choice) - 1): + ancestor_idx = sorted_tree_choices.index( + cur_tree_choice[:c + 1]) + 1 + tree_attn_mask[j + start + 1][ancestor_idx] = mask_val + start += depth_counts[i] + return tree_attn_mask diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082..1a14090178cb 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -167,21 +167,16 @@ def __init__( ): self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size - spec_config = vllm_config.speculative_config - spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) - # Construct the tree attention bias. - depth_counts = _get_depth_counts(tree_choices) - self.tree_attn_bias = _prepare_tree_attn_bias( - tree_choices, - depth_counts, - dtype=torch.float32, - device=device, - ) + tree_drafter_params = (spec := spec_config) and spec.tree_drafter_params + if tree_drafter_params is None: + # Standard decoding. + self.tree_attn_bias = torch.zeros((1, 1), dtype=torch.float32, device=device) + else: + # Spec decoding. + tree_attn_mask = torch.tensor(tree_drafter_params.attn_mask, device=device) + self.tree_attn_bias = torch.where(tree_attn_mask, 0, -torch.inf) + self.__class__.reorder_batch_threshold = self.tree_attn_bias.shape[0] def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -250,58 +245,6 @@ def build_for_drafting( return attn_metadata -def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: - # Count the number of choices at each depth of the tree. - depth_counts = [] - prev_depth = 0 - for path in sorted_tree_choices: - depth = len(path) - if depth != prev_depth: - depth_counts.append(0) - depth_counts[depth - 1] += 1 - prev_depth = depth - return depth_counts - - -def _prepare_tree_attn_bias( - sorted_tree_choices: list[tuple[int, ...]], - depth_counts: list[int], - dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> torch.Tensor: - # +1 comes from the additional root node. - tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) - - # Set diagonal to all zeros. Each token should - # attend to itself. - mask_val = 0 - for i in range(tree_len): - tree_attn_mask[i, i] = mask_val - - # Set root to all zeros. All tokens attend to it. - tree_attn_mask[:, 0] = mask_val - - # Set all ancestors to zeros. - start = 0 - for i in range(len(depth_counts)): - for j in range(depth_counts[i]): - cur_tree_choice = sorted_tree_choices[start + j] - # Retrieve ancestor position. - if len(cur_tree_choice) == 1: - continue - ancestor_idx = [] - for c in range(len(cur_tree_choice) - 1): - ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) - tree_attn_mask[j + start + 1, ancestor_idx] = mask_val - start += depth_counts[i] - return tree_attn_mask - - class TreeAttentionImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/sample/tree_rejection_sampler.py b/vllm/v1/sample/tree_rejection_sampler.py new file mode 100644 index 000000000000..6c5328397671 --- /dev/null +++ b/vllm/v1/sample/tree_rejection_sampler.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.tree_drafter_params import TreeDrafterParams +from vllm.triton_utils import tl +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +logger = init_logger(__name__) + +PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 +EPS: torch.float = 1e-10 + + +class TreeRejectionSampler(nn.Module): + + def __init__( + self, + tree_drafter_params: TreeDrafterParams, + max_batch_size: int, + main_sampler: Sampler, + device: Optional[torch.device], + ): + super().__init__() + tree_mask = torch.tensor(tree_drafter_params.attn_mask, + device=device)[:, 1:] + self.expanded_tree_mask = tree_mask.expand((max_batch_size, -1, -1)) + self.batch_indices = torch.arange(max_batch_size, device=device) + # Cumulative # of draft tokens per level. + self.cu_drafts_per_level = tree_drafter_params.cu_drafts_per_level + # Cumulative # of tokens per level, including the root token. + self.cu_tokens_per_level = [ + num_drafts + 1 for num_drafts in self.cu_drafts_per_level + ] + self.main_sampler = main_sampler + + # Get tree depth (# levels) and width (# drafts at last level). + self.tree_depth = len(self.cu_drafts_per_level) + self.tree_width = self.cu_drafts_per_level[ + -1] - self.cu_drafts_per_level[-2] + + # Used for getting the flattened tree position for any draft token, + # indexed by it's level and position in the level. + tree_draft_positions = torch.zeros( + (self.tree_depth, self.tree_width), + device=device, + dtype=torch.int32, + ) + for level in range(1, self.tree_depth): + start = self.cu_tokens_per_level[level - 1] + end = self.cu_tokens_per_level[level] + level_num_drafts = end - start + level_draft_positions = torch.arange(start, end, device=device) + tree_draft_positions[ + level] = level_draft_positions.repeat_interleave( + self.tree_width // level_num_drafts) + self.expanded_tree_draft_positions = tree_draft_positions.expand( + (max_batch_size, -1, -1)) + + # Precompute offsets for tree-decoding batches. + num_tree_tokens = self.cu_tokens_per_level[-1] + self.batch_offsets = self.batch_indices * num_tree_tokens + # Precompute indices for logits corresponding to tree-internal + # tokens across batches. + num_tree_internal_tokens = self.cu_tokens_per_level[-2] + self.tree_internal_indices = self.batch_offsets.unsqueeze( + 1) + torch.arange(num_tree_internal_tokens, device=device) + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + bonus_token_ids: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + bonus_token_ids_tensor (Optional[torch.Tensor]): + Not used, and expected to be None. This method will generate + a bonus token for each request depending on which branch is + accepted. + sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + accepted_token_indices (list[torch.Tensor]): + Contains the accepted token indices for each tree drafting + request. + """ + assert bonus_token_ids is None + + draft_tree_size = self.cu_drafts_per_level[-1] + total_num_draft_tokens = sum(metadata.num_draft_tokens) + num_tree_decodes = total_num_draft_tokens // draft_tree_size + batch_indices = self.batch_indices[:num_tree_decodes] + batch_offsets = self.batch_offsets[:num_tree_decodes] + tree_internal_indices = self.tree_internal_indices[:num_tree_decodes] + tree_mask = self.expanded_tree_mask[:num_tree_decodes] + tree_draft_positions = self.expanded_tree_draft_positions[: + num_tree_decodes] + + # Get only the logits associated with tree-drafted requests. + num_tree_logits = num_tree_decodes * (1 + draft_tree_size) + + # Compute target probabilities for all logits corresponding to internal + # nodes in the tree. + vocab_size = target_logits.shape[-1] + tree_internal_logits = target_logits[tree_internal_indices.flatten()] + num_logits_per_batch = tree_internal_logits.shape[ + 0] // num_tree_decodes + target_probs = self.compute_probs( + tree_internal_logits, + num_logits_per_batch, + sampling_metadata, + ).view(num_tree_decodes, -1, vocab_size) + + # Reshape the draft token ids to [num_tree_decodes, draft_tree_size]. + draft_token_ids = metadata.draft_token_ids.view(num_tree_decodes, -1) + + # Below tensor will hold 1 for a token if accepted, and 0 if rejected. + tree_acceptances = torch.zeros( + (num_tree_decodes, self.tree_depth, self.tree_width), + device=tree_mask.device, + dtype=torch.int32) + parents_end = 0 + for level in range(self.tree_depth - 1): + # Get target and draft start and end token indices for the current + # tree level. + parents_start = parents_end + parents_end = self.cu_tokens_per_level[level] + drafts_start = self.cu_drafts_per_level[level] + drafts_end = self.cu_drafts_per_level[level + 1] + + # Get the target probabilities and drafted token ids for the + # current level. + level_target_probs = target_probs[:, parents_start:parents_end] + level_draft_token_ids = draft_token_ids[:, drafts_start:drafts_end] + # Accept/reject tokens at the current level. + level_acceptances = self.rejection_sample(level_target_probs, + level_draft_token_ids) + + # Broadcast the acceptances to the width of the tree. + num_level_drafts = drafts_end - drafts_start + tree_acceptances[:, + level, :] = (level_acceptances.repeat_interleave( + self.tree_width // num_level_drafts, dim=1)) + + # Get the boolean mask for the maximum length path of accepted tokens. + path_lengths = tree_acceptances.argmin(dim=1) + accepted_path_levels, accepted_path_indices = path_lengths.max(dim=1) + accepted_paths = tree_draft_positions[batch_indices, + accepted_path_levels, + accepted_path_indices] + path_masks = tree_mask[batch_indices, accepted_paths] + + # Create output buffer. + num_req = len(metadata.num_draft_tokens) + output_token_ids = torch.empty( + # +1 for the bonus token. + (num_req, draft_tree_size + 1), + dtype=torch. + int32, # Consistent with SamplerOutput.sampled_token_ids. + device=draft_token_ids.device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + + # Set accepted draft tokens. + output_token_ids[:num_tree_decodes, :draft_tree_size][ + path_masks] = draft_token_ids[path_masks] + + # Sample and add a bonus token to the accepted paths. + bonus_logit_indices = batch_offsets + accepted_paths + bonus_target_logits = target_logits[bonus_logit_indices] + bonus_sampler_output = self.main_sampler( + logits=bonus_target_logits, + sampling_metadata=sampling_metadata, + ) + output_token_ids[:num_tree_decodes, + -1] = bonus_sampler_output.sampled_token_ids.view(-1) + + if num_req > num_tree_decodes: + # In some cases, we may have leftover requests with 0 draft tokens. + # Sample a bonus token for each. + bonus_sampler_output = self.main_sampler( + logits=target_logits[num_tree_logits:], + sampling_metadata=sampling_metadata, + ) + output_token_ids[num_tree_decodes:, + -1] = bonus_sampler_output.sampled_token_ids.view( + -1) + + accepted_token_indices = [torch.where(row)[0] for row in path_masks] + return output_token_ids, accepted_token_indices + + def compute_probs(self, logits: torch.Tensor, logits_per_batch: int, + sampling_metadata: SamplingMetadata): + if sampling_metadata.all_greedy: + return logits + + assert sampling_metadata.temperature is not None + temperature = sampling_metadata.temperature.repeat_interleave( + logits_per_batch) + logits.div_(temperature.view(-1, 1)) + + top_k = None + if sampling_metadata.top_k is not None: + top_k = sampling_metadata.top_k.repeat_interleave(logits_per_batch) + top_p = None + if sampling_metadata.top_p is not None: + top_p = sampling_metadata.top_p.repeat_interleave(logits_per_batch) + logits = apply_top_k_top_p(logits, top_k, top_p) + output_probs = logits.softmax(dim=-1, dtype=torch.float32) + return output_probs + + def rejection_sample(self, target_probs: torch.Tensor, + draft_tokens: torch.Tensor): + # TODO(TheEpicDolphin): Add support for probabilistic-style rejection + # sampling, as used in EAGLE. + target_argmax = target_probs.argmax(dim=-1) + target_sampled_tokens = target_argmax.repeat_interleave( + draft_tokens.shape[-1] // target_argmax.shape[-1], + dim=1, + ) + return target_sampled_tokens == draft_tokens + + @staticmethod + def parse_output( + output_token_ids: torch.Tensor, + vocab_size: int, + ) -> list[list[int]]: + """Parse the output of the rejection sampler. + + Args: + output_token_ids: The sampled token IDs in shape + [batch_size, max_spec_len + 1]. The rejected tokens are + replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler + and will be filtered out in this function. + vocab_size: The size of the vocabulary. + + Returns: + A list of lists of token IDs. + """ + output_token_ids_np = output_token_ids.cpu().numpy() + # Create mask for valid tokens. + valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & + (output_token_ids_np < vocab_size)) + outputs = [ + row[valid_mask[i]].tolist() + for i, row in enumerate(output_token_ids_np) + ] + return outputs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f995..84586fbd52f6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast from dataclasses import replace from typing import Optional @@ -97,26 +96,14 @@ def __init__( dtype=self.dtype, device=device) - # Parse the speculative token tree. - spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) - tree_depth = len(self.tree_choices[-1]) - # Precompute per-level properties of the tree. - num_drafts_per_level = [0] * tree_depth - for node in self.tree_choices: - num_drafts_per_level[len(node) - 1] += 1 - self.cu_drafts_per_level = [num_drafts_per_level[0]] - self.child_drafts_per_level = [num_drafts_per_level[0]] - for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) - # Precompute draft position offsets in flattened tree. - self.tree_draft_pos_offsets = torch.arange( + # Get tree drafter params. + tree_drafter_params = self.speculative_config.tree_drafter_params + self.cu_drafts_per_level = tree_drafter_params.cu_drafts_per_level + self.child_drafts_per_level = tree_drafter_params.child_drafts_per_level + # Precompute draft token positions in flattened tree. + self.flattened_tree_positions = torch.arange( 1, - len(self.tree_choices) + 1, + len(tree_drafter_params.tree_choices) + 1, device=device, dtype=torch.int32, ).repeat(max_batch_size, 1) @@ -342,7 +329,6 @@ def propose_tree( TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] - level_num_drafts = total_num_drafts # Sample a draft token for each child at the tree root level. num_children = self.child_drafts_per_level[0] if num_children == 1: @@ -363,14 +349,19 @@ def propose_tree( tree_hidden_states = torch.empty(0, device=self.hidden_states.device, dtype=self.hidden_states.dtype) - # Precompute the draft token positions. - flattened_draft_positions = ( + # Precompute the draft token query positions. + flattened_query_positions = ( positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + self.flattened_tree_positions[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) - for level in range(tree_depth - 1): + for level in range(1, tree_depth - 1): + # Update the # drafts counters for the current level. + level_num_drafts = self.cu_drafts_per_level[ + level] - total_num_drafts + total_num_drafts = self.cu_drafts_per_level[level] + # Get draft positions for RoPE. - draft_positions = positions + (level + 1) + draft_positions = positions + level exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. @@ -411,7 +402,7 @@ def propose_tree( ) attn_metadata = tree_attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, - draft_index=level + 1, + draft_index=level, ) # Apply new attention metadata to all layers. @@ -427,8 +418,7 @@ def propose_tree( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_query_positions[:, :query_len] block_numbers = query_positions // self.block_size block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) @@ -442,8 +432,7 @@ def propose_tree( # Copy inputs to buffer for cudagraph. num_tokens = attn_metadata.num_actual_tokens - input_ids = tree_input_ids.view(-1) - self.input_ids[:num_tokens] = input_ids + self.input_ids[:num_tokens] = tree_input_ids.view(-1) self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view( num_tokens, -1) @@ -479,7 +468,7 @@ def propose_tree( ) # Sample a draft token for each child at the next tree level. - num_children = self.child_drafts_per_level[level + 1] + num_children = self.child_drafts_per_level[level] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: @@ -488,10 +477,6 @@ def propose_tree( batch_size, -1) draft_token_ids_list.append(draft_token_ids) - # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts - total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list def prepare_inputs( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5ee44a82574c..f567793e6cd4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, +from vllm.forward_context import (BatchDescriptor, DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase @@ -55,6 +55,7 @@ get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, @@ -66,11 +67,12 @@ KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + ModelRunnerOutput, SamplerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.sample.tree_rejection_sampler import TreeRejectionSampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -98,6 +100,7 @@ "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) +VIRTUAL_ENGINE = 0 class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @@ -178,6 +181,23 @@ def __init__( # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # Tree spec decoding. + self.use_tree_spec_decode = envs.VLLM_ATTENTION_BACKEND == "TREE_ATTN" + self.draft_position_offsets = None + if self.use_tree_spec_decode: + tree_drafter_params = self.speculative_config.tree_drafter_params + draft_levels = torch.tensor( + tree_drafter_params.draft_levels, + dtype=torch.int64, + device=self.device) + self.flattened_draft_indices = torch.arange( + draft_levels.shape[0], + dtype=torch.int64, + device=self.device) + # The adjustments needed to go from draft positions in the + # flattened tree to their levels. + self.draft_position_offsets = draft_levels - self.flattened_draft_indices + self.use_aux_hidden_state_outputs = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on @@ -198,7 +218,7 @@ def __init__( else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") - self.rejection_sampler = RejectionSampler() + self.rejection_sampler = self._create_rejection_sampler() # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -803,7 +823,14 @@ def _prepare_inputs( for req_id, draft_token_ids in ( scheduler_output.scheduled_spec_decode_tokens.items()): req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) + num_draft_toks = len(draft_token_ids) + num_draft_tokens[req_idx] = num_draft_toks + if (num_draft_toks > 0 + and self.draft_position_offsets is not None): + # Offset the draft positions. + start = self.query_start_loc[req_idx] + 1 + end = start + num_draft_toks + self.positions[start:end] += self.draft_position_offsets spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) @@ -1658,31 +1685,13 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) + draft_token_index_remap = None else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids + sampler_output, draft_token_index_remap = self._rejection_sample( + logits, spec_decode_metadata, sampling_metadata) + # Update the KV cache to reflect the newly accepted draft tokens. + self._rewind_kv_cache(attn_metadata, draft_token_index_remap) num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: @@ -1772,6 +1781,7 @@ def execute_model( aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, + draft_token_index_remap=draft_token_index_remap, ) self.eplb_step() @@ -1798,6 +1808,7 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, + draft_token_index_remap: Optional[list[torch.Tensor]] = None ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1865,10 +1876,27 @@ def propose_draft_token_ids( ] num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, dtype=torch.int32) + common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) + if draft_token_index_remap is not None: + # Remap the draft token indices to the given values. This is + # currently used for tree spec decoding. + query_start_loc = common_attn_metadata.query_start_loc + for req_idx, new_indices in enumerate(draft_token_index_remap): + assert num_draft_tokens[req_idx] > 0 + num_accepted_tokens = new_indices.shape[0] + if num_accepted_tokens == 0: + # No tokens were accepted, skip remapping. + continue + # Get start and end of the draft tree tokens. Skip + # the root token at the start. + draft_start = query_start_loc[req_idx] + 1 + draft_end = query_start_loc[req_idx + 1] + token_indices[draft_start:draft_end] = new_indices + 1 + target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] @@ -2242,7 +2270,7 @@ def _dummy_run( - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is needed. - force_attention: If True, always create attention metadata. Used to + force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. @@ -2381,6 +2409,7 @@ def _dummy_run( with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, + virtual_engine=VIRTUAL_ENGINE, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, @@ -2462,25 +2491,34 @@ def _dummy_sampler_run( else: raise e if self.speculative_config: - draft_token_ids = [[0] for _ in range(num_reqs)] + num_spec_tokens = self.speculative_config.num_speculative_tokens + draft_token_ids = [[0] * num_spec_tokens for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids, self.device) - num_tokens = sum(len(ids) for ids in draft_token_ids) + num_draft_tokens = num_spec_tokens * num_reqs # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, + if self.use_tree_spec_decode: + # The tree rejection sampler should not receive bonus tokens. + # It computes its own bonus tokens depending on which branch + # is accepted. + num_logits = num_draft_tokens + num_reqs + bonus_token_ids = None + else: + num_logits = num_draft_tokens + # NOTE(woosuk): Here, we should use int32 because the sampler uses + # int32 for bonus_token_ids. If the dtype mismatches, re-compilation + # will occur at runtime. + bonus_token_ids = torch.zeros(num_reqs, + device=self.device, + dtype=torch.int32) + target_logits = torch.randn(num_logits, logits.shape[-1], device=self.device, dtype=logits.dtype) - # NOTE(woosuk): Here, we should use int32 because the sampler uses - # int32 for bonus_token_ids. If the dtype mismatches, re-compilation - # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -3339,3 +3377,131 @@ def _build_encoder_only_attn_metadata( group_metadata[layer_name] = (common_metadata, metadata) return group_metadata + + def _create_rejection_sampler(self): + if self.use_tree_spec_decode: + # Tree rejection sampling is required when using tree attention. + return TreeRejectionSampler( + self.speculative_config.tree_drafter_params, + max_batch_size=self.max_num_reqs, + main_sampler=self.sampler, + device=self.device, + ) + return RejectionSampler() + + def _rejection_sample( + self, + logits: torch.Tensor, + spec_decode_metadata: SpecDecodeMetadata, + sampling_metadata: SamplingMetadata, + ) -> tuple[SamplerOutput, Optional[list[torch.Tensor]]]: + if self.use_tree_spec_decode: + # Rejection sample from the tree of drafts. Bonus tokens are not + # provided because it generates its own. + output_token_ids, accepted_token_pos_remap = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + None , # bonus_token_ids + sampling_metadata, + ) + sampler_output = SamplerOutput( + sampled_token_ids=output_token_ids, + logprobs_tensors=None, + ) + else: + # No token remapping is needed when rejection sampling from a + # chain of drafts. + accepted_token_pos_remap = None + + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[ + spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + spec_decode_metadata.target_logits_indices] + + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + return sampler_output, accepted_token_pos_remap + + def _rewind_kv_cache( + self, + layer_attn_metadatas: dict[str, Any], + accepted_draft_indices: Optional[list[torch.Tensor]], + ): + """ + Copies K/Vs from the accepted path and makes them contiguous in the + paged KV slot map, effectively "rewinding" the speculative process to + keep only validated tokens while discarding rejected branches. This + method is called immediately after rejection sampling. + + Args: + layer_attn_metadatas: Layer-to-attention-metadata mapping. + accepted_draft_indices: Per-batch list of accepted draft indices. + """ + + if accepted_draft_indices is None: + # Nothing to do. + return + + block_size = self.cache_config.block_size + # Get all attention layers. + layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + # Get attention metadata for a layer. The current assumption is + # that all attention layers share the same KV cache slot mapping. + layer_name = next(iter(layers.keys())) + attn_metadata = layer_attn_metadatas[layer_name] + assert isinstance(attn_metadata, TreeAttentionMetadata) + num_decode_tokens = attn_metadata.num_decode_tokens + num_decodes = attn_metadata.num_decodes + slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens].view(num_decodes, -1) + # Slice slot mapping to get only the draft tokens. + draft_tree_slot_mapping = slot_mapping[:, 1:] + + # Collect all slot remappings across batches. + from_slots = [] + to_slots = [] + for batch, from_indices in enumerate(accepted_draft_indices): + num_indices = from_indices.shape[0] + if num_indices == 0: + continue + to_indices = self.flattened_draft_indices[:num_indices] + from_slots.append(draft_tree_slot_mapping[batch, from_indices]) + to_slots.append(draft_tree_slot_mapping[batch, to_indices]) + + if len(to_slots) == 0: + return + + # Convert to flat tensors. + from_slots = torch.cat(from_slots) + to_slots = torch.cat(to_slots) + + # Get KV cache blocks and offsets. + from_blocks = from_slots // block_size + from_offsets = from_slots % block_size + to_blocks = to_slots // block_size + to_offsets = to_slots % block_size + + # For all layers, copy accepted token KVs to contiguous memory. + for layer in layers.values(): + kv_cache = layer.kv_cache[VIRTUAL_ENGINE] + kv_cache[:, to_blocks, to_offsets, :, :] = kv_cache[:, from_blocks, from_offsets, :, :]