Skip to content

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Aug 14, 2025

Purpose

Don't create pooling_metadata when its not needed (overhead introduced in #21985)

cc @maxdebayser

image

Test Plan

branch   rate   req/s    median TTFT (ms) median TPOT (ms)
------   ----   -----    ---------------- ----------------
MAIN     1.0    0.99     39.36            7.29            
MAIN     10.0   9.87     42.77            9.5             
PR       1.0    0.99     38.57            7.19            
PR       10.0   9.87     43.01            9.43            

Server commands:
MAIN: 	.venv/bin/vllm serve --model meta-llama/Meta-Llama-3-8B-Instruct --host localhost --port 3333 -tp 1 --no-enable-prefix-caching --disable-log-stats --trust-remote-code 
PR:   	.venv/bin/vllm serve --model meta-llama/Meta-Llama-3-8B-Instruct --host localhost --port 3333 -tp 1 --no-enable-prefix-caching --disable-log-stats --trust-remote-code 

Client command (template):
CLIENT:	vllm bench serve --model meta-llama/Meta-Llama-3-8B-Instruct --host localhost --port 3333 --dataset-name random --random-input-len 1000 --random-output-len 100 --random-range-ratio 0 --num-prompts <req_rate * 60> --request-rate <req_rate> --save-result --result-dir <results_dir> --result-filename <filename> --ignore-eos --trust-remote-code

Test Result

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization in _init_model_kwargs by avoiding the creation of pooling_metadata when there are no pooling requests. The change correctly checks for the presence of pooling requests before creating expensive objects. However, there's a minor maintainability concern in how num_pooling_reqs is derived, which could lead to inconsistencies in the future. I've provided a suggestion to improve the robustness of this logic.

Comment on lines +344 to 351
num_pooling_reqs = len(self.input_batch.pooling_params)

if num_pooling_reqs == 0:
return model_kwargs

pooling_params = self.input_batch.pooling_metadata.pooling_params

assert num_pooling_reqs == num_reqs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While the optimization is correct, this implementation introduces a potential for future bugs. num_pooling_reqs is derived from self.input_batch.pooling_params, but the pooling_params variable that is used in the rest of the function is derived from self.input_batch.pooling_metadata. If the logic inside the pooling_metadata property changes in the future (e.g., to filter some requests), len(pooling_params) might no longer be equal to the originally computed num_pooling_reqs, which could lead to subtle issues.

To make the code more robust, it's better to derive num_pooling_reqs directly from the pooling_params list after it has been created. The initial check can be simplified to check if self.input_batch.pooling_params is empty.

Suggested change
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0:
return model_kwargs
pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs
if not self.input_batch.pooling_params:
return model_kwargs
pooling_params = self.input_batch.pooling_metadata.pooling_params
num_pooling_reqs = len(pooling_params)
assert num_pooling_reqs == num_reqs

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I think this only avoids attribute access, not creating the pooling params?

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@LucasWilkinson
Copy link
Collaborator Author

Although I think this only avoids attribute access, not creating the pooling params?

Python is sometimes a blessing and a curse haha; thats the gotcha, its actually a @property :/ so it is doing work:

@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)

@DarkLight1337
Copy link
Member

Maybe we should make this a method to avoid hiding this fact... but anyway, let's merge this PR first

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 14, 2025 05:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 14, 2025
@vllm-bot vllm-bot merged commit 829b9a6 into vllm-project:main Aug 14, 2025
49 of 53 checks passed
@maxdebayser
Copy link
Contributor

Thanks for the fix!

juuice-lee pushed a commit to juuice-lee/vllm-moe.code that referenced this pull request Aug 18, 2025
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
Gh0u1L5 pushed a commit to Gh0u1L5/vllm that referenced this pull request Aug 21, 2025
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
dumb0002 pushed a commit to dumb0002/vllm that referenced this pull request Aug 28, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants