Skip to content

Conversation

lucaslie
Copy link
Member

@lucaslie lucaslie commented Aug 29, 2025

This reverts commit 2101d46.

This is due to a regression in bmm sharding

Summary by CodeRabbit

  • New Features

    • Introduces granular sharding detectors: column/row sharding, expert-parallel (EP) detection, and DP BMM sharding.
    • Adopts a boundary-driven column–row strategy for more accurate GEMM sharding; improved MOE/DP BMM handling.
  • Breaking Changes

    • Renames sharding detector: detect_sharding → detect_column_row_shard; adds detect_ep_shard and detect_dp_bmm_shard.
    • Removes config options: use_sharding_from_factory and sharding_dims; deprecates factory-based sharding configs.
    • Inverts split-dimension semantics (ROW/COLUMN) and switches simple sharding to row-wise.
    • Simplifies sharding config to explicit TP/BMM/EP transform lists.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@lucaslie lucaslie self-assigned this Aug 29, 2025
@lucaslie lucaslie requested a review from a team as a code owner August 29, 2025 03:19
@lucaslie lucaslie requested a review from nvchenghaoz August 29, 2025 03:19
@lucaslie lucaslie added the bug Something isn't working label Aug 29, 2025
Copy link
Contributor

coderabbitai bot commented Aug 29, 2025

📝 Walkthrough

Walkthrough

Splits the monolithic sharding detector into three transforms (column/row TP, DP BMM, EP), removes factory-driven sharding config and related fields/APIs, updates utils to new split-dimension semantics, and adjusts tests and default YAML to the new transform-based workflow.

Changes

Cohort / File(s) Summary
Config: transform entries
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Replaces detect_sharding with detect_column_row_shard, adds detect_ep_shard and detect_dp_bmm_shard; removes use_sharding_from_factory and sharding_dims options.
CLI/Args: config fields
tensorrt_llm/_torch/auto_deploy/llm_args.py
Removes AutoDeployConfig fields use_sharding_from_factory and sharding_dims.
Model factory: sharding config removal
tensorrt_llm/_torch/auto_deploy/models/factory.py
Deletes ShardingConfigSource enum and ModelFactory.get_sharding_config; removes _sharding_config init and Enum import.
HF models: sharding setup removal
tensorrt_llm/_torch/auto_deploy/models/hf.py
Drops ShardingConfigSource import/usage; removes _set_sharding_config methods and invocations in CausalLM and ImageTextToText factories.
Transforms: sharding overhaul
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Renames Sharding→ColumnRowShard with new ColumnRowShardConfig; adds DpBmmShard and DetectEpShard classes; switches simple shard to ROW split; removes factory/predefined config paths and related imports.
Utils: node filtering API
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Simplifies filtered_nodes to filter only by ops; removes target parameter and related logic.
Utils: sharding config and semantics
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Inverts SplitDimension (ROW=0, COLUMN=1); simplifies ShardingConfig to only tp/bmm/ep transforms; removes validation/predefined-config helpers; updates TPShardingInfo.validate; adds model_config to ShardingTransformInfo.
Tests: detector and expectations
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py, .../test_ep_sharding.py, .../test_tp_sharding.py
Switches configs to detect_dp_bmm_shard/detect_ep_shard/detect_column_row_shard; removes factory-based flags/predefined plans; updates split-dimension expectations and function signatures; adjusts EP gate sizing logic.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant InferenceOptimizer
  participant ShardingExecutor
  participant ColumnRowShard as detect_column_row_shard
  participant DpBmmShard as detect_dp_bmm_shard
  participant DetectEpShard as detect_ep_shard

  User->>InferenceOptimizer: configure(default.yaml)
  InferenceOptimizer->>ShardingExecutor: run(sharding stage)
  ShardingExecutor->>ColumnRowShard: apply(gm)
  ColumnRowShard-->>ShardingExecutor: tp_transforms
  ShardingExecutor->>DpBmmShard: apply(gm)
  DpBmmShard-->>ShardingExecutor: bmm_transforms
  ShardingExecutor->>DetectEpShard: apply(gm)
  DetectEpShard-->>ShardingExecutor: ep_transforms
  ShardingExecutor-->>InferenceOptimizer: sharding_config{tp,bmm,ep}
  InferenceOptimizer-->>User: compiled plan
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

AutoDeploy

Suggested reviewers

  • nvchenghaoz
  • nzmora-nvidia
  • Fridah-nv
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@lucaslie lucaslie changed the title Revert "[TRTLLM-6342][feat] TP Sharding read from the model config (#6972) [None][fix] Revert TP Sharding read from the model config (#6972) Aug 29, 2025
@lucaslie
Copy link
Member Author

/bot run

@lucaslie lucaslie requested review from greg-kwasniewski1 and suyoggupta and removed request for nvchenghaoz August 29, 2025 03:22
@lucaslie lucaslie moved this from Backlog to In review in AutoDeploy Board Aug 29, 2025
@tensorrt-cicd
Copy link
Collaborator

PR_Github #16938 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (11)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

1-1: Add NVIDIA 2025 Apache-2.0 header.

All Python sources must prepend the NVIDIA 2025 Apache-2.0 header. Please add it above the module docstring.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

271-299: Guard against missing meta for BMM validate.

Detection relies on node.meta['val']; export/shape-prop may not always populate it. Fail fast with a clear message.

Apply:

-        lhs_batch_size = lhs_tensor.meta["val"].shape[0]
-        rhs_batch_size = rhs_tensor.meta["val"].shape[0]
+        try:
+            lhs_batch_size = lhs_tensor.meta["val"].shape[0]
+            rhs_batch_size = rhs_tensor.meta["val"].shape[0]
+        except Exception:
+            ad_logger.warning(
+                "Missing/invalid meta on BMM inputs; shape inference is required before BMM sharding. Skipping %s.",
+                self,
+            )
+            return False

327-329: Don’t set requires_grad=True on sharded params in inference pipeline.

This can increase memory and enable autograd unexpectedly.

Apply:

-                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
+                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=False)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)

1-1: Add NVIDIA 2025 Apache-2.0 header.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)

1-1: Add NVIDIA 2025 Apache-2.0 header.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

1-1: Add NVIDIA 2025 Apache-2.0 header.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

1-1: Add required NVIDIA 2025 Apache-2.0 header at file top.

The coding guidelines require the NVIDIA 2025 Apache-2.0 header on all Python files.

Apply this diff at the very top:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

151-157: Revert detector key to pre-#6972 name “detect_sharding”

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py: change both "detect_column_row_shard" entries (lines 151 & 274) back to "detect_sharding"
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py: update the @TransformRegistry.register key at line 124
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml: change the detect_column_row_shard key at line 55 to detect_sharding
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

1-1: Add required NVIDIA 2025 Apache-2.0 header at file top.

The coding guidelines require the NVIDIA 2025 Apache-2.0 header on all Python files.

Apply this diff at the very top:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

324-341: DpBmmShard also needs shape propagation enabled.

It reads meta shapes; provide a config that requires shape prop.

 @TransformRegistry.register("detect_dp_bmm_shard")
 class DpBmmShard(BaseTransform):
+    class Config(TransformConfig):
+        requires_shape_prop: bool = Field(default=True)
+        run_shape_prop: bool = Field(default=True)
+
+    @classmethod
+    def get_config_class(cls) -> Type[TransformConfig]:
+        return DpBmmShard.Config

124-148: Restore legacy detect_sharding transform registration
Registration of "detect_column_row_shard" in tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py contradicts this PR’s intent to revert #6972—remove it and re-add the original "detect_sharding" registration and its executor wiring.

🧹 Nitpick comments (8)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

12-12: Conform to “module namespace” import guideline for pydantic.

Avoid symbol imports; use the module namespace.

Apply:

-from pydantic import BaseModel, ConfigDict, Field
+import pydantic as pyd
-class ShardingTransformInfo(BaseModel, ABC):
+class ShardingTransformInfo(pyd.BaseModel, ABC):
@@
-    model_config = ConfigDict(frozen=True)  # Makes the model immutable and hashable
+    model_config = pyd.ConfigDict(frozen=True)  # Makes the model immutable and hashable
-class ShardingConfig(BaseModel):
+class ShardingConfig(pyd.BaseModel):
@@
-    tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
-    bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
-    ep_transforms: List[EPShardingInfo] = Field(default_factory=list)
+    tp_transforms: List[TPShardingInfo] = pyd.Field(default_factory=list)
+    bmm_transforms: List[BMMShardingInfo] = pyd.Field(default_factory=list)
+    ep_transforms: List[EPShardingInfo] = pyd.Field(default_factory=list)

Also applies to: 189-193, 477-483

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)

76-89: Optional: use filtered_nodes for clarity.

Iterate BMM nodes via filtered_nodes to avoid manual any(is_op(...)).

Apply:

-    run_test_transformed_gm(
+    run_test_transformed_gm(
         model,
         x,
         gm_transformed,
-        check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes)
+        check_transformed_graph=lambda gm: any(True for _ in filtered_nodes(gm.graph.nodes, op_expected))
         == (world_size > 1),
         _get_expected_num_params=_get_expected_num_params,
     )
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

32-32: Prefer module-namespace imports per guidelines.

Avoid importing classes/functions directly; import the module and reference symbols via the module.

Apply this minimal refactor:

-from .factory import ModelFactory, ModelFactoryRegistry
+from . import factory

And adjust usages:

-@ModelFactoryRegistry.register("AutoModelForCausalLM")
-class AutoModelForCausalLMFactory(ModelFactory):
+@factory.ModelFactoryRegistry.register("AutoModelForCausalLM")
+class AutoModelForCausalLMFactory(factory.ModelFactory):
@@
-@ModelFactoryRegistry.register("AutoModelForImageTextToText")
-class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
+@factory.ModelFactoryRegistry.register("AutoModelForImageTextToText")
+class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):

Also applies to: 64-65, 64-66, 333-345

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

316-324: Pattern-detection test ignores dist_op_expected.

Signature lists dist_op_expected but it’s unused. Either remove it from the parametrization/signature or assert on the detected ops for completeness.

Apply:

-def test_sharding_pattern_detection(
-    model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int
-):
+def test_sharding_pattern_detection(
+    model_cls: Type[nn.Module], bias: bool, world_size: int
+):
@@
-    _run_pattern_detection_job(model_cls, bias, 0, world_size)
+    _run_pattern_detection_job(model_cls, bias, 0, world_size)

And drop dist_op_expected from the parametrize tuple accordingly.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)

27-31: Prefer module-namespace imports per guidelines.

Import modules rather than individual symbols; adjust call sites accordingly.

Example minimal change:

-from ...models.factory import ModelFactory
+from ...models import factory as model_factory
@@
-from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op
+from ...utils import node_utils

And adapt type hints/usages:

-        factory: ModelFactory,
+        factory: model_factory.ModelFactory,
@@
-        boundary_nodes = identify_regions_between_residuals(gm)
+        boundary_nodes = node_utils.identify_regions_between_residuals(gm)

Similarly replace is_linear_op/is_op with node_utils.is_linear_op/node_utils.is_op where used in this file.


166-181: Boundary/allowlist sets OK; consider adding aten.addmm if present in exported graphs.

If aten.addmm appears between linear groups, add it to allowed/shardable sets to prevent unnecessary fallback.


368-387: Uneven-batch code path is skipped; simplify start/end computation.

Since you early-continue when remainder != 0, the remainder-based index logic is dead code. Simplify to avoid confusion.

-            # Calculate start and end indices for this rank
-            if local_rank < remainder:
-                start_idx = local_rank * (base_size + 1)
-                end_idx = start_idx + base_size + 1
-            else:
-                start_idx = remainder + local_rank * base_size
-                end_idx = start_idx + base_size
+            # Calculate start and end indices for this rank (remainder == 0 here)
+            start_idx = local_rank * base_size
+            end_idx = start_idx + base_size

48-92: Executor applies transforms by name; add a guard for duplicate targets if needed.

Optional: de-duplicate target nodes before application to avoid repeated application when multiple detectors emit the same target.

-        for tp_transform in shared_config.sharding_config.tp_transforms:
+        seen = set()
+        for tp_transform in shared_config.sharding_config.tp_transforms:
+            if tp_transform.target_node in seen:
+                continue
+            seen.add(tp_transform.target_node)
             if check_and_apply(tp_transform):
                 num_matches += 1
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ce580ce and 30d5f1a.

📒 Files selected for processing (10)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/llm_args.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/factory.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (7 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/llm_args.py
  • tensorrt_llm/_torch/auto_deploy/models/factory.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cc,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cc,cxx,cu,cuh,py}: If a constructor parameter name conflicts with a public member, add a trailing underscore to the parameter (e.g., foo_).
Use uppercase literal suffixes (e.g., 1234L not 1234l).
Use spaces, not tabs; indent by 4 spaces.

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Target Python 3.8+ for all Python code.
Indent with 4 spaces; do not use tabs.
Maintain module namespace on imports; import the module/submodule, not individual classes/functions (e.g., from package.subpackage import foo; foo.SomeClass()).
Python filenames use snake_case (e.g., some_file.py).
Class names use PascalCase.
Function and method names use snake_case.
Local variable names use snake_case; if starting with a number, prefix with k_ (e.g., k_99th_percentile).
Global variables use UPPER_SNAKE_CASE and prefix G_ (e.g., G_MY_GLOBAL).
Constants use UPPER_SNAKE_CASE.
Avoid shadowing outer-scope variables.
Initialize all externally visible members of a class in init.
Prefer docstrings for interfaces used outside a file; use comments for local code within functions or local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline with docstrings placed under the definition.
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns).
Limit except clauses to specific exceptions; avoid bare except.
For duck-typing try/except, keep try blocks minimal and use else for the main logic.

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.{cpp,cc,cxx,cu,h,hpp,hh,hxx,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA 2025 Apache-2.0 copyright header block at the top of all source files (.cpp, .h, .cu, .py).

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
  • ModelFactory (15-207)
  • ModelFactoryRegistry (210-228)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (6)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
  • ModelFactory (15-207)
  • register (214-219)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • identify_regions_between_residuals (292-345)
  • is_linear_op (240-252)
  • is_op (183-206)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)
  • SplitDimension (182-186)
  • TPShardingInfo (225-260)
  • BMMShardingInfo (263-362)
  • EPShardingInfo (452-474)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (9)
  • TransformConfig (60-99)
  • TransformRegistry (381-409)
  • register (387-394)
  • BaseTransform (139-378)
  • get_config_class (161-166)
  • get_config_class (402-404)
  • TransformInfo (108-133)
  • _apply (368-378)
  • SharedConfig (51-57)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • SplitDimension (182-186)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (2)
  • test_sharding (143-147)
  • _run_job (52-89)
tensorrt_llm/_torch/auto_deploy/distributed/common.py (1)
  • spawn_multiprocess_job (240-244)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (12)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

235-247: TP dist_op vs split_dim checks look correct.

Row→all_gather, Column→all_reduce validations align with dist_lookup. No issues.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)

67-73: Generated.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

55-62: Verify revert should restore single detect_sharding
rg returned no detect_sharding entries; please confirm that the revert’s intent is to reinstate the original single detect_sharding transform instead of the three granular detectors.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (2)

31-36: Gate param expectation change LGTM.

Counting gate weight+bias as num_experts*(hidden_size+1) matches the fact EP sharding doesn’t shard the gate.


42-49: Config uses detect_ep_shard; PR claims a revert.

Same concern as BMM test: this doesn’t reflect a revert. Please confirm desired behavior and update either the code or PR description.

Also applies to: 99-103

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

209-237: Confirm external usage of filtered_nodes before signature refactor

  • Internal grep found no call sites beyond its definition and doc example.
  • If you need to preserve the old target parameter for external consumers, apply the suggested signature change; otherwise drop legacy target support.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)

274-279: Config still references detect_column_row_shard for pattern detection.

Align the configured transform name with the pre-revert detector (if revert stands).

Use:

-            "detect_column_row_shard": {
+            "detect_sharding": {
                 "stage": "sharding",
             },

298-302: Spawn job wiring looks fine; ensure device_count≥2 coverage exists in CI.

No functional issue; just confirm CI actually exercises multi-GPU cases when device_count > 1.


219-224: The imports at the top of this test file will reveal whether it’s using the post-#6972 API (new SplitDimension/TPShardingInfo module path) or the legacy one—once we have that, we can confirm if any import adjustments are needed.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

94-116: Simple shard path looks consistent (ROW + all_gather).

The semantics align with a conservative simple-shard. No changes requested.


200-211: Fallbacks correctly route to simple-shard; log messages are helpful.

No code changes requested.

Also applies to: 243-249, 275-281


411-453: EP detector looks fine.

Detection and transform accumulation are straightforward.

Comment on lines +118 to 123
class ColumnRowShardConfig(TransformConfig):
"""Configuration for column-row sharding."""

simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = Field(default=False)
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Require shape propagation: config must enable shapes for meta access.

ColumnRowShard relies on node.meta["val"].shape; set requires_shape_prop/run_shape_prop to True.

 class ColumnRowShardConfig(TransformConfig):
     """Configuration for column-row sharding."""
-    simple_shard_only: bool = Field(default=False)
+    simple_shard_only: bool = Field(default=False)
+    requires_shape_prop: bool = Field(default=True)
+    run_shape_prop: bool = Field(default=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class ColumnRowShardConfig(TransformConfig):
"""Configuration for column-row sharding."""
simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = Field(default=False)
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])
class ColumnRowShardConfig(TransformConfig):
"""Configuration for column-row sharding."""
simple_shard_only: bool = Field(default=False)
requires_shape_prop: bool = Field(default=True)
run_shape_prop: bool = Field(default=True)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py around lines
118 to 123, the ColumnRowShardConfig class needs to enable shape propagation
because ColumnRowShard uses node.meta["val"].shape; add boolean fields to turn
on shape propagation (e.g., requires_shape_prop: bool = True and run_shape_prop:
bool = True) or set the existing config flags that control shape propagation to
True by default so that meta shapes are available during transforms.

Comment on lines +286 to +301
# check if we are sharding the attention block
if attention_nodes:
if len(attention_nodes) > 1:
# Column-row shard boundary region detection is probably wrong - there should be
# only one attention operation. Fall back to simple shard.
ad_logger.debug(f"More than one attention node: {unaccounted_nodes}")
_append_simple_shard(
nodes_linear, local_rank, world_size, shared_config.sharding_config
)
elif config == "rowwise":
sharding_config.tp_transforms.append(
continue
# Extract head dimension. We cannot shard below the head_dim size.
# Assume that head_dim is the last (innermost) dimension of the tensor
min_local_shape = attention_nodes.pop().meta["val"].shape[-1]
else:
min_local_shape = 1
for i, group in enumerate(nodes_linear.values()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use explicit enum for split_dim and avoid mutating the attention set.

  • Don’t pop from attention_nodes; read without mutation.
  • Pass SplitDimension enums instead of raw ints for clarity.
-                min_local_shape = attention_nodes.pop().meta["val"].shape[-1]
+                min_local_shape = next(iter(attention_nodes)).meta["val"].shape[-1]
@@
-                            split_dim=i,
+                            split_dim=(SplitDimension.COLUMN if i > 0 else SplitDimension.ROW),

Also applies to: 307-316

Comment on lines 182 to 187
class SplitDimension(IntEnum):
"""Enum for tensor split dimensions in sharding."""

# NOTE: The names COLUMN/ROW reflect the hugging face
# base_tp_plan sharding notation, but since we assume Y = W @ X^T,
# when splitting weight matrix W^T across columns, the actual split
# is over dimension 0
COLUMN = 0
ROW = 1
ROW = 0 # Split along rows (first dimension)
COLUMN = 1 # Split along columns (second dimension)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify SplitDimension ROW/COLUMN semantics across modules.

This file defines ROW=0, COLUMN=1, but tensorrt_llm/_torch/modules/linear.py::split_dim(cls, mode) returns 1 when mode==ROW (see provided snippet), which may invert expectations when mixing .value vs helper APIs.

Run to audit usages:

If mismatches exist, consider mapping dim via a single helper to avoid enum.value leaks.


🏁 Script executed:

#!/bin/bash
# Find SplitDimension definitions/usages and .value propagation
rg -nP -C2 '\bclass\s+SplitDimension\b|SplitDimension\.(ROW|COLUMN)|\.\s*value\b'
# Find call sites deriving dims for linear/matmul
rg -nP -C3 '\bsplit_dim\s*\(|_insert_sharded_matmul\(|torch\.ops\..*dist_all_(gather|reduce)'

Length of output: 266522


🏁 Script executed:

#!/bin/bash
# Show split_dim definition and implementation in linear.py
rg -nP -C5 "def split_dim" -n tensorrt_llm/_torch/modules/linear.py

Length of output: 425


Unify SplitDimension mapping with TensorParallelMode.split_dim

  • sharding_utils.py defines SplitDimension.ROW=0, COLUMN=1 while TensorParallelMode.split_dim maps ROW→1 and COLUMN→0 (linear.py:48–50)
  • Choose one source of truth (e.g. update SplitDimension values or use a central helper) to ensure .value and split_dim produce the same dimension ordering across modules
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py around lines 182 to
187, SplitDimension defines ROW=0 and COLUMN=1 which conflicts with
TensorParallelMode.split_dim (linear.py lines ~48–50) that maps ROW→1 and
COLUMN→0; pick one source of truth and make them consistent: either flip the
integer values in SplitDimension (ROW=1, COLUMN=0) to match TensorParallelMode,
or refactor both modules to use a shared helper/enum (move SplitDimension into a
common utils module and import it in linear.py or make TensorParallelMode.lookup
consult SplitDimension.value) so that all code uses the same mapping for split
dimensions and update any tests/usages accordingly.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16938 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12724 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: In review
Development

Successfully merging this pull request may close these issues.

2 participants