-
-
Notifications
You must be signed in to change notification settings - Fork 9.5k
feat: update flashinfer ar oneshot params #22108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the flashinfer
all-reduce fusion parameters by removing the use_oneshot
argument from the trtllm_allreduce_fusion
function call. This change aligns with a recent update in the flashinfer
library where this parameter is now auto-deduced. The change is correct and necessary to maintain compatibility with the updated dependency.
@yyihuang do you know when the next release will be? |
@mgoin flashinfer new release is ready since yesterday. |
@@ -457,7 +457,6 @@ def call_trtllm_fused_allreduce_norm( | |||
hidden_dim=allreduce_in.shape[-1], | |||
workspace_ptrs=_FI_WORKSPACE_TENSOR, | |||
launch_with_pdl=launch_with_pdl, | |||
use_oneshot=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have more context over this? At the very least the comment above might need to be updated. "For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce". Is there a test we can add for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We deduce the strategy by token num: https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/comm/trtllm_ar.py#L826.
Different kernels would be called by this strategy: https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/comm/trtllm_allreduce_fusion.cuh#L1388-L1400
We can add a unit test of token_num > 128 if needed. And using some general model tests would also be okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And you might be interested in the first of the PR series on flashinfer's allreduce_fusion.
#20691
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yyihuang Do you have results of benchmarking for oneshot vs twoshot? Firstly, usage of two shot should not only depend on token_num but world_size, similarly what is done in custom_all_reduce.cuh
. Secondly, In my benchmarking using two shot only made sense on Hopper, whereas one shot was better across all workloads on Blackwell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ilmarkov I benched on h200. Let me do more benchmarks on blackwell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you benchmark @ilmarkov! Let's keep this as a draft PR since we did not get speedup by this auto, until we figure out the problem shape and use case of each strategy cross-DLFW. In tllm we're taking this for min-latency case (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/allreduceOp.cpp#L453), which might not be the target case or there might be some framework diffs.
Signed-off-by: Avery Yingyi Huang <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
5c22ddb
to
1aac329
Compare
…' of github.com:vllm-project/vllm into auto-oneshot
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
ust for use_oneshot auto-deduction.
The interface in flashinfer is updated as
flashinfer-ai/flashinfer#1365 (in next release)
Test Plan
vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --no-enable-prefix-caching -tp 4 --compilation-config='{"pass_config": {"enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level":3}'
Test Result
-main branch
============ Serving Benchmark Result ============
Successful requests: 600
Request rate configured (RPS): 10.00
Benchmark duration (s): 60.63
Total input tokens: 305347
Total generated tokens: 90000
Request throughput (req/s): 9.90
Output token throughput (tok/s): 1484.31
Total Token throughput (tok/s): 6520.18
---------------Time to First Token----------------
Mean TTFT (ms): 18.26
Median TTFT (ms): 17.57
P99 TTFT (ms): 28.37
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 4.61
Median TPOT (ms): 4.60
P99 TPOT (ms): 4.89
---------------Inter-token Latency----------------
Mean ITL (ms): 4.61
Median ITL (ms): 4.40
P99 ITL (ms): 8.97
-this branch
============ Serving Benchmark Result ============
Successful requests: 600
Request rate configured (RPS): 10.00
Benchmark duration (s): 60.64
Total input tokens: 305347
Total generated tokens: 90000
Request throughput (req/s): 9.89
Output token throughput (tok/s): 1484.19
Total Token throughput (tok/s): 6519.67
---------------Time to First Token----------------
Mean TTFT (ms): 17.77
Median TTFT (ms): 17.51
P99 TTFT (ms): 27.39
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 4.60
Median TPOT (ms): 4.60
P99 TPOT (ms): 4.85
---------------Inter-token Latency----------------
Mean ITL (ms): 4.60
Median ITL (ms): 4.41
P99 ITL (ms): 8.94
(Optional) Documentation Update