Skip to content

[V1] Enable prefill optimization for Gemma3n #22628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

sarckk
Copy link
Collaborator

@sarckk sarckk commented Aug 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR adds an option to enable prefill optimization for Gemma3n model with --kv-sharing-fast-prefill.

Background

In You Only Cache Once (https://arxiv.org/abs/2405.05254), self-decoder layers generate KV caches while cross-decoder layers use cross-attention and reuse the shared KV cache. As only self-decoder layers generate KV caches, cross-decoder layers don't need to do prefill. Below is a figure from the YOCO paper on the prefill optimization:

Screenshot 2025-08-11 at 00 49 56

Design

In vLLM V1, the scheduler does not distinguish between prefill and decode. Instead, tokens for requests doing prefill and decode are batched together, as illustrated below (source: vLLM blog):

When we skip tokens corresponding to prefill in the cross-decoder layers, we therefore will have the batch size reduced during model forward for the cross-decoder layers:

Without optimization enabled (baseline)

Screenshot 2025-08-07 at 20 22 00

With optimization enabled (--kv-sharing-fast-prefill)

Screenshot 2025-08-07 at 20 22 09

With this change, we can no longer compile the top-level model for 2 reasons:

  1. torch.compile in vLLM assumes batch size remains the same within a single model forward. The traced graph will be specialized on the batch size, which leads to silent incorrectness if batch size changes within model forward pass.
  2. CUDA graphs are shape specialized, so we will get incorrect results.

Solution: we split the layers into self- and cross-decoder layers, and compile + graph capture them separately. For Gemma3n-E2B which has 30 layers, the first 20 layers and other 10 layers will be grouped separately into independently compiled and CUDA graph captured modules.

Other changes required in this PR:

  • Build attention metadata builder subclass for eligible layers so it can call make_kv_sharing_fast_prefill_common_attn_metadata to create an attention metadata excluding all prefill tokens. This requires passing logits_indices to CommonAttentionMetadata
  • Create a subclass of attention metadata for eligible layers which isinstance of KVSharingFastPrefillAttentionMetadata. This has two additional metadata (logits_indices_padded and num_logits_indices) which are required for indexing into hidden states in the model implementation to match the shapes that the new attention metadata expects
  • Changes to Gemma3n model implementation.
    • Need to change hidden_states shape from [altup_num_inputs, num_tokens,hidden_size] to [num_tokens,hidden_size, altup_num_inputs] to ensure num_tokens (batch size) comes at dim 0. We cannot have num_tokens be on dim=1 because creating a slice along dim=1 would a) cause torch.compile tensor stride assertions to fail, and b) resolving this by calling contiguous() on the slice would cause memory copy and therefore violate CUDA graph static address constraint.
    • If --kv-sharing-fast-prefill flag is passed, we take a different self.fast_prefill_forward() path which uses the logits_indices_padded metadata passed to index into the subset of tokens for cross-decoder layers (i.e. batch size is reduced). We then merge it back to the output of self-decoder to get the final output.
    • If --kv-sharing-fast-prefill flag is passed, we will compile self-decoder and cross-decoder submodules separately, and we also need to pre-allocate static buffers for CUDA graph replay. If it is not passed (default), we will still compile the top-level Gemma3TextModel
    • Attention group changes. After this PR, the attn groups looks for Gemma3n-2B looks like this:
- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19 
  - attn_groups[0][1]: 24, 29

- attn_groups[1]
  - attn_groups[1][0] layers: 0, 1, 2, 3

- attn_groups[2]
  - attn_groups[2][0] layers 5, 6, 7, 8

- attn_groups[3]
  - attn_groups[3][0] layers: 10, 11, 12, 13

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18
  - attn_groups[4][1] layers: 20, 21, 22, 23, 25, 26, 27, 28

Compared to trunk, the only difference is there are extra groups for attn_groups[0] and attn_groups[4] for layers which need a separate attention metadata builder for the fast prefill path. Previously it looks like this:

- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19, 24, 29 

# attn_groups[1], attn_groups[2] and attn_groups[3] same

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27, 28

Important

When prompt_logprobs is enabled, we can no longer use fast prefill optimization. This is because by skipping all but last prefill tokens, the logits for the prompt tokens will no longer be valid. For example, multiple choice question (MCQ) evals use prompt_logprobs to get the logprobs of continuation tokens (e.g. lm-evaluation-harness), so using --kv-sharing-fast-prefill will yield inaccurate results. To prevent this, we will disable half prefill for the scheduling iterations where at least one request has prompt_logprobs set.

Follow ups

Test Plan

Evals

