Skip to content

Conversation

dongxuy04
Copy link
Collaborator

@dongxuy04 dongxuy04 commented Aug 28, 2025

Summary by CodeRabbit

  • New Features

    • Added optional setting to enable LM tensor-parallelism within Attention Data Parallel, improving scalability on multi-GPU setups.
    • Enhanced distributed generation: pre-LM gather and cross-rank logits handling for more consistent outputs under ADP + LM TP.
    • Speculative decoding updated to support ADP + LM TP paths for draft token selection.
  • Chores

    • Introduced additional runtime logs (shapes, ranks, paths) to aid debugging.
  • Public API

    • New user-facing argument to toggle LM TP in ADP is available in configuration and mapped through engine setup.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Njuapp and others added 2 commits August 28, 2025 03:39
Copy link
Contributor

coderabbitai bot commented Aug 28, 2025

📝 Walkthrough

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_utils.py (1)

398-406: Ensure embed_tokens is created with the same mapping/TP settings as lm_head under ADP+LM-TP when tie_word_embeddings=True.
Currently most models do

if config.mapping.is_first_pp_rank():
    self.vocab_embedding = Embedding(…, dtype=config.dtype)

omitting mapping=config.mapping (and tensor_parallel_mode=…), so embed_tokens.tp_size==1 even when lm_head.tp_size>1, triggering the asserts in modeling_utils.py. Pass the same mapping (and TP mode) into every Embedding(…) call for vocab/embed_tokens under ADP+LM-TP, or disable tie_word_embeddings when TP settings differ.

🧹 Nitpick comments (8)
tensorrt_llm/_torch/models/modeling_speculative.py (1)

436-440: Remove noisy debug prints; gate with a debug flag if needed.

Printing per-step in distributed inference is costly and pollutes logs. Prefer logger with a runtime flag (e.g., enable_llm_debug()) or remove entirely.

Apply:

-            # print(f"lm_head.weight.data_ptr: {self.lm_head.weight.data_ptr()}")
-            # print(f"lm_head.weight.shape: {self.lm_head.weight.shape}")
-            print(f"In SpecDecOneEngineForCausalLM, before spec_worker, logits.shape: {logits.shape}")
-            # print(f"draft_model.lm_head.weight.data_ptr: {self.draft_model.lm_head.weight.data_ptr()}")
-            # print(f"draft_model.lm_head.weight.shape: {self.draft_model.lm_head.weight.shape}")
+            # Debug: add gated logging here if absolutely necessary.
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

167-167: Remove redundant import inside hot path.

allgather is already imported at module scope; avoid re-import in forward.

Apply:

-            from ..distributed import allgather

1-27: Add NVIDIA 2025 SPDX header per repo guidelines.

This file lacks the required NVIDIA SPDX header. Keep third‑party license block, but prepend the NVIDIA header.

Apply at file top:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/modules/embedding.py (2)

72-74: Remove unconditional prints; they sync/slow and violate lint (E501).

Use a logger gated by a runtime flag, or drop them.

Apply:

-        print('*' * 50)
-        print(f"enable_attention_dp: {mapping.enable_attention_dp}, enable_lm_tp_in_adp: {getattr(mapping, 'enable_lm_tp_in_adp', False)}")
-        print(f"In LMHead, weight_shape: {weight_shape}")
+        # Optional: add gated debug logging here if necessary.

1-1: Add NVIDIA 2025 SPDX header per project policy.

File lacks required header.

Apply at file top:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/models/modeling_utils.py (2)

355-364: LMHead ADP/LM‑TP gating looks good; drop prints.

Creation paths match the new flag. Replace prints with logger or remove.

Apply:

-        if config.mapping.enable_attention_dp and not getattr(config.mapping, 'enable_lm_tp_in_adp', False):
-            print(f"In DecoderModelForCausalLM, creating LMHead without TP")
+        if config.mapping.enable_attention_dp and not getattr(config.mapping, 'enable_lm_tp_in_adp', False):
             self.lm_head = LMHead(
@@
-        else:
-            print(f"In DecoderModelForCausalLM, creating LMHead with TP")
+        else:
             # TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig

1-1: Add NVIDIA 2025 SPDX header per policy.

Header missing.

Apply at file top:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/speculative/mtp.py (1)

1122-1143: ADP + LM TP handling looks correct but needs optimization.

The new branch for ADP + LM TP correctly:

  1. Gathers logits across all ranks
  2. Reshapes based on TP configuration
  3. Slices by current rank
  4. Computes argmax on the sliced data

However, the approach could be optimized to reduce memory usage.

Consider using the same optimized approach as the non-ADP branch (lines 1115-1121) that combines local max values with indices before gathering, rather than gathering all logits which is more memory intensive.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 08f9356 and 158f92d.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_utils.py (2 hunks)
  • tensorrt_llm/_torch/modules/embedding.py (2 hunks)
  • tensorrt_llm/_torch/modules/logits_processor.py (2 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (6 hunks)
  • tensorrt_llm/llmapi/llm_args.py (4 hunks)
  • tensorrt_llm/mapping.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespaces in imports: import the subpackage/module, not the symbol (from package.subpackage import foo; foo.SomeClass())
Naming: files snake_case; classes PascalCase; functions/methods snake_case; local variables snake_case (k_ prefix if starting with a number); globals G_ + UPPER_SNAKE_CASE; constants UPPER_SNAKE_CASE
Avoid shadowing outer-scope variables; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; reserve comments for function-internal or file-local interfaces
Use Google-style docstrings for classes and functions; inline docstrings for attributes/variables are allowed
Avoid reflection when straightforward code suffices (e.g., prefer explicit parameters over dict(**locals()))
Use narrow except clauses (e.g., catch FileNotFoundError instead of bare except)
For duck-typing try/except, keep try body minimal and use else for the main logic

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/modules/embedding.py
  • tensorrt_llm/mapping.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/modules/logits_processor.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/speculative/mtp.py
**/*.{cpp,cc,cxx,cu,h,hpp,hh,hxx,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header with current year to all source files

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/modules/embedding.py
  • tensorrt_llm/mapping.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/modules/logits_processor.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/speculative/mtp.py
🧠 Learnings (1)
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/models/modeling_utils.py (2)
tensorrt_llm/_torch/modules/embedding.py (1)
  • LMHead (15-118)
tensorrt_llm/_torch/modules/logits_processor.py (1)
  • LogitsProcessor (10-75)
tensorrt_llm/_torch/modules/logits_processor.py (4)
tensorrt_llm/_torch/attention_backend/interface.py (1)
  • AttentionMetadata (39-328)
tensorrt_llm/_torch/modules/linear.py (1)
  • Linear (1495-1704)
tensorrt_llm/_torch/distributed/communicator.py (1)
  • tp_size (46-47)
tensorrt_llm/mapping.py (1)
  • tp_rank (340-341)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)
tensorrt_llm/_torch/distributed/ops.py (1)
  • allgather (141-225)
tensorrt_llm/functional.py (1)
  • allgather (4142-4226)
tensorrt_llm/_torch/speculative/mtp.py (4)
tensorrt_llm/_torch/distributed/communicator.py (4)
  • tp_size (46-47)
  • allgather (94-95)
  • allgather (107-108)
  • tp_rank (54-55)
tensorrt_llm/_torch/distributed/ops.py (1)
  • allgather (141-225)
cpp/tensorrt_llm/thop/allgatherOp.cpp (2)
  • allgather (122-137)
  • allgather (122-122)
tensorrt_llm/mapping.py (1)
  • tp_rank (340-341)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/embedding.py

73-73: Line too long (139 > 120)

(E501)

tensorrt_llm/_torch/speculative/mtp.py

865-865: Line too long (134 > 120)

(E501)


868-868: Line too long (158 > 120)

(E501)


1139-1139: Line too long (140 > 120)

(E501)

🔇 Additional comments (13)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

171-177: Confirm downstream handles sharded logits correctly.

You temporarily set lm_head.gather_output=False, compute logits, then re-enable it, but the returned logits here are still shard-local. Ensure a later stage all-gathers logits along vocab or keep gather_output=True here.

Would you like me to scan for consumers of DeepseekV3MTPHead.forward to verify a subsequent vocab-dim all-gather?

tensorrt_llm/_torch/modules/embedding.py (1)

40-42: Gating LM TP under ADP is correct.

The guard allows LM TP only when explicitly enabled. Good defensive default.

tensorrt_llm/_torch/models/modeling_utils.py (1)

407-407: Init LogitsProcessor with config: OK.

Constructor change is reflected here; no issues.

tensorrt_llm/llmapi/llm_args.py (3)

228-228: LGTM! The new field enables LM tensor parallelism in attention data parallel mode.

The addition of enable_lm_tp_in_adp aligns with the PR objectives to enable LM TP for MTP under attention DP case.


1251-1255: Good addition for enabling LM TP in attention DP mode.

The field is correctly marked with status="beta" to indicate it's not yet stable, and the description is clear.


292-292: Proper propagation of the new flag to Mapping.

The enable_lm_tp_in_adp flag is correctly propagated from the internal _ParallelConfig to the public Mapping object, ensuring downstream components can access this configuration.

Also applies to: 1503-1503

tensorrt_llm/_torch/modules/logits_processor.py (5)

12-14: Constructor properly accepts ModelConfig for distributed logic paths.

The addition of model_config parameter enables the LogitsProcessor to access mapping configuration for ADP LM-TP scenarios.


33-34: Token count calculation looks correct.

The computation properly handles the reshaping to get the batch dimension.


36-52: All-gather logic for ADP + LM TP looks correct.

The implementation correctly:

  1. Checks for both enable_attention_dp and enable_lm_tp_in_adp flags
  2. Pads hidden states to match max tokens across ranks
  3. Performs all-gather operation along the token dimension

Good use of getattr with a default value for backward compatibility.


55-63: Clean toggling of gather_output flag.

The logic correctly disables gather_output during LM head forward when appropriate, preventing double gathering.


67-72: Tensor manipulation for TP looks correct.

The reshaping and slicing logic properly handles the distributed tensor layout.

tensorrt_llm/_torch/speculative/mtp.py (2)

477-493: Proper padding of hidden states for uniform token count.

The padding logic correctly aligns hidden states to all_rank_max_num_tokens before passing to the shared head, ensuring consistent tensor dimensions across ranks.


1246-1246: Properly passes iter parameter to draft_sampler.

The change correctly passes the iteration index to the updated draft_sampler method.

Comment on lines +163 to +169
# Add pre-lm gather logic
if (self.model_config.mapping.enable_attention_dp and
getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# ADP + LM TP mode: perform All-Gather before LM_head
from ..distributed import allgather
hidden_states = allgather(hidden_states, self.model_config.mapping, dim=0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential shape mismatch in ADP all-gather by passing per-rank sizes.

Without sizes, allgather requires identical lengths across ranks. In ADP, per-rank token counts can differ; this risks incorrect results or runtime assertions.

Apply:

-        # Add pre-lm gather logic
-        if (self.model_config.mapping.enable_attention_dp and 
-            getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
-            # ADP + LM TP mode: perform All-Gather before LM_head
-            from ..distributed import allgather
-            hidden_states = allgather(hidden_states, self.model_config.mapping, dim=0)
+        # Add pre-lm gather logic
+        if (self.model_config.mapping.enable_attention_dp and
+            getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
+            # ADP + LM TP mode: perform All-Gather before LM_head with per-rank sizes
+            hidden_states = allgather(
+                hidden_states,
+                self.model_config.mapping,
+                dim=0,
+                sizes=getattr(attn_metadata, "all_rank_num_tokens", None),
+            )

Committable suggestion skipped: line range outside the PR's diff.

logits = logits.view(self.model_config.mapping.tp_size,
local_batch_size, -1)
logits = logits[self.model_config.mapping.tp_rank][:token_count]
print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove or guard debug print statement.

Debug print statements should not be left in production code.

-        print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
+        # Uncomment for debugging:
+        # print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
# Uncomment for debugging:
# print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/logits_processor.py around line 73, remove the
stray debug print statement or replace it with a guarded debug-level log: either
delete the print entirely or call the module/logger debug method (e.g.,
logger.debug) or wrap it behind a runtime debug/config flag so that it only runs
in development; ensure no direct stdout prints remain in production code.

Comment on lines +858 to +869
print(
f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}"
)

# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]

print(
f"In sample_and_accept_draft_tokens, accepted_tokens.shape: {accepted_tokens.shape}, num_contexts: {num_contexts}"
)
print(
f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}, num_gens: {num_gens}, mtp_num_modules: {mtp_num_modules}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statements from production code.

Multiple debug print statements should be removed or properly guarded.

-                print(
-                    f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}"
-                )
+                # Uncomment for debugging:
+                # print(f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}")

                 # context
                 accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
-                print(
-                    f"In sample_and_accept_draft_tokens, accepted_tokens.shape: {accepted_tokens.shape}, num_contexts: {num_contexts}"
-                )
-                print(
-                    f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}, num_gens: {num_gens}, mtp_num_modules: {mtp_num_modules}"
-                )
+                # Uncomment for debugging:
+                # print(f"In sample_and_accept_draft_tokens, accepted_tokens.shape: {accepted_tokens.shape}, num_contexts: {num_contexts}")
+                # print(f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}, num_gens: {num_gens}, mtp_num_modules: {mtp_num_modules}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(
f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}"
)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
print(
f"In sample_and_accept_draft_tokens, accepted_tokens.shape: {accepted_tokens.shape}, num_contexts: {num_contexts}"
)
print(
f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}, num_gens: {num_gens}, mtp_num_modules: {mtp_num_modules}"
)
# Uncomment for debugging:
# print(f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}")
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
# Uncomment for debugging:
# print(f"In sample_and_accept_draft_tokens, accepted_tokens.shape: {accepted_tokens.shape}, num_contexts: {num_contexts}")
# print(f"In sample_and_accept_draft_tokens, target_tokens.shape: {target_tokens.shape}, num_gens: {num_gens}, mtp_num_modules: {mtp_num_modules}")
🧰 Tools
🪛 Ruff (0.12.2)

865-865: Line too long (134 > 120)

(E501)


868-868: Line too long (158 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/mtp.py around lines 858 to 869, there are
multiple raw print() debug statements; remove them or replace them with a proper
logger.debug() (or guard them behind an existing debug flag) so production code
does not emit uncontrolled stdout. Specifically, delete the three print(...)
calls or change them to process-wide logger.debug(...) with contextual messages
including variable values, or wrap them in if self.debug / if debug_mode: ... so
they only run in debug mode; ensure any used logger is already configured and
imported.

@@ -1074,6 +1097,7 @@ def get_draft_tokens_from_gathered(self, gathered):
def draft_sampler(
self,
logits: torch.Tensor,
iter: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Breaking change: draft_sampler now requires iter parameter.

The addition of the iter parameter changes the public API of draft_sampler. Ensure all callers are updated accordingly.


🏁 Script executed:

#!/bin/bash
# Search for all calls to draft_sampler to ensure they pass the iter parameter
rg -n "draft_sampler\(" --type py

Length of output: 306


Include the new iter argument in all draft_sampler calls
tensorrt_llm/_torch/speculative/mtp.py:492 — update the call from

self.draft_sampler(logits)

to

self.draft_sampler(logits, i)
🤖 Prompt for AI Agents
tensorrt_llm/_torch/speculative/mtp.py around lines 492 and anywhere else where
draft_sampler is invoked: the new function signature includes a second parameter
iter: int, so update all calls like self.draft_sampler(logits) to pass the
current loop/index variable (e.g., self.draft_sampler(logits, i) or
self.draft_sampler(logits, iter) depending on your local variable name). Ensure
the variable you pass exists in scope and matches the intended iteration
counter, and update any tests/call sites accordingly.

Comment on lines +1122 to +1143
elif (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size > 1) and (
self.model_config.mapping.enable_attention_dp and getattr(
self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# For ADP + LM TP mode, we need to find the global argmax across all TP ranks
# First, get local argmax and max values
# print(f"In draft_sampler, initial logits.shape: {logits.shape}")
# combined = self.get_local_max_and_combined(logits)
gathered = allgather(logits, self.model_config.mapping, dim=-1)
batch_size = logits.shape[0]
local_batch_size = batch_size // self.model_config.mapping.tp_size
gathered = gathered.view(self.model_config.mapping.tp_size,
local_batch_size, -1)
sliced_gathered = gathered[self.model_config.mapping.tp_rank]
# print(f"In draft_sampler, gathered.shape: {gathered.shape}")
print(
f"In draft_sampler, iter: {iter}, rank: {self.model_config.mapping.tp_rank}, sliced_gathered.shape: {sliced_gathered.shape}"
)
# draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
draft_tokens = torch.argmax(sliced_gathered,
dim=-1).type(torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statement.

Line 1139 contains a debug print that should be removed.

-            print(
-                f"In draft_sampler, iter: {iter}, rank: {self.model_config.mapping.tp_rank}, sliced_gathered.shape: {sliced_gathered.shape}"
-            )
+            # Uncomment for debugging:
+            # print(f"In draft_sampler, iter: {iter}, rank: {self.model_config.mapping.tp_rank}, sliced_gathered.shape: {sliced_gathered.shape}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size > 1) and (
self.model_config.mapping.enable_attention_dp and getattr(
self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# For ADP + LM TP mode, we need to find the global argmax across all TP ranks
# First, get local argmax and max values
# print(f"In draft_sampler, initial logits.shape: {logits.shape}")
# combined = self.get_local_max_and_combined(logits)
gathered = allgather(logits, self.model_config.mapping, dim=-1)
batch_size = logits.shape[0]
local_batch_size = batch_size // self.model_config.mapping.tp_size
gathered = gathered.view(self.model_config.mapping.tp_size,
local_batch_size, -1)
sliced_gathered = gathered[self.model_config.mapping.tp_rank]
# print(f"In draft_sampler, gathered.shape: {gathered.shape}")
print(
f"In draft_sampler, iter: {iter}, rank: {self.model_config.mapping.tp_rank}, sliced_gathered.shape: {sliced_gathered.shape}"
)
# draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
draft_tokens = torch.argmax(sliced_gathered,
dim=-1).type(torch.int32)
elif (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size > 1) and (
self.model_config.mapping.enable_attention_dp and getattr(
self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# For ADP + LM TP mode, we need to find the global argmax across all TP ranks
# First, get local argmax and max values
# print(f"In draft_sampler, initial logits.shape: {logits.shape}")
# combined = self.get_local_max_and_combined(logits)
gathered = allgather(logits, self.model_config.mapping, dim=-1)
batch_size = logits.shape[0]
local_batch_size = batch_size // self.model_config.mapping.tp_size
gathered = gathered.view(self.model_config.mapping.tp_size,
local_batch_size, -1)
sliced_gathered = gathered[self.model_config.mapping.tp_rank]
# print(f"In draft_sampler, gathered.shape: {gathered.shape}")
# Uncomment for debugging:
# print(f"In draft_sampler, iter: {iter}, rank: {self.model_config.mapping.tp_rank}, sliced_gathered.shape: {sliced_gathered.shape}")
# draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
draft_tokens = torch.argmax(sliced_gathered,
dim=-1).type(torch.int32)
🧰 Tools
🪛 Ruff (0.12.2)

1139-1139: Line too long (140 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/mtp.py around lines 1122 to 1143, remove the
debug print statement at line 1139 (the print starting with "In draft_sampler,
iter: {iter}...") — delete that print call and, if needed for future debugging,
replace it with a proper logger.debug call using the module/logger used
elsewhere; otherwise simply remove the line so no raw print remains.

Comment on lines +145 to 146
enable_lm_tp_in_adp=False):
# set default values for non-moe cases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Include new flag in equality/hash/serialization; otherwise caches/config round-trips break.

enable_lm_tp_in_adp is stored but omitted from eq, hash, and to_dict. Two Mapping instances with different values will compare equal and serialize without the flag.

Apply:

@@
     def __eq__(self, other):
@@
-        return (self.world_size == other.world_size and self.rank == other.rank
+        return (self.world_size == other.world_size and self.rank == other.rank
                 and self.gpus_per_node == other.gpus_per_node
                 and self.cp_size == other.cp_size
                 and self.tp_size == other.tp_size
                 and self.moe_cluster_size == other.moe_cluster_size
                 and self.pp_size == other.pp_size
                 and self.moe_tp_size == other.moe_tp_size
                 and self.moe_ep_size == other.moe_ep_size
                 and self.attn_tp_size == other.attn_tp_size
                 and self.attn_cp_size == other.attn_cp_size
                 and self.cp_config == other.cp_config
-                and self.auto_parallel == other.auto_parallel)
+                and self.auto_parallel == other.auto_parallel
+                and self.enable_attention_dp == other.enable_attention_dp
+                and self.enable_lm_tp_in_adp == other.enable_lm_tp_in_adp)
@@
     def __hash__(self):
         return hash((
             self.world_size,
             self.rank,
             self.gpus_per_node,
             self.cp_size,
             self.tp_size,
             self.pp_size,
             self.moe_tp_size,
             self.moe_cluster_size,
             self.moe_ep_size,
             self.attn_tp_size,
             self.attn_cp_size,
             # note: we do not allow updating cp_config after initialization
             tuple(sorted(self.cp_config.items())),
             self.auto_parallel,
+            self.enable_attention_dp,
+            self.enable_lm_tp_in_adp,
         ))
@@
     def to_dict(self):
         return {
             'world_size': self.world_size,
             'rank': self.rank,
             'gpus_per_node': self.gpus_per_node,
             'cp_size': self.cp_size,
             'tp_size': self.tp_size,
             'pp_size': self.pp_size,
             'moe_tp_size': self.moe_tp_size,
             'moe_cluster_size': self.moe_cluster_size,
             'moe_ep_size': self.moe_ep_size,
             'attn_tp_size': self.attn_tp_size,
             'attn_cp_size': self.attn_cp_size,
             'cp_config': self.cp_config,
             'auto_parallel': self.auto_parallel,
+            'enable_attention_dp': self.enable_attention_dp,
+            'enable_lm_tp_in_adp': self.enable_lm_tp_in_adp,
         }

Also applies to: 227-229

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants