-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Perf/gpt oss eagle #7353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Perf/gpt oss eagle #7353
Changes from all commits
b0116f7
e642f66
7a95462
7093c1b
63c19db
4331b71
6b2d3f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -316,17 +316,30 @@ def forward(self, input_ids, position_ids, hidden_states, logits, | |||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for i in range(self.max_draft_len): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if i == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * | ||||||||||||||||||||||||||||||||||||||||||||||||||||
(self.max_draft_len + 1)).long() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids_gen = (start_ids_gen + | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens[num_contexts:] - 1 + | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.num_ctx_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids = torch.concat( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) # 7us saving | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def compute_gather_ids(spec_metadata, num_gens, self_max_draft_len, num_accepted_tokens, num_contexts, attn_metadata): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * | ||||||||||||||||||||||||||||||||||||||||||||||||||||
(self_max_draft_len + 1)).long() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids_gen = (start_ids_gen + | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens[num_contexts:] - 1 + | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.num_ctx_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids = torch.concat( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return gather_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids = compute_gather_ids( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
spec_metadata, num_gens, self.max_draft_len, num_accepted_tokens, num_contexts, attn_metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_gather_ids(spec_metadata, batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return spec_metadata.batch_indices_cuda[:batch_size] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gather_ids = get_gather_ids(spec_metadata, batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.guided_decoder is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
new_tokens = inputs["input_ids"][gather_ids] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -353,10 +366,17 @@ def forward(self, input_ids, position_ids, hidden_states, logits, | |||||||||||||||||||||||||||||||||||||||||||||||||||
draft_step=i) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
new_draft_token = self.draft_decoder(logits, draft_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens.append(new_draft_token) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# update inputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
hidden_states = hidden_states_to_save[gather_ids] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
position_ids = inputs["position_ids"][gather_ids] + 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens.append(new_draft_token) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
hidden_states = hidden_states_to_save[gather_ids] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
position_ids = inputs["position_ids"][gather_ids] + 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return hidden_states, position_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
hidden_states, position_ids = update_draft_tokens_and_inputs( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
new_draft_token, hidden_states_to_save, gather_ids, inputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# update attn_metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if i == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata._seq_lens[:batch_size].fill_(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -373,7 +393,10 @@ def forward(self, input_ids, position_ids, hidden_states, logits, | |||||||||||||||||||||||||||||||||||||||||||||||||||
self.max_draft_len - num_accepted_tokens[num_contexts:]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
elif hasattr(attn_metadata, 'kv_lens_cuda'): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.kv_lens_cuda[:batch_size] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def update_kv_lens_cuda(attn_metadata, batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.kv_lens_cuda[:batch_size] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
update_kv_lens_cuda(attn_metadata, batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# support attention dp | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if spec_metadata.all_rank_num_tokens is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -384,18 +407,24 @@ def forward(self, input_ids, position_ids, hidden_states, logits, | |||||||||||||||||||||||||||||||||||||||||||||||||||
"attn_metadata": attn_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"spec_metadata": spec_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# restore attn_metadata to support cuda graph | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.restore_from_spec_dec() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.on_update() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# prepare next new tokens to support overlap scheduler | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_new_tokens = accepted_tokens[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
spec_metadata.batch_indices_cuda[:batch_size], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens - 1].unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_new_tokens = accepted_tokens[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
spec_metadata.batch_indices_cuda[:batch_size], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens - 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
].unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return next_draft_tokens_stacked, next_new_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens, next_new_tokens = prepare_next_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+415
to
+427
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use int64 indices for row/col advanced indexing. Cast both row and column indices to - def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
- next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
- next_new_tokens = accepted_tokens[
- spec_metadata.batch_indices_cuda[:batch_size],
- num_accepted_tokens - 1
- ].unsqueeze(1)
+ def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
+ next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
+ row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
+ col_idx = (num_accepted_tokens - 1).long()
+ next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata.use_spec_decoding = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -407,43 +436,61 @@ def forward(self, input_ids, position_ids, hidden_states, logits, | |||||||||||||||||||||||||||||||||||||||||||||||||||
'next_new_tokens': next_new_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# @torch.compile(options={"max-autotune": True}) # dont apply on all | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def sample_and_accept_draft_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
attn_metadata: AttentionMetadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
spec_metadata: Eagle3OneModelSpecMetadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size = attn_metadata.num_seqs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_contexts = attn_metadata.num_contexts | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens = batch_size - num_contexts | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
if logits.dim() == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = logits.unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# The return buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dtype=torch.int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device=logits.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens = torch.ones(batch_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dtype=torch.int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device=logits.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) # this and below compile saves 8us ; torch.compile on argmax spoils the performance | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_num_gens_and_accepted_tokens(logits, attn_metadata, max_draft_len): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size = attn_metadata.num_seqs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_contexts = attn_metadata.num_contexts | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens = batch_size - num_contexts | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
if logits.dim() == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logits = logits.unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# The return buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens = torch.empty((batch_size, (max_draft_len + 1)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dtype=torch.int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device=logits.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens = torch.ones(batch_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dtype=torch.int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
device=logits.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
return num_gens, num_contexts, num_accepted_tokens, accepted_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens, num_contexts, num_accepted_tokens, accepted_tokens = get_num_gens_and_accepted_tokens(logits, attn_metadata, self.max_draft_len) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Do greedy sampling for the input logits | ||||||||||||||||||||||||||||||||||||||||||||||||||||
target_tokens = torch.argmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# context | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# generation | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gen_target_tokens = target_tokens[num_contexts:].reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens, self.max_draft_len + 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens[num_contexts:, :] = gen_target_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
draft_tokens = spec_metadata.draft_tokens.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens, self.max_draft_len) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens[num_contexts:] += torch.cumprod( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dim=-1).sum(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def process_accepted_tokens(target_tokens, num_contexts, num_gens, accepted_tokens, gen_target_tokens, draft_tokens, num_accepted_tokens, max_draft_len): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# context | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# generation | ||||||||||||||||||||||||||||||||||||||||||||||||||||
gen_target_tokens_reshaped = gen_target_tokens.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens, max_draft_len + 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens[num_contexts:, :] = gen_target_tokens_reshaped | ||||||||||||||||||||||||||||||||||||||||||||||||||||
draft_tokens_reshaped = draft_tokens.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_gens, max_draft_len) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens[num_contexts:] += torch.cumprod( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
(draft_tokens_reshaped == gen_target_tokens_reshaped[:, :max_draft_len]).int(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dim=-1).sum(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return accepted_tokens, num_accepted_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
accepted_tokens, num_accepted_tokens = process_accepted_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
target_tokens, num_contexts, num_gens, accepted_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
target_tokens[num_contexts:], spec_metadata.draft_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
num_accepted_tokens, self.max_draft_len | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return accepted_tokens, num_accepted_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# @torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def draft_decoder( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -464,8 +511,7 @@ def draft_decoder( | |||||||||||||||||||||||||||||||||||||||||||||||||||
[batch_size * max_draft_len] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Draft token ids. Flattened. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
''' | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
draft_tokens = torch.argmax(logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
draft_tokens = torch.argmax(logits, dim=-1) # [num_tokens] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Apply d2t (offsets between draft model dictionary and main model dictionary). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (d2t := getattr(draft_model.model, "d2t", None)) is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -475,6 +521,7 @@ def draft_decoder( | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
return draft_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.compile(options={"max-autotune": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
def prepare_1st_drafter_inputs( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
input_ids: torch.LongTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not mutate Python lists inside @torch.compile (graph break).
next_draft_tokens.append(...)
inside a compiled function introduces Python side effects that TorchDynamo can’t reliably trace/cache. Append outside; keep the compiled fn pure.📝 Committable suggestion
🤖 Prompt for AI Agents