PORT=8000
vllm serve google/gemma-3n-E2B-it --disable-log-requests
lm_eval --model local-completions --tasks $TASKNAME \
    --model_args model=google/gemma-3n-E2B-it,base_url=http://127.0.0.1:$PORT/v1/completions,num_concurrent=200,tokenized_requests=False --batch_size auto --apply_chat_template --fewshot_as_multiturn

ran gsm8k, mmlu, mmlu pro

Unit tests

pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"
pytest tests/v1/e2e/test_kv_sharing_fast_prefill.py::test_kv_sharing_fast_prefill

Performance

VLLM_DISABLE_COMPILE_CACHE=1 python -m vllm.entrypoints.openai.api_server --model google/gemma-3n-E2B-it --disable-log-requests -tp 1 --port 8000 --no-enable-prefix-caching --max-num-seqs 128 --max-model-len=32768 --max_num_batched_token=8192 --kv-sharing-fast-prefill

Perform sweep over max-concurrency and random-input-len,
$num_reqs = 256
max-num-batched-tokens = 8192
max-num-seqs = 128

python benchmarks/benchmark_serving.py     --backend vllm     --ignore-eos     --port 8000     --model google/gemma-3n-E2B-it     --dataset-name random --max-concurrency 8 --request-rate inf --num-prompts $num_reqs         --random-input-len 8192 --random-output-len 150

Test Result

Evals

Evals on par

Run gsm_8k.5-shot.strict-match mmlu_pro.5-shot.custom_extract mmlu.0-shot.acc
This PR (fast prefill) 0.5466 0.3444 0.5558 (fast prefill disabled as prompt_logprobs=1)
This PR (full prefill) 0.5413 0.3439 0.5560
Base 0.5474 0.3426 0.5560

Unit tests

Unit tests all pass

Performance

Mean TTFT and TPOT (ms):

Screenshot 2025-08-11 at 00 48 13

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for Gemma3n models by enabling a fast prefill path, inspired by the YOCO paper. The implementation is well-thought-out, involving a refactoring of the Gemma3n model into self-decoder and cross-decoder modules to work with torch.compile and dynamic batch sizes. The changes to attention metadata and the use of a conditional compilation decorator are clean solutions. Overall, the changes are robust and the associated refactoring of KV cache sharing logic improves the codebase. However, there is a critical gap in testing that should be addressed.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80013417.

Copy link

mergify bot commented Aug 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 14, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. I took a look on model runner changes. Will check gemma3m.py later.

@sarckk sarckk requested a review from yewentao256 as a code owner August 19, 2025 19:57
@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from a822572 to bcf331a Compare August 19, 2025 20:02
@mergify mergify bot removed the needs-rebase label Aug 19, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almosst LGTM. But maybe we need more discussion about custom attention abstraction @LucasWilkinson

embed_scale: torch.Tensor,
):
super().__init__()
self.decoder_layers = decoder_layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be any hidden problem when the decoder_layers are registered with both Gemma3nTextModel.layers and Gemma3nTextModel.self_decoder.decoder_layers in nn.Module? A cleaner solution would be to only register it in Gemma3nSelfDecoder (but need to update the weight loader, can do it in a follow-up PR after the model structure is finalized)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, mostly did this for simplicity as I couldn't really think of a case where it would be problematic (though there could be). I do want to separate this to a separate PR if possible

@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80013417.

Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 20, 2025
@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from 6464c15 to 21e9cac Compare August 20, 2025 18:27
@mergify mergify bot removed the needs-rebase label Aug 20, 2025
@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from 0d5a442 to 1949e8b Compare August 20, 2025 23:01
@@ -890,6 +898,9 @@ def _prepare_inputs(
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
prompt_logprobs=len(self.input_batch.num_prompt_logprobs) > 0,
Copy link
Collaborator Author

@sarckk sarckk Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson @heheda12345 WDYT? I know we discussed how we don't want the common attention metadata being a dumping ground for stuff different builders need, but in this case I'm not sure how to otherwise propagate this information from the model runner -> builder.

We need this in order to not perform the prefill optimization if at least one request requrest prompt token logprobs (as the logprobs will be invalid, similar to prefix caching). The options are:

  1. Raising an exception and letting user request crash a server (which we don't want)
  2. Silent incorrectness + warning msg (also don't want)
  3. Disable optimization and return correct logprobs at expense of slight perf regression (this solution)
  4. Don't schedule such request in first place and return an error without crashing server (AFAIK not possible to do in vLLM at the moment).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think another solution is to reject the requests with logprobs when they are adding to the server. For online infernece, it won't crash the server. For offline inference, it can crash at very beginning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
speculative-decoding tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants