-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[None][fix] Revert TP Sharding read from the model config (#6972) #7356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…VIDIA#6972)" This reverts commit 2101d46. Signed-off-by: Lucas Liebenwein <[email protected]>
📝 WalkthroughWalkthroughSplits the monolithic sharding detector into three transforms (column/row TP, DP BMM, EP), removes factory-driven sharding config and related fields/APIs, updates utils to new split-dimension semantics, and adjusts tests and default YAML to the new transform-based workflow. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant InferenceOptimizer
participant ShardingExecutor
participant ColumnRowShard as detect_column_row_shard
participant DpBmmShard as detect_dp_bmm_shard
participant DetectEpShard as detect_ep_shard
User->>InferenceOptimizer: configure(default.yaml)
InferenceOptimizer->>ShardingExecutor: run(sharding stage)
ShardingExecutor->>ColumnRowShard: apply(gm)
ColumnRowShard-->>ShardingExecutor: tp_transforms
ShardingExecutor->>DpBmmShard: apply(gm)
DpBmmShard-->>ShardingExecutor: bmm_transforms
ShardingExecutor->>DetectEpShard: apply(gm)
DetectEpShard-->>ShardingExecutor: ep_transforms
ShardingExecutor-->>InferenceOptimizer: sharding_config{tp,bmm,ep}
InferenceOptimizer-->>User: compiled plan
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
/bot run |
PR_Github #16938 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (11)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
1-1
: Add NVIDIA 2025 Apache-2.0 header.All Python sources must prepend the NVIDIA 2025 Apache-2.0 header. Please add it above the module docstring.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
271-299
: Guard against missing meta for BMM validate.Detection relies on node.meta['val']; export/shape-prop may not always populate it. Fail fast with a clear message.
Apply:
- lhs_batch_size = lhs_tensor.meta["val"].shape[0] - rhs_batch_size = rhs_tensor.meta["val"].shape[0] + try: + lhs_batch_size = lhs_tensor.meta["val"].shape[0] + rhs_batch_size = rhs_tensor.meta["val"].shape[0] + except Exception: + ad_logger.warning( + "Missing/invalid meta on BMM inputs; shape inference is required before BMM sharding. Skipping %s.", + self, + ) + return False
327-329
: Don’t set requires_grad=True on sharded params in inference pipeline.This can increase memory and enable autograd unexpectedly.
Apply:
- param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True) + param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=False)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)
1-1
: Add NVIDIA 2025 Apache-2.0 header.Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)
1-1
: Add NVIDIA 2025 Apache-2.0 header.Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
1-1
: Add NVIDIA 2025 Apache-2.0 header.Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
1-1
: Add required NVIDIA 2025 Apache-2.0 header at file top.The coding guidelines require the NVIDIA 2025 Apache-2.0 header on all Python files.
Apply this diff at the very top:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
151-157
: Revert detector key to pre-#6972 name “detect_sharding”
- tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py: change both
"detect_column_row_shard"
entries (lines 151 & 274) back to"detect_sharding"
- tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py: update the @TransformRegistry.register key at line 124
- tensorrt_llm/_torch/auto_deploy/config/default.yaml: change the
detect_column_row_shard
key at line 55 todetect_sharding
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
1-1
: Add required NVIDIA 2025 Apache-2.0 header at file top.The coding guidelines require the NVIDIA 2025 Apache-2.0 header on all Python files.
Apply this diff at the very top:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
324-341
: DpBmmShard also needs shape propagation enabled.It reads meta shapes; provide a config that requires shape prop.
@TransformRegistry.register("detect_dp_bmm_shard") class DpBmmShard(BaseTransform): + class Config(TransformConfig): + requires_shape_prop: bool = Field(default=True) + run_shape_prop: bool = Field(default=True) + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return DpBmmShard.Config
124-148
: Restore legacydetect_sharding
transform registration
Registration of"detect_column_row_shard"
intensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
contradicts this PR’s intent to revert #6972—remove it and re-add the original"detect_sharding"
registration and its executor wiring.
🧹 Nitpick comments (8)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
12-12
: Conform to “module namespace” import guideline for pydantic.Avoid symbol imports; use the module namespace.
Apply:
-from pydantic import BaseModel, ConfigDict, Field +import pydantic as pyd-class ShardingTransformInfo(BaseModel, ABC): +class ShardingTransformInfo(pyd.BaseModel, ABC): @@ - model_config = ConfigDict(frozen=True) # Makes the model immutable and hashable + model_config = pyd.ConfigDict(frozen=True) # Makes the model immutable and hashable-class ShardingConfig(BaseModel): +class ShardingConfig(pyd.BaseModel): @@ - tp_transforms: List[TPShardingInfo] = Field(default_factory=list) - bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) - ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + tp_transforms: List[TPShardingInfo] = pyd.Field(default_factory=list) + bmm_transforms: List[BMMShardingInfo] = pyd.Field(default_factory=list) + ep_transforms: List[EPShardingInfo] = pyd.Field(default_factory=list)Also applies to: 189-193, 477-483
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)
76-89
: Optional: use filtered_nodes for clarity.Iterate BMM nodes via filtered_nodes to avoid manual any(is_op(...)).
Apply:
- run_test_transformed_gm( + run_test_transformed_gm( model, x, gm_transformed, - check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes) + check_transformed_graph=lambda gm: any(True for _ in filtered_nodes(gm.graph.nodes, op_expected)) == (world_size > 1), _get_expected_num_params=_get_expected_num_params, )tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
32-32
: Prefer module-namespace imports per guidelines.Avoid importing classes/functions directly; import the module and reference symbols via the module.
Apply this minimal refactor:
-from .factory import ModelFactory, ModelFactoryRegistry +from . import factoryAnd adjust usages:
-@ModelFactoryRegistry.register("AutoModelForCausalLM") -class AutoModelForCausalLMFactory(ModelFactory): +@factory.ModelFactoryRegistry.register("AutoModelForCausalLM") +class AutoModelForCausalLMFactory(factory.ModelFactory): @@ -@ModelFactoryRegistry.register("AutoModelForImageTextToText") -class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): +@factory.ModelFactoryRegistry.register("AutoModelForImageTextToText") +class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):Also applies to: 64-65, 64-66, 333-345
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
316-324
: Pattern-detection test ignores dist_op_expected.Signature lists dist_op_expected but it’s unused. Either remove it from the parametrization/signature or assert on the detected ops for completeness.
Apply:
-def test_sharding_pattern_detection( - model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int -): +def test_sharding_pattern_detection( + model_cls: Type[nn.Module], bias: bool, world_size: int +): @@ - _run_pattern_detection_job(model_cls, bias, 0, world_size) + _run_pattern_detection_job(model_cls, bias, 0, world_size)And drop dist_op_expected from the parametrize tuple accordingly.
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)
27-31
: Prefer module-namespace imports per guidelines.Import modules rather than individual symbols; adjust call sites accordingly.
Example minimal change:
-from ...models.factory import ModelFactory +from ...models import factory as model_factory @@ -from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op +from ...utils import node_utilsAnd adapt type hints/usages:
- factory: ModelFactory, + factory: model_factory.ModelFactory, @@ - boundary_nodes = identify_regions_between_residuals(gm) + boundary_nodes = node_utils.identify_regions_between_residuals(gm)Similarly replace is_linear_op/is_op with node_utils.is_linear_op/node_utils.is_op where used in this file.
166-181
: Boundary/allowlist sets OK; consider adding aten.addmm if present in exported graphs.If aten.addmm appears between linear groups, add it to allowed/shardable sets to prevent unnecessary fallback.
368-387
: Uneven-batch code path is skipped; simplify start/end computation.Since you early-continue when remainder != 0, the remainder-based index logic is dead code. Simplify to avoid confusion.
- # Calculate start and end indices for this rank - if local_rank < remainder: - start_idx = local_rank * (base_size + 1) - end_idx = start_idx + base_size + 1 - else: - start_idx = remainder + local_rank * base_size - end_idx = start_idx + base_size + # Calculate start and end indices for this rank (remainder == 0 here) + start_idx = local_rank * base_size + end_idx = start_idx + base_size
48-92
: Executor applies transforms by name; add a guard for duplicate targets if needed.Optional: de-duplicate target nodes before application to avoid repeated application when multiple detectors emit the same target.
- for tp_transform in shared_config.sharding_config.tp_transforms: + seen = set() + for tp_transform in shared_config.sharding_config.tp_transforms: + if tp_transform.target_node in seen: + continue + seen.add(tp_transform.target_node) if check_and_apply(tp_transform): num_matches += 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (10)
tensorrt_llm/_torch/auto_deploy/config/default.yaml
(1 hunks)tensorrt_llm/_torch/auto_deploy/llm_args.py
(0 hunks)tensorrt_llm/_torch/auto_deploy/models/factory.py
(0 hunks)tensorrt_llm/_torch/auto_deploy/models/hf.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
(6 hunks)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
(2 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
(7 hunks)
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/auto_deploy/llm_args.py
- tensorrt_llm/_torch/auto_deploy/models/factory.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cc,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cc,cxx,cu,cuh,py}
: If a constructor parameter name conflicts with a public member, add a trailing underscore to the parameter (e.g., foo_).
Use uppercase literal suffixes (e.g., 1234L not 1234l).
Use spaces, not tabs; indent by 4 spaces.
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
tensorrt_llm/_torch/auto_deploy/models/hf.py
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Target Python 3.8+ for all Python code.
Indent with 4 spaces; do not use tabs.
Maintain module namespace on imports; import the module/submodule, not individual classes/functions (e.g., from package.subpackage import foo; foo.SomeClass()).
Python filenames use snake_case (e.g., some_file.py).
Class names use PascalCase.
Function and method names use snake_case.
Local variable names use snake_case; if starting with a number, prefix with k_ (e.g., k_99th_percentile).
Global variables use UPPER_SNAKE_CASE and prefix G_ (e.g., G_MY_GLOBAL).
Constants use UPPER_SNAKE_CASE.
Avoid shadowing outer-scope variables.
Initialize all externally visible members of a class in init.
Prefer docstrings for interfaces used outside a file; use comments for local code within functions or local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline with docstrings placed under the definition.
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns).
Limit except clauses to specific exceptions; avoid bare except.
For duck-typing try/except, keep try blocks minimal and use else for the main logic.
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
tensorrt_llm/_torch/auto_deploy/models/hf.py
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.{cpp,cc,cxx,cu,h,hpp,hh,hxx,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA 2025 Apache-2.0 copyright header block at the top of all source files (.cpp, .h, .cu, .py).
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
tensorrt_llm/_torch/auto_deploy/models/hf.py
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
ModelFactory
(15-207)ModelFactoryRegistry
(210-228)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (6)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
ModelFactory
(15-207)register
(214-219)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface
(12-70)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
identify_regions_between_residuals
(292-345)is_linear_op
(240-252)is_op
(183-206)tensorrt_llm/_torch/modules/linear.py (1)
split_dim
(48-49)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)
SplitDimension
(182-186)TPShardingInfo
(225-260)BMMShardingInfo
(263-362)EPShardingInfo
(452-474)tensorrt_llm/_torch/auto_deploy/transform/interface.py (9)
TransformConfig
(60-99)TransformRegistry
(381-409)register
(387-394)BaseTransform
(139-378)get_config_class
(161-166)get_config_class
(402-404)TransformInfo
(108-133)_apply
(368-378)SharedConfig
(51-57)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/modules/linear.py (1)
split_dim
(48-49)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
SplitDimension
(182-186)tensorrt_llm/_torch/modules/linear.py (1)
split_dim
(48-49)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (2)
test_sharding
(143-147)_run_job
(52-89)tensorrt_llm/_torch/auto_deploy/distributed/common.py (1)
spawn_multiprocess_job
(240-244)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (12)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
235-247
: TP dist_op vs split_dim checks look correct.Row→all_gather, Column→all_reduce validations align with dist_lookup. No issues.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)
67-73
: Generated.tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
55-62
: Verify revert should restore singledetect_sharding
rg
returned nodetect_sharding
entries; please confirm that the revert’s intent is to reinstate the original singledetect_sharding
transform instead of the three granular detectors.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (2)
31-36
: Gate param expectation change LGTM.Counting gate weight+bias as num_experts*(hidden_size+1) matches the fact EP sharding doesn’t shard the gate.
42-49
: Config uses detect_ep_shard; PR claims a revert.Same concern as BMM test: this doesn’t reflect a revert. Please confirm desired behavior and update either the code or PR description.
Also applies to: 99-103
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
209-237
: Confirm external usage offiltered_nodes
before signature refactor
- Internal grep found no call sites beyond its definition and doc example.
- If you need to preserve the old
target
parameter for external consumers, apply the suggested signature change; otherwise drop legacytarget
support.tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)
274-279
: Config still references detect_column_row_shard for pattern detection.Align the configured transform name with the pre-revert detector (if revert stands).
Use:
- "detect_column_row_shard": { + "detect_sharding": { "stage": "sharding", },
298-302
: Spawn job wiring looks fine; ensure device_count≥2 coverage exists in CI.No functional issue; just confirm CI actually exercises multi-GPU cases when device_count > 1.
219-224
: The imports at the top of this test file will reveal whether it’s using the post-#6972 API (new SplitDimension/TPShardingInfo module path) or the legacy one—once we have that, we can confirm if any import adjustments are needed.tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
94-116
: Simple shard path looks consistent (ROW + all_gather).The semantics align with a conservative simple-shard. No changes requested.
200-211
: Fallbacks correctly route to simple-shard; log messages are helpful.No code changes requested.
Also applies to: 243-249, 275-281
411-453
: EP detector looks fine.Detection and transform accumulation are straightforward.
class ColumnRowShardConfig(TransformConfig): | ||
"""Configuration for column-row sharding.""" | ||
|
||
simple_shard_only: bool = Field(default=False) | ||
use_sharding_from_factory: bool = Field(default=False) | ||
# Which sharding families to run: any subset of {"tp", "ep", "bmm"} | ||
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Require shape propagation: config must enable shapes for meta access.
ColumnRowShard relies on node.meta["val"].shape; set requires_shape_prop/run_shape_prop to True.
class ColumnRowShardConfig(TransformConfig):
"""Configuration for column-row sharding."""
- simple_shard_only: bool = Field(default=False)
+ simple_shard_only: bool = Field(default=False)
+ requires_shape_prop: bool = Field(default=True)
+ run_shape_prop: bool = Field(default=True)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
class ColumnRowShardConfig(TransformConfig): | |
"""Configuration for column-row sharding.""" | |
simple_shard_only: bool = Field(default=False) | |
use_sharding_from_factory: bool = Field(default=False) | |
# Which sharding families to run: any subset of {"tp", "ep", "bmm"} | |
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) | |
class ColumnRowShardConfig(TransformConfig): | |
"""Configuration for column-row sharding.""" | |
simple_shard_only: bool = Field(default=False) | |
requires_shape_prop: bool = Field(default=True) | |
run_shape_prop: bool = Field(default=True) |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py around lines
118 to 123, the ColumnRowShardConfig class needs to enable shape propagation
because ColumnRowShard uses node.meta["val"].shape; add boolean fields to turn
on shape propagation (e.g., requires_shape_prop: bool = True and run_shape_prop:
bool = True) or set the existing config flags that control shape propagation to
True by default so that meta shapes are available during transforms.
# check if we are sharding the attention block | ||
if attention_nodes: | ||
if len(attention_nodes) > 1: | ||
# Column-row shard boundary region detection is probably wrong - there should be | ||
# only one attention operation. Fall back to simple shard. | ||
ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") | ||
_append_simple_shard( | ||
nodes_linear, local_rank, world_size, shared_config.sharding_config | ||
) | ||
elif config == "rowwise": | ||
sharding_config.tp_transforms.append( | ||
continue | ||
# Extract head dimension. We cannot shard below the head_dim size. | ||
# Assume that head_dim is the last (innermost) dimension of the tensor | ||
min_local_shape = attention_nodes.pop().meta["val"].shape[-1] | ||
else: | ||
min_local_shape = 1 | ||
for i, group in enumerate(nodes_linear.values()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use explicit enum for split_dim and avoid mutating the attention set.
- Don’t pop from attention_nodes; read without mutation.
- Pass SplitDimension enums instead of raw ints for clarity.
- min_local_shape = attention_nodes.pop().meta["val"].shape[-1]
+ min_local_shape = next(iter(attention_nodes)).meta["val"].shape[-1]
@@
- split_dim=i,
+ split_dim=(SplitDimension.COLUMN if i > 0 else SplitDimension.ROW),
Also applies to: 307-316
class SplitDimension(IntEnum): | ||
"""Enum for tensor split dimensions in sharding.""" | ||
|
||
# NOTE: The names COLUMN/ROW reflect the hugging face | ||
# base_tp_plan sharding notation, but since we assume Y = W @ X^T, | ||
# when splitting weight matrix W^T across columns, the actual split | ||
# is over dimension 0 | ||
COLUMN = 0 | ||
ROW = 1 | ||
ROW = 0 # Split along rows (first dimension) | ||
COLUMN = 1 # Split along columns (second dimension) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify SplitDimension ROW/COLUMN semantics across modules.
This file defines ROW=0, COLUMN=1, but tensorrt_llm/_torch/modules/linear.py::split_dim(cls, mode) returns 1 when mode==ROW (see provided snippet), which may invert expectations when mixing .value vs helper APIs.
Run to audit usages:
If mismatches exist, consider mapping dim via a single helper to avoid enum.value leaks.
🏁 Script executed:
#!/bin/bash
# Find SplitDimension definitions/usages and .value propagation
rg -nP -C2 '\bclass\s+SplitDimension\b|SplitDimension\.(ROW|COLUMN)|\.\s*value\b'
# Find call sites deriving dims for linear/matmul
rg -nP -C3 '\bsplit_dim\s*\(|_insert_sharded_matmul\(|torch\.ops\..*dist_all_(gather|reduce)'
Length of output: 266522
🏁 Script executed:
#!/bin/bash
# Show split_dim definition and implementation in linear.py
rg -nP -C5 "def split_dim" -n tensorrt_llm/_torch/modules/linear.py
Length of output: 425
Unify SplitDimension mapping with TensorParallelMode.split_dim
- sharding_utils.py defines SplitDimension.ROW=0, COLUMN=1 while TensorParallelMode.split_dim maps ROW→1 and COLUMN→0 (linear.py:48–50)
- Choose one source of truth (e.g. update SplitDimension values or use a central helper) to ensure
.value
andsplit_dim
produce the same dimension ordering across modules
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py around lines 182 to
187, SplitDimension defines ROW=0 and COLUMN=1 which conflicts with
TensorParallelMode.split_dim (linear.py lines ~48–50) that maps ROW→1 and
COLUMN→0; pick one source of truth and make them consistent: either flip the
integer values in SplitDimension (ROW=1, COLUMN=0) to match TensorParallelMode,
or refactor both modules to use a shared helper/enum (move SplitDimension into a
common utils module and import it in linear.py or make TensorParallelMode.lookup
consult SplitDimension.value) so that all code uses the same mapping for split
dimensions and update any tests/usages accordingly.
PR_Github #16938 [ run ] completed with state |
This reverts commit 2101d46.
This is due to a regression in bmm sharding
Summary by CodeRabbit
New Features
Breaking Changes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.