Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 96 additions & 49 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,30 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
next_draft_tokens = []
for i in range(self.max_draft_len):
if i == 0:
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
dim=0)
@torch.compile(options={"max-autotune": True}) # 7us saving
def compute_gather_ids(spec_metadata, num_gens, self_max_draft_len, num_accepted_tokens, num_contexts, attn_metadata):
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self_max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
dim=0)
return gather_ids

gather_ids = compute_gather_ids(
spec_metadata, num_gens, self.max_draft_len, num_accepted_tokens, num_contexts, attn_metadata
)

else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]

@torch.compile(options={"max-autotune": True})
def get_gather_ids(spec_metadata, batch_size):
return spec_metadata.batch_indices_cuda[:batch_size]
gather_ids = get_gather_ids(spec_metadata, batch_size)


if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
Expand All @@ -353,10 +366,17 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
draft_step=i)

new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1

@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
next_draft_tokens.append(new_draft_token)
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
return hidden_states, position_ids

hidden_states, position_ids = update_draft_tokens_and_inputs(
new_draft_token, hidden_states_to_save, gather_ids, inputs
)
Comment on lines +370 to +379
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Do not mutate Python lists inside @torch.compile (graph break).

next_draft_tokens.append(...) inside a compiled function introduces Python side effects that TorchDynamo can’t reliably trace/cache. Append outside; keep the compiled fn pure.

-            @torch.compile(options={"max-autotune": True})
-            def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
-                next_draft_tokens.append(new_draft_token)
-                hidden_states = hidden_states_to_save[gather_ids]
-                position_ids = inputs["position_ids"][gather_ids] + 1
-                return hidden_states, position_ids
-
-            hidden_states, position_ids = update_draft_tokens_and_inputs(
-                new_draft_token, hidden_states_to_save, gather_ids, inputs
-            )
+            @torch.compile(options={"max-autotune": True})
+            def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids):
+                hidden_states = hidden_states_to_save[gather_ids]
+                position_ids = position_ids[gather_ids] + 1
+                return hidden_states, position_ids
+
+            next_draft_tokens.append(new_draft_token)
+            hidden_states, position_ids = update_draft_tokens_and_inputs(
+                hidden_states_to_save, gather_ids, inputs["position_ids"]
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
next_draft_tokens.append(new_draft_token)
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
return hidden_states, position_ids
hidden_states, position_ids = update_draft_tokens_and_inputs(
new_draft_token, hidden_states_to_save, gather_ids, inputs
)
@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids):
hidden_states = hidden_states_to_save[gather_ids]
position_ids = position_ids[gather_ids] + 1
return hidden_states, position_ids
next_draft_tokens.append(new_draft_token)
hidden_states, position_ids = update_draft_tokens_and_inputs(
hidden_states_to_save, gather_ids, inputs["position_ids"]
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 354 to 363, the
compiled function update_draft_tokens_and_inputs mutates a Python list
(next_draft_tokens.append(...)) which breaks TorchDynamo tracing; remove any
list mutation from the @torch.compile function, instead have the function return
the new draft token (or an index/flag) along with hidden_states and
position_ids, then perform next_draft_tokens.append(returned_token) immediately
after calling the compiled function so the compiled function remains pure and
side-effect free.

# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
Expand All @@ -373,7 +393,10 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
self.max_draft_len - num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
@torch.compile(options={"max-autotune": True})
def update_kv_lens_cuda(attn_metadata, batch_size):
attn_metadata.kv_lens_cuda[:batch_size] += 1
update_kv_lens_cuda(attn_metadata, batch_size)
# support attention dp
if spec_metadata.all_rank_num_tokens is not None:
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
Expand All @@ -384,18 +407,24 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)

# restore attn_metadata to support cuda graph
attn_metadata.restore_from_spec_dec()
attn_metadata.on_update()

# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1
].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens

next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
Comment on lines +415 to +427
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use int64 indices for row/col advanced indexing.

Cast both row and column indices to long() to avoid dtype-related surprises across PyTorch versions/backends.

-        def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
-            next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
-            next_new_tokens = accepted_tokens[
-                spec_metadata.batch_indices_cuda[:batch_size],
-                num_accepted_tokens - 1
-            ].unsqueeze(1)
+        def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
+            next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
+            row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
+            col_idx = (num_accepted_tokens - 1).long()
+            next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
             next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
             return next_draft_tokens_stacked, next_new_tokens
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1
].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
col_idx = (num_accepted_tokens - 1).long()
next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 401 to 413, the
advanced indexing into accepted_tokens uses spec_metadata.batch_indices_cuda and
num_accepted_tokens which may be wrong dtype on some PyTorch backends; cast both
the row and column indices to long() (int64) before indexing (e.g., use .long()/
.to(torch.long)) so the indexing operation is always performed with int64
tensors, and adjust the code to pass those cast indices into accepted_tokens and
any related indexing operations.


attn_metadata.use_spec_decoding = True

Expand All @@ -407,43 +436,61 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
'next_new_tokens': next_new_tokens,
}

# @torch.compile(options={"max-autotune": True}) # dont apply on all
def sample_and_accept_draft_tokens(
self,
logits: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: Eagle3OneModelSpecMetadata,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts

if logits.dim() == 1:
logits = logits.unsqueeze(0)

# The return buffer
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
@torch.compile(options={"max-autotune": True}) # this and below compile saves 8us ; torch.compile on argmax spoils the performance
def get_num_gens_and_accepted_tokens(logits, attn_metadata, max_draft_len):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts

if logits.dim() == 1:
logits = logits.unsqueeze(0)

# The return buffer
accepted_tokens = torch.empty((batch_size, (max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)

return num_gens, num_contexts, num_accepted_tokens, accepted_tokens

num_gens, num_contexts, num_accepted_tokens, accepted_tokens = get_num_gens_and_accepted_tokens(logits, attn_metadata, self.max_draft_len)

# Do greedy sampling for the input logits
target_tokens = torch.argmax(logits, dim=-1)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]

# generation
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, self.max_draft_len + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, self.max_draft_len)
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
dim=-1).sum(1)

@torch.compile(options={"max-autotune": True})
def process_accepted_tokens(target_tokens, num_contexts, num_gens, accepted_tokens, gen_target_tokens, draft_tokens, num_accepted_tokens, max_draft_len):
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]

# generation
gen_target_tokens_reshaped = gen_target_tokens.reshape(
num_gens, max_draft_len + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens_reshaped
draft_tokens_reshaped = draft_tokens.reshape(
num_gens, max_draft_len)
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens_reshaped == gen_target_tokens_reshaped[:, :max_draft_len]).int(),
dim=-1).sum(1)
return accepted_tokens, num_accepted_tokens

accepted_tokens, num_accepted_tokens = process_accepted_tokens(
target_tokens, num_contexts, num_gens, accepted_tokens,
target_tokens[num_contexts:], spec_metadata.draft_tokens,
num_accepted_tokens, self.max_draft_len
)
return accepted_tokens, num_accepted_tokens

# @torch.compile(options={"max-autotune": True})
def draft_decoder(
self,
logits: torch.Tensor,
Expand All @@ -464,8 +511,7 @@ def draft_decoder(
[batch_size * max_draft_len]
Draft token ids. Flattened.
'''

draft_tokens = torch.argmax(logits, dim=-1)
draft_tokens = torch.argmax(logits, dim=-1) # [num_tokens]

# Apply d2t (offsets between draft model dictionary and main model dictionary).
if (d2t := getattr(draft_model.model, "d2t", None)) is not None:
Expand All @@ -475,6 +521,7 @@ def draft_decoder(

return draft_tokens

@torch.compile(options={"max-autotune": True})
def prepare_1st_drafter_inputs(
self,
input_ids: torch.LongTensor,
Expand Down
Loading