Skip to content

Conversation

TheEpicDolphin
Copy link
Contributor

@TheEpicDolphin TheEpicDolphin commented Aug 12, 2025

Purpose

Continuing with my work in this PR: #20401 , where I added support for drafting a tree of speculative tokens, that are then validated by the target model. In this PR, I add the class that performs rejection sampling for those draft tokens, so that they conform to the target model's output distribution. This class is based off of RejectionSampler, but with some key differences necessary to support a tree structure for drafted tokens. I added some tests for this new class to verify it's correctness.

In addition, I also made some refactors to the tree attention parameters to improve readability and performance. I created a new class called TreeDrafterParams which is created during the SpeculativeConfig initialization, and precomputes several properties from the spec token tree so that other tree-attention systems can use it (without re-computing themselves). Examples: attention mask, children per level, etc.

Finally, I reenabled the test_eagle_correctness test, which is currently skipping the tree attention backend due to flakiness. After my changes in this PR and from my testing on an H100, I don't observe flakiness with the test anymore. However, I decided to reduce the accuracy threshold from 66% -> 50% for tree attention only, so as to not cause noise in test signals for others. This is also because I've heard that triton attention kernels (which tree attention uses under the hood), tend to have more floating point non-determinism than flash attention 3.

Test Plan

Automated Testing

New tree rejection sampler tests:

(py312conda) bash-5.1$ pytest tests/v1/sample/test_tree_rejection_sampler.py
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 9 items                                                                                                                                                                 

tests/v1/sample/test_tree_rejection_sampler.py .........                                                                                                                    [100%]

================================================================================ 9 passed in 6.51s ================================================================================

Eagle tree proposer test:

(py312conda) bash-5.1$ pytest tests/v1/spec_decode/test_eagle.py -k test_propose_tree
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 47 items / 43 deselected / 4 selected                                                                                                                                   

tests/v1/spec_decode/test_eagle.py ....                                                                                                                                     [100%]

======================================================================== 4 passed, 43 deselected in 10.82s ========================================================================

Spec decode e2e test:

(py312conda) bash-5.1$ pytest tests/v1/e2e/test_spec_decode.py -k test_eagle_correctness
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 13 items / 1 deselected / 12 selected                                                                                                                                   

tests/v1/e2e/test_spec_decode.py ssssssss..ss                                                                                                                               [100%]

================================================================================ warnings summary =================================================================================
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TREE_ATTN-llama3_eagle]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TREE_ATTN-llama3_eagle]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TREE_ATTN-llama3_eagle3]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TREE_ATTN-llama3_eagle3]
  /home/gdelfin/.conda/envs/py312conda/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=4129255) is multi-threaded, use of fork() may lead to deadlocks in the child.
    self.pid = os.fork()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================= 2 passed, 10 skipped, 1 deselected, 4 warnings in 419.61s (0:06:59) =======================================================

Manual Testing

Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=TREE_ATTN
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (0, 0), (0, 0, 0)]"}'
python -m vllm.entrypoints.openai.api_server --model $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --block-size=128 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
response = client.chat.completions.create(model="meta-llama/Llama-3.1-8B-Instruct", messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Explain the theory of relativity in simple terms."}],temperature=0.2)
print(response)

Response

ChatCompletion(id='chatcmpl-c19ec61b38214df896f3dfe01472b6b3', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The theory of relativity, developed by Albert Einstein, is a fundamental concept in physics that explains how space and time are connected. It's a bit complex, but I'll try to break it down in simple terms.\n\n**Special Relativity (1905)**\n\nImagine you're on a train, and you throw a ball straight up in the air. From your perspective on the train, the ball goes up and comes down in a straight line. Now, imagine someone is standing outside the train, watching you throw the ball. From their perspective, the ball doesn't just go straight up and down; it also moves forward because the train is moving.\n\nEinstein said that how we measure time and space depends on how we're moving relative to each other. If you're on the train, time and space seem normal to you. But if you're standing outside the train, time and space seem different because you're moving relative to the train.\n\n**Key points:**\n\n1. **Time dilation**: Time can seem to pass slower for someone who is moving relative to you.\n2. **Length contraction**: Objects can appear shorter to someone who is moving relative to you.\n3. **The speed of light is always the same**: No matter how fast you're moving, the speed of light remains constant.\n\n**General Relativity (1915)**\n\nImagine you're standing on a trampoline. If you put a heavy object, like a bowling ball, on the trampoline, it will warp and curve, creating a dent. That's kind of like what gravity does to space and time.\n\nEinstein said that gravity is not a force that pulls objects towards each other, but rather a curvature of space and time caused by massive objects. The more massive the object, the greater the curvature.\n\n**Key points:**\n\n1. **Gravity is a curvature of space and time**: Massive objects warp space and time, creating gravity.\n2. **Equivalence principle**: The effects of gravity are the same as the effects of acceleration.\n\nIn summary, the theory of relativity says that:\n\n* Time and space are connected and can seem different depending on how we're moving relative to each other.\n* Gravity is a curvature of space and time caused by massive objects.\n* The speed of light is always the same, no matter how fast we're moving.\n\nI hope this helps you understand the basics of the theory of relativity!", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], created=1756252397, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=487, prompt_tokens=52, total_tokens=539, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None, kv_transfer_params=None)

Benchmark

Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=TREE_ATTN
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (0, 0), (0, 0, 0)]"}'
python -m vllm.entrypoints.openai.api_server --model $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --block-size=128 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
python benchmarks/benchmark_serving.py --model $LLAMA_MODEL --tokenizer $LLAMA_MODEL --host 0.0.0.0 --dataset-name random --ignore-eos --request-rate inf --random-input-len 1000 --random-output-len 300 --max-concurrency 64 --num-prompts 128

Results

----------------Serving Benchmark Result----------------
Successful requests:                     64        
Maximum request concurrency:             64        
Benchmark duration (s):                  5.97      
Total input tokens:                      63936     
Total generated tokens:                  13077     
Request throughput (req/s):              10.73     
Output token throughput (tok/s):         2192.04   
Total Token throughput (tok/s):          12909.36  
---------------Time to First Token----------------
Mean TTFT (ms):                          1349.37   
Median TTFT (ms):                        1322.26   
P99 TTFT (ms):                           2360.83   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.09     
Median TPOT (ms):                        19.82     
P99 TPOT (ms):                           52.62     
---------------Inter-token Latency----------------
Mean ITL (ms):                           48.18     
Median ITL (ms):                         37.53     
P99 ITL (ms):                            286.35    
==================================================
SpecDecoding metrics: Draft acceptance rate: 42.1%, Mean acceptance length: 2.26, Accepted: 7793 tokens, Drafted: 18528 tokens, Per-position acceptance rate: 0.639, 0.342, 0.281

Next Steps

  • Update positions for the K/Vs of accepted draft tokens. This must be done for both the target and draft models.
  • Add more e2e tests specifically for drafting trees.
  • Explore dynamic trees, and adjusting tree breadth depending on current sequence position.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Aug 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a tree-based sampler for speculative decoding, which is a significant feature for improving performance. The changes are well-structured, with new components like TreeDrafterParams and TreeRejectionSampler to handle the tree-based logic. The Eagle proposer is also updated to support tree drafting and to output draft probabilities. My review has identified a few critical issues regarding the handling of irregular token trees and the correctness of probability calculations, which should be addressed to ensure the feature works as expected.

@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 2 times, most recently from 9c59df6 to 3da3a66 Compare August 16, 2025 19:44
@mergify mergify bot added llama Related to Llama models and removed needs-rebase labels Aug 16, 2025
Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 20, 2025
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 17 times, most recently from 544ec15 to aa5dd93 Compare August 25, 2025 22:25
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 8 times, most recently from ac2523d to f04916b Compare August 26, 2025 23:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant