Skip to content

Conversation

ameynaik-hub
Copy link
Collaborator

@ameynaik-hub ameynaik-hub commented Aug 28, 2025

Summary by CodeRabbit

  • Performance
    • Improved generation speed and reduced latency for speculative decoding workflows.
    • More consistent throughput under high-load scenarios.
  • Refactor
    • Internal computation steps modularized and optimized with compiled helpers to streamline execution.
    • Preserved existing behavior and public interfaces; no user-facing API changes.

feat perf improvement for gpt-oss eagle3

PR add selective torch.compile decorators to get optimized kernels.

Baseline  (50% median TPS) 822 TPS for BS1 with avg AR 2.44
after this PR  (50% median TPS) 838 TPS for BS1 with avg AR 2.44 (+2% improvement)
- e2e iteration time saving of 69.9 us 

@ameynaik-hub ameynaik-hub requested a review from a team as a code owner August 28, 2025 23:16
Copy link
Contributor

coderabbitai bot commented Aug 28, 2025

📝 Walkthrough

Walkthrough

Refactors Eagle3OneModelWorker.forward and related routines by extracting inlined computations into multiple small torch.compile-wrapped helper functions for position/gather ID calculations, draft token updates, KV-length updates, input preparation, and acceptance logic. Control flow remains equivalent; public interfaces unchanged.

Changes

Cohort / File(s) Summary
Compiled helper refactors
tensorrt_llm/_torch/speculative/eagle3.py
Moved inline logic into torch.compile-wrapped helpers: position/gather ID computation, draft token updates, KV-cache length updates, next token preparation, first-draft input prep, and acceptance sampling utilities. Updated forward and sample_and_accept_draft_tokens to use these helpers without changing external APIs.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Caller
  participant W as Eagle3OneModelWorker
  participant H as Compiled Helpers

  C->>W: forward(inputs, attn_metadata, ...)
  W->>H: calc_position_ids_and_last_tokens_idx()
  H-->>W: position_ids, last_tokens_idx

  alt first-draft path
    W->>H: prepare_1st_drafter_inputs(...)
    H-->>W: drafter_inputs
    W->>H: compute_gather_ids(...)
    H-->>W: gather_ids
  else subsequent steps
    W->>H: get_gather_ids(spec_metadata, batch_size)
    H-->>W: gather_ids
  end

  W->>H: update_draft_tokens_and_inputs(...)
  H-->>W: next_draft_tokens, hidden_states, position_ids

  W->>H: update_kv_lens_cuda(attn_metadata, batch_size)
  H-->>W: kv_lens_cuda

  W->>H: prepare_next_tokens(...)
  H-->>W: next_new_tokens
  W-->>C: outputs
Loading
sequenceDiagram
  autonumber
  participant W as Eagle3OneModelWorker
  participant H as Compiled Helpers

  W->>H: get_num_gens_and_accepted_tokens(...)
  H-->>W: num_gens, accepted_tokens_init

  W->>H: process_accepted_tokens(...)
  H-->>W: accepted_tokens, counters, updates
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • SimengLiu-nv
  • yilin-void
  • syuoni

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (8)
tensorrt_llm/_torch/speculative/eagle3.py (8)

1-1: Add NVIDIA 2025 copyright header.

Per guidelines, prepend the standard NVIDIA copyright header for 2025.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 from dataclasses import dataclass, field

5-5: Prefer module import style for torch.nn.

Guidelines say “keep module namespace; import the module, not individual names.” Use import torch.nn as nn.

-from torch import nn
+import torch.nn as nn

329-343: compute_gather_ids: semantics LGTM; minor robustness.

Computation matches prior inline logic. Consider hoisting this helper to module scope to avoid redefinition per forward call and improve compile cache reuse.


380-384: Tiny in-place increment doesn’t merit @torch.compile.

This is a single in-place add; compiling it likely adds overhead and reduces readability. Inline it or hoist a shared compiled helper at module scope if measurements show benefit.


432-452: Initialization helper looks fine; hoist for better caching.

Logic mirrors existing behavior. Consider moving get_num_gens_and_accepted_tokens to module scope to avoid per-call redefinition and improve torch.compile cache reuse.


500-503: Remove leftover debug prints and shorten the long line.

Drop commented debug prints; keep the argmax line concise to satisfy E501.

-        # print(f"DBG : logits.shape: {logits.shape} {logits.dtype} {logits.device} logits stride: {logits.stride()}")
-        draft_tokens = torch.argmax(logits, dim=-1) # [num_tokens]
-        # print(f"DBG : draft_tokens.shape: {draft_tokens.shape} {draft_tokens.dtype} {draft_tokens.device} draft_tokens stride: {draft_tokens.stride()}")
+        draft_tokens = torch.argmax(logits, dim=-1)  # [num_tokens]

513-557: Compiling input prep is reasonable; consider a brief docstring.

Optional: a one-liner docstring clarifying shapes; helps future maintainers optimize further.


330-330: Fix Ruff E501 (line length) in flagged lines.

Wrap/break these lines to ≤120 chars to satisfy the linter.

Also applies to: 341-341, 432-432, 451-451, 457-457, 502-502

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b093d94 and 7a95462.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/speculative/eagle3.py (7 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,cc,cxx,cu,py,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use spaces only; no tabs; indent by 4 spaces

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; no tabs
Keep module namespace on import; import the module, not individual names; use module.symbol
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; if starting with a number, prefix with k_ (e.g., k_99th_percentile)
Global variables use UPPER_SNAKE with G_ prefix (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE
Avoid shadowing variables from outer scopes
Initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with trailing docstrings
Avoid reflection when simple alternatives exist (e.g., avoid dict(**locals()) patterns)
Limit except clauses to specific exceptions; avoid bare except
When duck-typing with try/except, keep try body minimal and use else for logic

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cc,cxx,cu,h,hpp,hh,hxx,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/eagle3.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
  • attn_metadata (68-69)
  • spec_metadata (60-61)
tensorrt_llm/_torch/attention_backend/interface.py (6)
  • seq_lens_cuda (225-226)
  • num_contexts (205-206)
  • num_contexts (209-212)
  • num_ctx_tokens (273-274)
  • AttentionMetadata (39-331)
  • num_seqs (255-259)
tensorrt_llm/_torch/speculative/mtp.py (2)
  • prepare_next_tokens (1253-1262)
  • sample_and_accept_draft_tokens (700-858)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/eagle3.py

330-330: Line too long (151 > 120)

(E501)


341-341: Line too long (130 > 120)

(E501)


432-432: Line too long (138 > 120)

(E501)


451-451: Line too long (146 > 120)

(E501)


457-457: Line too long (161 > 120)

(E501)


502-502: Line too long (154 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (2)

425-425: Good call: keep the whole method uncompiled.

Keeping sample_and_accept_draft_tokens selectively compiled (and not the entire method) aligns with the Triton 3.4 constraint noted above.


456-476: Acceptance logic parity: LGTM.

Shapes and cumprod-based acceptance count match prior implementation.

Comment on lines +298 to +306
@torch.compile(options={"max-autotune": True})
def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata):
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
return position_ids, last_tokens_idx

position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid passing Python objects into @torch.compile; pass tensors instead.

Accessing attn_metadata.seq_lens_cuda inside the compiled fn can trigger graph breaks and reduce cache hit rate. Pass the tensor directly.

-        @torch.compile(options={"max-autotune": True})
-        def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata):
-            position_ids = position_ids.squeeze(0)
-            last_tokens_idx = torch.cumsum(
-                attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
-            return position_ids, last_tokens_idx
+        @torch.compile(options={"max-autotune": True})
+        def calc_position_ids_and_last_tokens_idx(position_ids, seq_lens_cuda):
+            position_ids = position_ids.squeeze(0)
+            last_tokens_idx = torch.cumsum(seq_lens_cuda, dim=0, dtype=torch.long) - 1
+            return position_ids, last_tokens_idx
@@
-        position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata)
+        position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(
+            position_ids, attn_metadata.seq_lens_cuda
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile(options={"max-autotune": True})
def calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata):
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
return position_ids, last_tokens_idx
position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(position_ids, attn_metadata)
@torch.compile(options={"max-autotune": True})
def calc_position_ids_and_last_tokens_idx(position_ids, seq_lens_cuda):
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(seq_lens_cuda, dim=0, dtype=torch.long) - 1
return position_ids, last_tokens_idx
position_ids, last_tokens_idx = calc_position_ids_and_last_tokens_idx(
position_ids, attn_metadata.seq_lens_cuda
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 298 to 306, the
@torch.compile function currently accesses the Python object
attn_metadata.seq_lens_cuda inside the compiled function which can cause graph
breaks; change the compiled function signature to accept the seq_lens_cuda
tensor (e.g., seq_lens_cuda) as an explicit tensor argument, use that tensor in
torch.cumsum to compute last_tokens_idx, and call the function passing
attn_metadata.seq_lens_cuda from the outer scope; keep position_ids.squeeze(0)
as before and preserve dtype=torch.long.

Comment on lines +346 to 349
def get_gather_ids(spec_metadata, batch_size):
return spec_metadata.batch_indices_cuda[:batch_size]
gather_ids = get_gather_ids(spec_metadata, batch_size)
logits = draft_model.logits_processor(hidden_states[gather_ids],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure index dtype is int64 for tensor indexing.

batch_indices_cuda is torch.int (int32). Advanced/fancy indexing is most robust with int64.

-                def get_gather_ids(spec_metadata, batch_size):
-                    return spec_metadata.batch_indices_cuda[:batch_size]
+                def get_gather_ids(spec_metadata, batch_size):
+                    return spec_metadata.batch_indices_cuda[:batch_size].long()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_gather_ids(spec_metadata, batch_size):
return spec_metadata.batch_indices_cuda[:batch_size]
gather_ids = get_gather_ids(spec_metadata, batch_size)
logits = draft_model.logits_processor(hidden_states[gather_ids],
def get_gather_ids(spec_metadata, batch_size):
return spec_metadata.batch_indices_cuda[:batch_size].long()
gather_ids = get_gather_ids(spec_metadata, batch_size)
logits = draft_model.logits_processor(hidden_states[gather_ids],
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 346 to 349, the helper
returns spec_metadata.batch_indices_cuda which is torch.int (int32) and is used
for advanced indexing; change get_gather_ids to return the first batch_size
entries cast to int64 (e.g., .to(torch.int64) or .long()) to ensure robust
tensor indexing on CUDA while preserving device and slicing semantics.

Comment on lines +354 to +363
@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
next_draft_tokens.append(new_draft_token)
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
return hidden_states, position_ids

hidden_states, position_ids = update_draft_tokens_and_inputs(
new_draft_token, hidden_states_to_save, gather_ids, inputs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Do not mutate Python lists inside @torch.compile (graph break).

next_draft_tokens.append(...) inside a compiled function introduces Python side effects that TorchDynamo can’t reliably trace/cache. Append outside; keep the compiled fn pure.

-            @torch.compile(options={"max-autotune": True})
-            def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
-                next_draft_tokens.append(new_draft_token)
-                hidden_states = hidden_states_to_save[gather_ids]
-                position_ids = inputs["position_ids"][gather_ids] + 1
-                return hidden_states, position_ids
-
-            hidden_states, position_ids = update_draft_tokens_and_inputs(
-                new_draft_token, hidden_states_to_save, gather_ids, inputs
-            )
+            @torch.compile(options={"max-autotune": True})
+            def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids):
+                hidden_states = hidden_states_to_save[gather_ids]
+                position_ids = position_ids[gather_ids] + 1
+                return hidden_states, position_ids
+
+            next_draft_tokens.append(new_draft_token)
+            hidden_states, position_ids = update_draft_tokens_and_inputs(
+                hidden_states_to_save, gather_ids, inputs["position_ids"]
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(new_draft_token, hidden_states_to_save, gather_ids, inputs):
next_draft_tokens.append(new_draft_token)
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
return hidden_states, position_ids
hidden_states, position_ids = update_draft_tokens_and_inputs(
new_draft_token, hidden_states_to_save, gather_ids, inputs
)
@torch.compile(options={"max-autotune": True})
def update_draft_tokens_and_inputs(hidden_states_to_save, gather_ids, position_ids):
hidden_states = hidden_states_to_save[gather_ids]
position_ids = position_ids[gather_ids] + 1
return hidden_states, position_ids
next_draft_tokens.append(new_draft_token)
hidden_states, position_ids = update_draft_tokens_and_inputs(
hidden_states_to_save, gather_ids, inputs["position_ids"]
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 354 to 363, the
compiled function update_draft_tokens_and_inputs mutates a Python list
(next_draft_tokens.append(...)) which breaks TorchDynamo tracing; remove any
list mutation from the @torch.compile function, instead have the function return
the new draft token (or an index/flag) along with hidden_states and
position_ids, then perform next_draft_tokens.append(returned_token) immediately
after calling the compiled function so the compiled function remains pure and
side-effect free.

Comment on lines +401 to +413
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1
].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens

next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use int64 indices for row/col advanced indexing.

Cast both row and column indices to long() to avoid dtype-related surprises across PyTorch versions/backends.

-        def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
-            next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
-            next_new_tokens = accepted_tokens[
-                spec_metadata.batch_indices_cuda[:batch_size],
-                num_accepted_tokens - 1
-            ].unsqueeze(1)
+        def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
+            next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
+            row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
+            col_idx = (num_accepted_tokens - 1).long()
+            next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
             next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
             return next_draft_tokens_stacked, next_new_tokens
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1
].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens_stacked = torch.stack(next_draft_tokens, dim=1)
row_idx = spec_metadata.batch_indices_cuda[:batch_size].long()
col_idx = (num_accepted_tokens - 1).long()
next_new_tokens = accepted_tokens[row_idx, col_idx].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens_stacked], dim=1)
return next_draft_tokens_stacked, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/eagle3.py around lines 401 to 413, the
advanced indexing into accepted_tokens uses spec_metadata.batch_indices_cuda and
num_accepted_tokens which may be wrong dtype on some PyTorch backends; cast both
the row and column indices to long() (int64) before indexing (e.g., use .long()/
.to(torch.long)) so the indexing operation is always performed with int64
tensors, and adjust the code to pass those cast indices into accepted_tokens and
any related indexing operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant