-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Perf/gpt oss eagle #7353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Perf/gpt oss eagle #7353
Conversation
📝 WalkthroughWalkthroughRefactors Eagle3OneModelWorker.forward and related routines by extracting inlined computations into multiple small torch.compile-wrapped helper functions for position/gather ID calculations, draft token updates, KV-length updates, input preparation, and acceptance logic. Control flow remains equivalent; public interfaces unchanged. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Caller
participant W as Eagle3OneModelWorker
participant H as Compiled Helpers
C->>W: forward(inputs, attn_metadata, ...)
W->>H: calc_position_ids_and_last_tokens_idx()
H-->>W: position_ids, last_tokens_idx
alt first-draft path
W->>H: prepare_1st_drafter_inputs(...)
H-->>W: drafter_inputs
W->>H: compute_gather_ids(...)
H-->>W: gather_ids
else subsequent steps
W->>H: get_gather_ids(spec_metadata, batch_size)
H-->>W: gather_ids
end
W->>H: update_draft_tokens_and_inputs(...)
H-->>W: next_draft_tokens, hidden_states, position_ids
W->>H: update_kv_lens_cuda(attn_metadata, batch_size)
H-->>W: kv_lens_cuda
W->>H: prepare_next_tokens(...)
H-->>W: next_new_tokens
W-->>C: outputs
sequenceDiagram
autonumber
participant W as Eagle3OneModelWorker
participant H as Compiled Helpers
W->>H: get_num_gens_and_accepted_tokens(...)
H-->>W: num_gens, accepted_tokens_init
W->>H: process_accepted_tokens(...)
H-->>W: accepted_tokens, counters, updates
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (8)
tensorrt_llm/_torch/speculative/eagle3.py (8)
1-1
: Add NVIDIA 2025 copyright header.Per guidelines, prepend the standard NVIDIA copyright header for 2025.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + from dataclasses import dataclass, field
5-5
: Prefer module import style for torch.nn.Guidelines say “keep module namespace; import the module, not individual names.” Use
import torch.nn as nn
.-from torch import nn +import torch.nn as nn
329-343
: compute_gather_ids: semantics LGTM; minor robustness.Computation matches prior inline logic. Consider hoisting this helper to module scope to avoid redefinition per forward call and improve compile cache reuse.
380-384
: Tiny in-place increment doesn’t merit @torch.compile.This is a single in-place add; compiling it likely adds overhead and reduces readability. Inline it or hoist a shared compiled helper at module scope if measurements show benefit.
432-452
: Initialization helper looks fine; hoist for better caching.Logic mirrors existing behavior. Consider moving
get_num_gens_and_accepted_tokens
to module scope to avoid per-call redefinition and improvetorch.compile
cache reuse.
500-503
: Remove leftover debug prints and shorten the long line.Drop commented debug prints; keep the argmax line concise to satisfy E501.
- # print(f"DBG : logits.shape: {logits.shape} {logits.dtype} {logits.device} logits stride: {logits.stride()}") - draft_tokens = torch.argmax(logits, dim=-1) # [num_tokens] - # print(f"DBG : draft_tokens.shape: {draft_tokens.shape} {draft_tokens.dtype} {draft_tokens.device} draft_tokens stride: {draft_tokens.stride()}") + draft_tokens = torch.argmax(logits, dim=-1) # [num_tokens]
513-557
: Compiling input prep is reasonable; consider a brief docstring.Optional: a one-liner docstring clarifying shapes; helps future maintainers optimize further.
330-330
: Fix Ruff E501 (line length) in flagged lines.Wrap/break these lines to ≤120 chars to satisfy the linter.
Also applies to: 341-341, 432-432, 451-451, 457-457, 502-502
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/speculative/eagle3.py
(7 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,cc,cxx,cu,py,h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use spaces only; no tabs; indent by 4 spaces
Files:
tensorrt_llm/_torch/speculative/eagle3.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; no tabs
Keep module namespace on import; import the module, not individual names; use module.symbol
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; if starting with a number, prefix with k_ (e.g., k_99th_percentile)
Global variables use UPPER_SNAKE with G_ prefix (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE
Avoid shadowing variables from outer scopes
Initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with trailing docstrings
Avoid reflection when simple alternatives exist (e.g., avoid dict(**locals()) patterns)
Limit except clauses to specific exceptions; avoid bare except
When duck-typing with try/except, keep try body minimal and use else for logic
Files:
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cc,cxx,cu,h,hpp,hh,hxx,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
attn_metadata
(68-69)spec_metadata
(60-61)tensorrt_llm/_torch/attention_backend/interface.py (6)
seq_lens_cuda
(225-226)num_contexts
(205-206)num_contexts
(209-212)num_ctx_tokens
(273-274)AttentionMetadata
(39-331)num_seqs
(255-259)tensorrt_llm/_torch/speculative/mtp.py (2)
prepare_next_tokens
(1253-1262)sample_and_accept_draft_tokens
(700-858)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/eagle3.py
330-330: Line too long (151 > 120)
(E501)
341-341: Line too long (130 > 120)
(E501)
432-432: Line too long (138 > 120)
(E501)
451-451: Line too long (146 > 120)
(E501)
457-457: Line too long (161 > 120)
(E501)
502-502: Line too long (154 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (2)
425-425
: Good call: keep the whole method uncompiled.Keeping
sample_and_accept_draft_tokens
selectively compiled (and not the entire method) aligns with the Triton 3.4 constraint noted above.
456-476
: Acceptance logic parity: LGTM.Shapes and cumprod-based acceptance count match prior implementation.
@torch.compile(options={"max-autotune": True}) | ||
def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata): | ||
position_ids = position_ids.squeeze(0) | ||
last_tokens_idx = torch.cumsum( | ||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 | ||
return position_ids, last_tokens_idx | ||
|
||
position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid passing Python objects into @torch.compile; pass tensors instead.
Accessing attn_metadata.seq_lens_cuda
inside the compiled fn can trigger graph breaks and reduce cache hit rate. Pass the tensor directly.
- @torch.compile(options={"max-autotune": True})
- def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata):
- position_ids = position_ids.squeeze(0)
- last_tokens_idx = torch.cumsum(
- attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
- return position_ids, last_tokens_idx
+ @torch.compile(options={"max-autotune": True})
+ def calc_position_ids_and_last_tokens_idx(position_ids, seq_lens_cuda):
+ position_ids = position_ids.squeeze(0)
+ last_tokens_idx = torch.cumsum(seq_lens_cuda, dim=0, dtype=torch.long) - 1
+ return position_ids, last_tokens_idx
@@
- position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata)
+ position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(
+ position_ids, attn_metadata.seq_lens_cuda
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@torch.compile(options={"max-autotune": True}) | |
def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata): | |
position_ids = position_ids.squeeze(0) | |
last_tokens_idx = torch.cumsum( | |
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 | |
return position_ids, last_tokens_idx | |
position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata) | |
@torch.compile(options={"max-autotune": True}) | |
def calc_position_ids_and_last_tokens_idx(position_ids, seq_lens_cuda): | |
position_ids = position_ids.squeeze(0) | |
last_tokens_idx = torch.cumsum(seq_lens_cuda, dim=0, dtype=torch.long) - 1 | |
return position_ids, last_tokens_idx | |
position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx( | |
position_ids, attn_metadata.seq_lens_cuda | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 298 to 306, the
@torch.compile function currently accesses the Python object
attn_metadata.seq_lens_cuda inside the compiled function which can cause graph
breaks; change the compiled function signature to accept the seq_lens_cuda
tensor (e.g., seq_lens_cuda) as an explicit tensor argument, use that tensor in
torch.cumsum to compute last_tokens_idx, and call the function passing
attn_metadata.seq_lens_cuda from the outer scope; keep position_ids.squeeze(0)
as before and preserve dtype=torch.long.
def get_gather_ids(spec_metadata, batch_size): | ||
return spec_metadata.batch_indices_cuda[:batch_size] | ||
gather_ids = get_gather_ids(spec_metadata, batch_size) | ||
logits = draft_model.logits_processor(hidden_states[gather_ids], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Ensure index dtype is int64 for tensor indexing.
batch_indices_cuda
is torch.int
(int32). Advanced/fancy indexing is most robust with int64
.
- def get_gather_ids(spec_metadata, batch_size):
- return spec_metadata.batch_indices_cuda[:batch_size]
+ def get_gather_ids(spec_metadata, batch_size):
+ return spec_metadata.batch_indices_cuda[:batch_size].long()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_gather_ids(spec_metadata, batch_size): | |
return spec_metadata.batch_indices_cuda[:batch_size] | |
gather_ids = get_gather_ids(spec_metadata, batch_size) | |
logits = draft_model.logits_processor(hidden_states[gather_ids], | |
def get_gather_ids(spec_metadata, batch_size): | |
return spec_metadata.batch_indices_cuda[:batch_size].long() | |
gather_ids = get_gather_ids(spec_metadata, batch_size) | |
logits = draft_model.logits_processor(hidden_states[gather_ids], |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 346 to 349, the helper
returns spec_metadata.batch_indices_cuda which is torch.int (int32) and is used
for advanced indexing; change get_gather_ids to return the first batch_size
entries cast to int64 (e.g., .to(torch.int64) or .long()) to ensure robust
tensor indexing on CUDA while preserving device and slicing semantics.
@torch.compile(options={"max-autotune": True}) | ||
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs): | ||
next_draft_tokens.append(new_draft_token) | ||
hidden_states = hidden_states_to_save[gather_ids] | ||
position_ids = inputs["position_ids"][gather_ids] + 1 | ||
return hidden_states, position_ids | ||
|
||
hidden_states, position_ids = update_draft_tokens_and_inputs( | ||
new_draft_token, hidden_states_to_save, gather_ids, inputs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not mutate Python lists inside @torch.compile (graph break).
next_draft_tokens.append(...)
inside a compiled function introduces Python side effects that TorchDynamo can’t reliably trace/cache. Append outside; keep the compiled fn pure.
- @torch.compile(options={"max-autotune": True})
- def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
- next_draft_tokens.append(new_draft_token)
- hidden_states = hidden_states_to_save[gather_ids]
- position_ids = inputs["position_ids"][gather_ids] + 1
- return hidden_states, position_ids
-
- hidden_states, position_ids = update_draft_tokens_and_inputs(
- new_draft_token, hidden_states_to_save, gather_ids, inputs
- )
+ @torch.compile(options={"max-autotune": True})
+ def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids):
+ hidden_states = hidden_states_to_save[gather_ids]
+ position_ids = position_ids[gather_ids] + 1
+ return hidden_states, position_ids
+
+ next_draft_tokens.append(new_draft_token)
+ hidden_states, position_ids = update_draft_tokens_and_inputs(
+ hidden_states_to_save, gather_ids, inputs["position_ids"]
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@torch.compile(options={"max-autotune": True}) | |
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs): | |
next_draft_tokens.append(new_draft_token) | |
hidden_states = hidden_states_to_save[gather_ids] | |
position_ids = inputs["position_ids"][gather_ids] + 1 | |
return hidden_states, position_ids | |
hidden_states, position_ids = update_draft_tokens_and_inputs( | |
new_draft_token, hidden_states_to_save, gather_ids, inputs | |
) | |
@torch.compile(options={"max-autotune": True}) | |
def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids): | |
hidden_states = hidden_states_to_save[gather_ids] | |
position_ids = position_ids[gather_ids] + 1 | |
return hidden_states, position_ids | |
next_draft_tokens.append(new_draft_token) | |
hidden_states, position_ids = update_draft_tokens_and_inputs( | |
hidden_states_to_save, gather_ids, inputs["position_ids"] | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 354 to 363, the
compiled function update_draft_tokens_and_inputs mutates a Python list
(next_draft_tokens.append(...)) which breaks TorchDynamo tracing; remove any
list mutation from the @torch.compile function, instead have the function return
the new draft token (or an index/flag) along with hidden_states and
position_ids, then perform next_draft_tokens.append(returned_token) immediately
after calling the compiled function so the compiled function remains pure and
side-effect free.
@torch.compile(options={"max-autotune": True}) | ||
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens): | ||
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1) | ||
next_new_tokens = accepted_tokens[ | ||
spec_metadata.batch_indices_cuda[:batch_size], | ||
num_accepted_tokens - 1 | ||
].unsqueeze(1) | ||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1) | ||
return next_draft_tokens_stacked, next_new_tokens | ||
|
||
next_draft_tokens, next_new_tokens = prepare_next_tokens( | ||
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use int64 indices for row/col advanced indexing.
Cast both row and column indices to long()
to avoid dtype-related surprises across PyTorch versions/backends.
- def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
- next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
- next_new_tokens = accepted_tokens[
- spec_metadata.batch_indices_cuda[:batch_size],
- num_accepted_tokens - 1
- ].unsqueeze(1)
+ def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
+ next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
+ row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
+ col_idx = (num_accepted_tokens - 1).long()
+ next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@torch.compile(options={"max-autotune": True}) | |
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens): | |
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1) | |
next_new_tokens = accepted_tokens[ | |
spec_metadata.batch_indices_cuda[:batch_size], | |
num_accepted_tokens - 1 | |
].unsqueeze(1) | |
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1) | |
return next_draft_tokens_stacked, next_new_tokens | |
next_draft_tokens, next_new_tokens = prepare_next_tokens( | |
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens | |
) | |
@torch.compile(options={"max-autotune": True}) | |
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens): | |
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1) | |
row_idx = spec_metadata.batch_indices_cuda[:batch_size].long() | |
col_idx = (num_accepted_tokens - 1).long() | |
next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1) | |
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1) | |
return next_draft_tokens_stacked, next_new_tokens | |
next_draft_tokens, next_new_tokens = prepare_next_tokens( | |
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 401 to 413, the
advanced indexing into accepted_tokens uses spec_metadata.batch_indices_cuda and
num_accepted_tokens which may be wrong dtype on some PyTorch backends; cast both
the row and column indices to long() (int64) before indexing (e.g., use .long()/
.to(torch.long)) so the indexing operation is always performed with int64
tensors, and adjust the code to pass those cast indices into accepted_tokens and
any related indexing operations.
Summary by CodeRabbit
feat perf improvement for gpt-oss eagle3
PR add selective torch.compile decorators to get optimized kernels.