diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 451f18b471a..041d51e73d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -52,11 +52,13 @@ transforms: quantize_moe: stage: pattern_matcher # TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config. - detect_sharding: + detect_column_row_shard: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - sharding_dims: ['tp', 'ep', 'dp'] + detect_ep_shard: + stage: sharding + detect_dp_bmm_shard: + stage: sharding # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: stage: sharding diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 617f5b58fc3..9811274a8bc 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -159,17 +159,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "If False, auto-detect and use column+row (all_reduce) sharding when possible.", ) - use_sharding_from_factory: bool = Field( - default=False, - description="If True, use sharding from the model factory. If False, use sharding from the " - "AutoDeployConfig.", - ) - - sharding_dims: List[str] = Field( - default=["tp", "ep", "dp"], - description="The sharding methods to apply by the heuristic sharding stage.", - ) - compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( Field( default="torch-compile", diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 4cf0a093ee2..42a30402537 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -2,7 +2,6 @@ import copy from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Callable, Dict, Optional, Type import torch @@ -13,13 +12,6 @@ from ..utils.logger import ad_logger -class ShardingConfigSource(Enum): - """Enum for factory source.""" - - HUGGINGFACE = "huggingface" - UNKNOWN = "unknown" - - class ModelFactory(ABC): """An interface to return and correctly initialize a model from a desired source. @@ -46,8 +38,6 @@ def __init__( self.max_seq_len = max_seq_len self._prefetched_model_path: Optional[str] = None self._prefetched_tokenizer_path: Optional[str] = None - self._sharding_config: Dict[str, Any] = {} - self._sharding_config["source"] = ShardingConfigSource.UNKNOWN @property def model(self) -> Optional[str]: @@ -106,10 +96,6 @@ def get_quant_config(self) -> Dict: """Returns the quantization config for this model or None if not quantized.""" return {} - def get_sharding_config(self) -> Dict: - """Returns the sharding config for this model.""" - return self._sharding_config - def get_cache_config(self) -> CacheConfig: """Return the cache configuration for the model. diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 0a35690c684..ec1da12bc99 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -29,7 +29,7 @@ from ..custom_ops.attention_interface import CacheConfig from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger -from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource +from .factory import ModelFactory, ModelFactoryRegistry from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry @@ -94,9 +94,6 @@ def __init__(self, *args, **kwargs): assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}" self.model_kwargs["torch_dtype"] = dtype - # set sharding config source to huggingface - self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE - @property def autoconfig_from_pretrained(self): return AutoConfig.from_pretrained @@ -164,9 +161,6 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: if hasattr(model, "post_init"): model.post_init() - # if present, initialize sharding config. We need head_dim for colwise sharding. - self._set_sharding_config(model.config) - # patch forward method model.forward = types.MethodType(self._simple_forward, model) @@ -174,20 +168,6 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: return model - def _set_sharding_config(self, model_config: PretrainedConfig): - """Set the sharding config for the model.""" - self._sharding_config["head_dim"] = 1 - if hasattr(model_config, "base_model_tp_plan"): - self._sharding_config["tp_plan"] = model_config.base_model_tp_plan - if hasattr(model_config, "head_dim") and model_config.head_dim is not None: - self._sharding_config["head_dim"] = model_config.head_dim - elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"): - self._sharding_config["head_dim"] = ( - model_config.hidden_size // model_config.num_attention_heads - ) - if hasattr(model_config, "num_hidden_layers"): - self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers - def get_quant_config(self) -> Dict: """Returns the quantization config for this model or an empty dict if not quantized.""" if self._quant_config_reader is not None: @@ -359,19 +339,6 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): }, } - def _set_sharding_config(self, model_config: PretrainedConfig): - """Override the sharding config for the model with text_config.""" - super()._set_sharding_config(model_config) - - if hasattr(model_config, "text_config"): - text_config = model_config.text_config - if hasattr(text_config, "base_model_tp_plan"): - self._sharding_config["tp_plan"] = text_config.base_model_tp_plan - if hasattr(text_config, "head_dim"): - self._sharding_config["head_dim"] = text_config.head_dim - if hasattr(text_config, "num_hidden_layers"): - self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers - @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index c37b627240f..b4ed58c5d32 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -17,7 +17,6 @@ """ import operator -import re from collections import defaultdict from typing import DefaultDict, Dict, List, Set, Tuple, Type @@ -25,15 +24,10 @@ from pydantic import Field from torch.fx import GraphModule, Node -from ...models.factory import ModelFactory, ShardingConfigSource +from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger -from ...utils.node_utils import ( - filtered_nodes, - identify_regions_between_residuals, - is_linear_op, - is_op, -) +from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op from ...utils.sharding_utils import ( BMMShardingInfo, EPShardingInfo, @@ -111,7 +105,7 @@ def _append_simple_shard( tp_shards.append( TPShardingInfo( target_node=n.name, - split_dim=SplitDimension.COLUMN, + split_dim=SplitDimension.ROW, rank=rank, world_size=world_size, dist_op="all_gather", @@ -121,17 +115,14 @@ def _append_simple_shard( sharding_config.tp_transforms.extend(tp_shards) -class ShardingTransformConfig(TransformConfig): - """Configuration for sharding transformations.""" +class ColumnRowShardConfig(TransformConfig): + """Configuration for column-row sharding.""" simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = Field(default=False) - # Which sharding families to run: any subset of {"tp", "ep", "bmm"} - sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) -@TransformRegistry.register("detect_sharding") -class Sharding(BaseTransform): +@TransformRegistry.register("detect_column_row_shard") +class ColumnRowShard(BaseTransform): """A transformation to apply sharding to the model following tensor parallelism. The transformation is based on the following steps: @@ -149,11 +140,11 @@ class Sharding(BaseTransform): splitting, e.g., the individual heads into smaller shards. """ - config: ShardingTransformConfig + config: ColumnRowShardConfig @classmethod def get_config_class(cls) -> Type[TransformConfig]: - return ShardingTransformConfig + return ColumnRowShardConfig def _apply( self, @@ -171,395 +162,167 @@ def _apply( ) assert isinstance(gm, GraphModule), "Expecting GraphModule" - shared_config.sharding_config.rank = local_rank - shared_config.sharding_config.world_size = world_size - shared_config.sharding_config.predefined_config = ( - factory.get_sharding_config() if factory else {} - ) - shared_config.sharding_config.factory_source = ( - shared_config.sharding_config.predefined_config.get( - "source", ShardingConfigSource.UNKNOWN - ) - if factory - else ShardingConfigSource.UNKNOWN - ) - shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only - shared_config.sharding_config.use_sharding_from_factory = ( - self.config.use_sharding_from_factory - ) - - sharding_config = shared_config.sharding_config - sharding_config.validate_config() - - if ( - shared_config.sharding_config.use_sharding_from_factory - and len(shared_config.sharding_config.get_predefined_config()) > 0 - ): - ad_logger.info("Applying sharding from config") - factory_info = detect_sharding_from_factory_config(gm, sharding_config) - return gm, factory_info - shared_config.sharding_config.sharding_dims = self.config.sharding_dims - - ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" - ) - # run TP sharding across ranks - if "tp" in shared_config.sharding_config.sharding_dims: - tp_info = detect_column_row_shard(gm, sharding_config) - else: - tp_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + # find boundary nodes of regions we want to shard + boundary_nodes = identify_regions_between_residuals(gm) + + # TODO: continue updating these lists + # pointwise ops that don't affect the sharder + pointwise_ops = { + torch.ops.aten.gelu, + torch.ops.aten.leaky_relu, + torch.ops.aten.mul, + torch.ops.aten.relu, + torch.ops.aten.sigmoid, + torch.ops.aten.silu, + torch.ops.aten.tanh, + torch.ops.aten.contiguous, + } + + # acceptable attention nodes between sharded GEMMs + shardable_attention_nodes = { + torch.ops.auto_deploy.torch_attention_sdpa, + torch.ops.auto_deploy.torch_attention_grouped_sdpa, + torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, + } + + # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an + # attention node because we know that those ops must be compatible with the attention op. Now + # since the attention op is shardable, we will assume those are as well if used in conjunction + # with the attention op. + shardable_nodes_with_attention = { + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.auto_deploy.flashinfer_rope, + operator.getitem, + } + + # let's look at linear nodes we can identify between pairs of boundary nodes + # There is three potential cases we can handle: + # 1. No linear nodes: + # --> just continue + # 2. Two groups of linear nodes and we can account for all to the view nodes: + # --> row_split (dim 0) 1st group + check for supported nodes + + # col_split (dim 1) 2nd group + all_reduce output of 2nd group + # 3. Linear nodes that are not in two groups or we cannot account for all nodes: + # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output + num_shards = 0 + for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): + # we iterate through all nodes between the two boundary nodes and store linear nodes + # sorted by their input activation node. We also store remaining nodes. + nodes_linear: DefaultDict[Node, List[Node]] = defaultdict(list) + attention_nodes: Set[Node] = set() + attention_related_nodes: Set[Node] = set() + unaccounted_nodes: Set[Node] = set() + current_node = n_start + while current_node != n_end: + if is_linear_op(current_node, include_quantization=True): + nodes_linear[current_node.args[0]].append(current_node) + elif is_op(current_node, shardable_attention_nodes): + attention_nodes.add(current_node) + elif is_op(current_node, shardable_nodes_with_attention): + attention_related_nodes.add(current_node) + elif not is_op(current_node, pointwise_ops): + unaccounted_nodes.add(current_node) + current_node = current_node.next + assert current_node, "Could not identify next node" + + # nothing to shard + if len(nodes_linear) == 0: + continue - # run EP sharding across ranks - if "ep" in shared_config.sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + num_shards += 1 - # run BMM sharding across ranks - if "bmm" in shared_config.sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + if self.config.simple_shard_only: + ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue - info = TransformInfo( - skipped=tp_info.skipped and ep_info.skipped and dp_bmm_info.skipped, - num_matches=tp_info.num_matches + ep_info.num_matches + dp_bmm_info.num_matches, - is_clean=tp_info.is_clean and ep_info.is_clean and dp_bmm_info.is_clean, - has_valid_shapes=tp_info.has_valid_shapes - and ep_info.has_valid_shapes - and dp_bmm_info.has_valid_shapes, - ) - return gm, info + # simple shard when we have != 2 groups of linear nodes + if len(nodes_linear) != 2: + ad_logger.debug(f"Linear groups: {nodes_linear}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue + # let's look at the unnacounted nodes. They are okay as long as they fall before the + # first linear node or after the last linear node, i.e., outside the sharded region + lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} + lin_nodes_passed: Set[Node] = set() + current_node = n_start + while current_node != n_end: + # check if this is another linear node + if current_node in lin_nodes_flat: + lin_nodes_passed.add(current_node) + + # check if we are OUTSIDE sharded region + if len(lin_nodes_passed) == 0 or lin_nodes_passed == lin_nodes_flat: + # remove node from unaccounted nodes since we are outside and it doesn't matter + unaccounted_nodes.discard(current_node) + attention_related_nodes.discard(current_node) + attention_nodes.discard(current_node) + + current_node = current_node.next + + # let's post-process the attention-related nodes + # we can disregard them if we also see attention nodes and we assume they are compatible + if len(attention_nodes) > 0: + attention_related_nodes.clear() + + # check if any unaccounted nodes are left. If so, do a simply shard + if unaccounted_nodes or attention_related_nodes: + ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config + ) + continue -def detect_sharding_from_factory_config( - gm: GraphModule, - sharding_config: ShardingConfig, -) -> TransformInfo: - """ - Create sharding transformations from the predefined config. - TODO: currently, it applies only to TP sharding. - Args: - gm: Graph module to apply transformations to - sharding_config: Predefined sharding configuration - """ - # check if config is valid. - # 1. it is a Dict[str, str] - # 2. the keys are of format "module.submodule.subsubmodule..." - # 3. the wildcard "*" is allowed in the keys - # 4. the allowed values are: - # - "colwise" - # - "rowwise" - # - "sequence_parallel" - # - "local_colwise" - # - "local_rowwise" - # - "local" - # - "gather" - # The following constraints are based on - # https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249 - - factory_config = sharding_config.get_predefined_config() - head_dim = factory_config["head_dim"] - tp_plan = factory_config["tp_plan"] - - rank, world_size = sharding_config.rank, sharding_config.world_size - - # If the node is inside the attention module, we need to set min_local_shape to the - # head_dim - otherwise, we would risk splitting the heads into smaller shards. - # TODO: is there a better way to check if we are in attention module? - attn_names = [ - "attention", - "Attention", - "attn", - "Attn", - "q_proj", - "k_proj", - "v_proj", - "o_proj", - ] - - num_shards = 0 - num_simple_shards = 0 - num_row_col_shards = 0 - - for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): - # use node's weight name to get the module name - module_name = lin_node.args[1].target - - if any(attn_name in module_name for attn_name in attn_names): - min_local_shape = head_dim - else: - min_local_shape = 1 - - # use regex to find if module_name matches any of the keys in sharding_config - for key in tp_plan.keys(): - pattern_string = "*" + key + "*" - # convert it to regex. Escape dots, replace * with .* - # First, we substitute * with an unlikely character, e.g. @ - # Then we escape dots, and finally we replace @ with .* - pattern_string = pattern_string.replace("*", "@") - pattern_regex = re.escape(pattern_string).replace("@", ".*") - if re.match(pattern_regex, module_name): - num_shards += 1 - # we have a match. Get the config for this layer - config = tp_plan[key] - if config == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, - ) + # If we can account for all sharded nodes, we can do a two-way shard + # --> row_split (dim 0) + col_split (dim 1) + all_reduce + + # check if we are sharding the attention block + if attention_nodes: + if len(attention_nodes) > 1: + # Column-row shard boundary region detection is probably wrong - there should be + # only one attention operation. Fall back to simple shard. + ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") + _append_simple_shard( + nodes_linear, local_rank, world_size, shared_config.sharding_config ) - elif config == "rowwise": - sharding_config.tp_transforms.append( + continue + # Extract head dimension. We cannot shard below the head_dim size. + # Assume that head_dim is the last (innermost) dimension of the tensor + min_local_shape = attention_nodes.pop().meta["val"].shape[-1] + else: + min_local_shape = 1 + for i, group in enumerate(nodes_linear.values()): + for n in group: + if i > 0: + dist_op = "all_reduce" + else: + dist_op = None + shared_config.sharding_config.tp_transforms.append( TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.ROW, - rank=rank, + target_node=n.name, + split_dim=i, + rank=local_rank, world_size=world_size, - dist_op="all_reduce", + dist_op=dist_op, min_local_shape=min_local_shape, ) ) - num_row_col_shards += 1 - elif "sequence" in config: - # TODO: Sequence parallelism is not supported yet. - ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") - elif "local" in config: - # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. - ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") - elif "gather" in config: - # Simple shard (row + all_gather) - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, - ) - ) - num_simple_shards += 1 - else: - ad_logger.warning("Invalid sharding config. Skipping.") - # after successful match, break the loop - break - - ad_logger.info( - f"Applied {num_shards} TP shards (simple: {num_simple_shards}, " - f"row-col pattern: {num_row_col_shards})" - ) - return TransformInfo( - skipped=False, - num_matches=len(sharding_config.tp_transforms), - is_clean=False, - has_valid_shapes=False, - ) - - -def detect_column_row_shard( - gm: GraphModule, - sharding_config: ShardingConfig, -) -> TransformInfo: - """A transformation to apply sharding to the model following tensor parallelism. - - The transformation is based on the following steps: - - 1. Identify boundary nodes between residual nodes to identify shardable regions. - 2. Identify the GEMM nodes that can be sharded - 3. Trace through the subgraph using DFS/BFS between each pair of boundary nodes - 4. Account for each node in the trace to ensure the op is correct even after sharding. This is - necessary to ensure that the sharding is correct and we need to be able to account for - **all** nodes in the subgraph. The subgraph here is defined as the region between the first - linear node to the last linear node of an identified sharding region. - # 5. Shard the GEMM nodes or skip accordingly. - - min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism - splitting, e.g., the individual heads into smaller shards. - """ - ad_logger.debug("Before sharding graph: " + str(gm)) - - rank, world_size = sharding_config.rank, sharding_config.world_size - if world_size < 2: - ad_logger.info("Skipping TP sharding for single device") - return TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - - assert isinstance(gm, GraphModule), "Expecting GraphModule" - - ad_logger.info("Running TP sharding detection") - - # find boundary nodes of regions we want to shard - boundary_nodes = identify_regions_between_residuals(gm) - - # TODO: continue updating these lists - # pointwise ops that don't affect the sharder - pointwise_ops = { - torch.ops.aten.gelu, - torch.ops.aten.leaky_relu, - torch.ops.aten.mul, - torch.ops.aten.relu, - torch.ops.aten.sigmoid, - torch.ops.aten.silu, - torch.ops.aten.tanh, - torch.ops.aten.contiguous, - } - - # acceptable attention nodes between sharded GEMMs - shardable_attention_nodes = { - torch.ops.auto_deploy.torch_attention_sdpa, - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, - } - - # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an - # attention node because we know that those ops must be compatible with the attention op. Now - # since the attention op is shardable, we will assume those are as well if used in conjunction - # with the attention op. - shardable_nodes_with_attention = { - torch.ops.aten.view, - torch.ops.aten.reshape, - torch.ops.auto_deploy.flashinfer_rope, - operator.getitem, - } - - # let's look at linear nodes we can identify between pairs of boundary nodes - # There is three potential cases we can handle: - # 1. No linear nodes: - # --> just continue - # 2. Two groups of linear nodes and we can account for all to the view nodes: - # --> row_split (dim 0) 1st group + check for supported nodes + - # col_split (dim 1) 2nd group + all_reduce output of 2nd group - # 3. Linear nodes that are not in two groups or we cannot account for all nodes: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output - num_shards = 0 - num_simple_shards = 0 - num_row_col_shards = 0 - for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): - # we iterate through all nodes between the two boundary nodes and store linear nodes - # sorted by their input activation node. We also store remaining nodes. - nodes_linear: DefaultDict[Node, List[Node]] = defaultdict(list) - attention_nodes: Set[Node] = set() - attention_related_nodes: Set[Node] = set() - unaccounted_nodes: Set[Node] = set() - current_node = n_start - while current_node != n_end: - if is_linear_op(current_node, include_quantization=True): - nodes_linear[current_node.args[0]].append(current_node) - elif is_op(current_node, shardable_attention_nodes): - attention_nodes.add(current_node) - elif is_op(current_node, shardable_nodes_with_attention): - attention_related_nodes.add(current_node) - elif not is_op(current_node, pointwise_ops): - unaccounted_nodes.add(current_node) - current_node = current_node.next - assert current_node, "Could not identify next node" - - # nothing to shard - if len(nodes_linear) == 0: - continue - - num_shards += 1 - - if sharding_config.simple_shard_only: - ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 - continue - - # simple shard when we have != 2 groups of linear nodes - if len(nodes_linear) != 2: - ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 - continue - - # let's look at the unnacounted nodes. They are okay as long as they fall before the - # first linear node or after the last linear node, i.e., outside the sharded region - lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} - lin_nodes_passed: Set[Node] = set() - current_node = n_start - while current_node != n_end: - # check if this is another linear node - if current_node in lin_nodes_flat: - lin_nodes_passed.add(current_node) - - # check if we are OUTSIDE sharded region - if len(lin_nodes_passed) == 0 or lin_nodes_passed == lin_nodes_flat: - # remove node from unaccounted nodes since we are outside and it doesn't matter - unaccounted_nodes.discard(current_node) - attention_related_nodes.discard(current_node) - attention_nodes.discard(current_node) - - current_node = current_node.next - - # let's post-process the attention-related nodes - # we can disregard them if we also see attention nodes and we assume they are compatible - if len(attention_nodes) > 0: - attention_related_nodes.clear() - - # check if any unaccounted nodes are left. If so, do a simply shard - if unaccounted_nodes or attention_related_nodes: - ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 - continue - - # If we can account for all sharded nodes, we can do a two-way shard - # --> row_split (dim 0) + col_split (dim 1) + all_reduce - - # check if we are sharding the attention block - if attention_nodes: - if len(attention_nodes) > 1: - # Column-row shard boundary region detection is probably wrong - there should be - # only one attention operation. Fall back to simple shard. - ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 - continue - # Extract head dimension. We cannot shard below the head_dim size. - # Assume that head_dim is the last (innermost) dimension of the tensor - min_local_shape = attention_nodes.pop().meta["val"].shape[-1] - else: - min_local_shape = 1 - for i, group in enumerate(nodes_linear.values()): - for n in group: - if i > 0: - dist_op = "all_reduce" - else: - dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=n.name, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, - ) - ) - num_row_col_shards += 1 - ad_logger.info( - f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})" - ) - return TransformInfo( - skipped=False, num_matches=num_shards, is_clean=False, has_valid_shapes=False - ) + info = TransformInfo( + skipped=False, num_matches=num_shards, is_clean=False, has_valid_shapes=False + ) + return gm, info -def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo: +@TransformRegistry.register("detect_dp_bmm_shard") +class DpBmmShard(BaseTransform): """A transformation to apply sharding to batched matrix multiplications in the graph. We'll shard the BMM nodes by slicing the batch dimension of input tensors into world_size number of slices. @@ -568,107 +331,122 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra We'll also assume that the inputs to BMM are broadcasted across the devices already. """ - ad_logger.debug("Before sharding graph: " + str(gm)) - rank, world_size = sharding_config.rank, sharding_config.world_size - if world_size < 2: - ad_logger.info("Skipping DP BMM sharding for single device") - return TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - assert isinstance(gm, GraphModule), "Expecting GraphModule" + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + local_rank, world_size = shared_config.local_rank, shared_config.world_size + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + assert isinstance(gm, GraphModule), "Expecting GraphModule" - num_bmm_shards = 0 + num_bmm_shards = 0 - for node in gm.graph.nodes: - if not is_op(node, {torch.ops.aten.bmm}): - continue + for node in gm.graph.nodes: + if not is_op(node, {torch.ops.aten.bmm}): + continue - ad_logger.debug(f"Found BMM node: {node}") + # Get the input tensors + lhs_tensor = node.args[0] + rhs_tensor = node.args[1] - # Get the input tensors - lhs_tensor = node.args[0] - rhs_tensor = node.args[1] + # Check batch sizes from meta information + lhs_batch_size = lhs_tensor.meta["val"].shape[0] + rhs_batch_size = rhs_tensor.meta["val"].shape[0] - # Check batch sizes from meta information - lhs_batch_size = lhs_tensor.meta["val"].shape[0] - rhs_batch_size = rhs_tensor.meta["val"].shape[0] + assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match" + bmm_batch_size = lhs_batch_size - assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match" - bmm_batch_size = lhs_batch_size + # Calculate balanced distribution + base_size = bmm_batch_size // world_size + remainder = bmm_batch_size % world_size - # Calculate balanced distribution - base_size = bmm_batch_size // world_size - remainder = bmm_batch_size % world_size + # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + if remainder: + ad_logger.warning( + f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " + f"This will result in uneven distribution of work across devices. Skipping." + ) + continue - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. - if remainder: - ad_logger.warning( - f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " - f"This will result in uneven distribution of work across devices. Skipping." + # Calculate start and end indices for this rank + if local_rank < remainder: + start_idx = local_rank * (base_size + 1) + end_idx = start_idx + base_size + 1 + else: + start_idx = remainder + local_rank * base_size + end_idx = start_idx + base_size + + shared_config.sharding_config.bmm_transforms.append( + BMMShardingInfo( + target_node=node.name, + rank=local_rank, + world_size=world_size, + start_idx=start_idx, + end_idx=end_idx, + ) ) - continue - - # Calculate start and end indices for this rank - if rank < remainder: - start_idx = rank * (base_size + 1) - end_idx = start_idx + base_size + 1 - else: - start_idx = remainder + rank * base_size - end_idx = start_idx + base_size - - sharding_config.bmm_transforms.append( - BMMShardingInfo( - target_node=node.name, - rank=rank, - world_size=world_size, - start_idx=start_idx, - end_idx=end_idx, + ad_logger.debug( + f"Sharding BMM for rank {local_rank}: " + f"batch_size={bmm_batch_size}, " + f"start_idx={start_idx}, end_idx={end_idx}" ) + + num_bmm_shards += 1 + + info = TransformInfo( + skipped=False, num_matches=num_bmm_shards, is_clean=False, has_valid_shapes=False ) - ad_logger.debug( - f"Sharding BMM for rank {rank}: batch_size={bmm_batch_size}, start_idx={start_idx}, end_idx={end_idx}" - ) + return gm, info + + +@TransformRegistry.register("detect_ep_shard") +class DetectEpShard(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + local_rank, world_size = shared_config.local_rank, shared_config.world_size - num_bmm_shards += 1 - - ad_logger.debug("After sharding BMM: " + str(gm)) - ad_logger.info(f"Found {num_bmm_shards} BMM shards") - - return TransformInfo( - skipped=False, num_matches=num_bmm_shards, is_clean=False, has_valid_shapes=False - ) - - -def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> TransformInfo: - ad_logger.debug("Before sharding graph: " + str(gm)) - - rank, world_size = sharding_config.rank, sharding_config.world_size - if world_size < 2: - ad_logger.info("Skipping EP sharding for single device") - return TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - - assert isinstance(gm, GraphModule), "Expecting GraphModule" - num_moe_patterns = 0 - for node in list(gm.graph.nodes): - if not is_op( - node, - ( - torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, - ), - ): - continue - sharding_config.ep_transforms.append( - EPShardingInfo( - target_node=node.name, - rank=rank, - world_size=world_size, + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - ) - num_moe_patterns += 1 - ad_logger.info(f"Found {num_moe_patterns} MoE patterns") + assert isinstance(gm, GraphModule), "Expecting GraphModule" + num_moe_patterns = 0 + for node in list(gm.graph.nodes): + if not is_op( + node, + ( + torch.ops.auto_deploy.torch_moe, + torch.ops.auto_deploy.torch_quant_fp8_moe, + torch.ops.auto_deploy.torch_quant_fp4_moe, + ), + ): + continue + shared_config.sharding_config.ep_transforms.append( + EPShardingInfo( + target_node=node.name, + rank=local_rank, + world_size=world_size, + ) + ) + num_moe_patterns += 1 - return TransformInfo( - skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False - ) + info = TransformInfo( + skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 6cc98616d47..48f06c70e60 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -207,48 +207,34 @@ def is_op(node: Node, ops: Union[OperatorLike, Iterable[OperatorLike]]) -> bool: def filtered_nodes( - nodes: Iterable[Node], - target: Union[Callable[[Node], bool], Union[OperatorLike, Iterable[OperatorLike]]] = None, - ops: Union[OperatorLike, Iterable[OperatorLike]] = None, + nodes: Iterable[Node], ops: Union[OperatorLike, Iterable[OperatorLike]] ) -> Iterable[Node]: - """Iterate over nodes that are filtered by the given operations or target function. + """Iterate over nodes that are filtered by the given operations. This utility function simplifies the common pattern of iterating through nodes - and filtering by operation type or custom function. + and filtering by operation type. Args: nodes: Iterable of nodes to filter (e.g., gm.graph.nodes) - target: Either a callable function that takes a Node and returns bool, - or operation(s) to match against (deprecated, use ops parameter) - ops: Operation(s) to match against (preferred over target for operations) + ops: Operation(s) to match against Yields: - Node: Nodes that match the given operations or target function + Node: Nodes that match the given operations Example: - # Using callable function: - for node in filtered_nodes(gm.graph.nodes, is_linear_op): + # Instead of: + for node in gm.graph.nodes: + if not is_op(node, torch.ops.aten.linear): + continue # process node - # Using operations: - for node in filtered_nodes(gm.graph.nodes, ops=torch.ops.aten.linear): - # process node - - # Using multiple operations: - for node in filtered_nodes(gm.graph.nodes, ops=[torch.ops.aten.linear, torch.ops.aten.bmm]): + # Use: + for node in filtered_nodes(gm.graph.nodes, torch.ops.aten.linear): # process node """ - # Handle the case where target is a callable function - if callable(target) and not isinstance(target, (OpOverloadPacket, OpOverload)): - for node in nodes: - if target(node): - yield node - else: - # Handle the case where target or ops contains operations - operations = ops if ops is not None else target - for node in nodes: - if is_op(node, operations): - yield node + for node in nodes: + if is_op(node, ops): + yield node def is_linear_op(node: Node, include_quantization: bool = False) -> bool: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 7f9833a2d18..e0c8cd65cac 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -5,14 +5,13 @@ from abc import ABC, abstractmethod from enum import IntEnum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Callable, Dict, List, Literal, Optional import torch import torch.nn as nn -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field from torch.fx import GraphModule, Node -from ..models.factory import ShardingConfigSource from ..utils.logger import ad_logger from .node_utils import extract_param_names_from_lin_node, is_op, num_users_of_weight_node from .quantization_utils import QuantizationImpl @@ -183,12 +182,8 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to class SplitDimension(IntEnum): """Enum for tensor split dimensions in sharding.""" - # NOTE: The names COLUMN/ROW reflect the hugging face - # base_tp_plan sharding notation, but since we assume Y = W @ X^T, - # when splitting weight matrix W^T across columns, the actual split - # is over dimension 0 - COLUMN = 0 - ROW = 1 + ROW = 0 # Split along rows (first dimension) + COLUMN = 1 # Split along columns (second dimension) class ShardingTransformInfo(BaseModel, ABC): @@ -237,16 +232,16 @@ class TPShardingInfo(ShardingTransformInfo): def validate(self, gm: GraphModule = None, node: Node = None) -> bool: """Validate the transformation configuration.""" if self.dist_op is not None: - if self.split_dim == SplitDimension.COLUMN: + if self.split_dim == SplitDimension.ROW: if self.dist_op == "all_reduce": ad_logger.warning( - f"Column split is only supported for all_gather. Skipping {self}." + f"Row split is only supported for all_gather. Skipping {self}." ) return False - if self.split_dim == SplitDimension.ROW: + if self.split_dim == SplitDimension.COLUMN: if self.dist_op == "all_gather": ad_logger.warning( - f"Row split is only supported for all_reduce. Skipping {self}." + f"Column split is only supported for all_reduce. Skipping {self}." ) return False return True @@ -482,98 +477,6 @@ def apply(self, gm: GraphModule, node: Node) -> None: class ShardingConfig(BaseModel): """Configuration for sharding the model.""" - factory_source: ShardingConfigSource = Field(default=ShardingConfigSource.UNKNOWN) - rank: int = Field(default=0) - world_size: int = Field(default=1) - predefined_config: Optional[Dict[str, Any]] = None - simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = False - sharding_dims: List[str] = Field(default_factory=list) tp_transforms: List[TPShardingInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) - - @model_validator(mode="after") - def _validate_and_normalize(self): - # Normalize empty dict to None for "no config" - if isinstance(self.predefined_config, dict) and not self.predefined_config: - self.predefined_config = None - # Validate only if provided - if self.predefined_config is not None: - self.validate_config() - return self - - def validate_config(self) -> bool: - if self.factory_source != ShardingConfigSource.HUGGINGFACE: - ad_logger.warning( - "Sharding config is currently only supported for HuggingFace. Skipping." - ) - # invalidate the config - self.predefined_config = {} - return False - - if not isinstance(self.predefined_config, dict): - ad_logger.warning("Sharding config is not a dictionary. Skipping.") - # invalidate the config - self.predefined_config = {} - return False - - if "head_dim" not in self.predefined_config: - ad_logger.warning("Sharding config does not contain head_dim. Skipping.") - # invalidate the config - self.predefined_config = {} - return False - - if "tp_plan" not in self.predefined_config or self.predefined_config["tp_plan"] is None: - ad_logger.warning("Sharding config does not contain tp_plan. Skipping.") - # invalidate the config - self.predefined_config = {} - return False - tp_plan = self.predefined_config["tp_plan"] - - values = set(tp_plan.values()) - allowed_values = { - "colwise", # row split and no collective - "rowwise", # column split and all-reduce - "gather", # simple shard (row + all_gather) - # TODO: remaining values are not supported yet. - # They require hybrid EP+TP and/or SP support. - # "sequence_parallel", # sequence parallelism - # "local_colwise", - # "local_rowwise", - # "local_packed_rowwise", - # "local", - } - if not values.issubset(allowed_values): - ad_logger.warning("Sharding config contains invalid values. Skipping.") - # invalidate the config - self.predefined_config = {} - return False - return True - - def get_predefined_config(self) -> Dict[str, Any]: - return self.predefined_config - - -def _append_simple_shard( - nodes_linear: Dict[Node, List[Node]], - rank: int, - world_size: int, - sharding_config: ShardingConfig, -) -> None: - # for every linear node: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] - for node_group in nodes_linear.values(): - for n in node_group: - tp_shards.append( - TPShardingInfo( - target_node=n.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, - ) - ) - sharding_config.tp_transforms.extend(tp_shards) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index 41f3ddb36ae..f47e38b9947 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -64,9 +64,8 @@ def _run_job( gm_transformed = InferenceOptimizer( None, { - "detect_sharding": { + "detect_dp_bmm_shard": { "stage": "sharding", - "use_sharding_from_factory": False, }, "sharding_transform_executor": { "stage": "sharding", @@ -125,9 +124,8 @@ def _run_pattern_detection_job( optimizer = InferenceOptimizer( None, { - "detect_sharding": { + "detect_dp_bmm_shard": { "stage": "sharding", - "use_sharding_from_factory": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 0d8c7a33936..8a95771a3a6 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -24,16 +24,11 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None: ).to(device=device, dtype=torch.bfloat16) x = model.get_input(device=device, dtype=torch.bfloat16) - if world_size > num_experts: - print(f"world_size {world_size} > num_experts {num_experts}, skipping test") - return - def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: if world_size <= 1: return num_p_og # the gate's weight and bias node - # NOTE:gate layer is also distributed using simple_shard during tp_transform - n_gate = num_experts * (hidden_size + 1) # // world_size + n_gate = num_experts * (hidden_size + 1) num_experts_per_rank = num_experts // world_size if rank == world_size - 1: num_experts_per_rank += num_experts % world_size @@ -44,10 +39,8 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: gm_transformed = InferenceOptimizer( None, { - "detect_sharding": { + "detect_ep_shard": { "stage": "sharding", - "use_sharding_from_factory": False, - "sharding_dims": ["ep"], }, "sharding_transform_executor": { "stage": "sharding", @@ -103,9 +96,8 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> optimizer = InferenceOptimizer( None, { - "detect_sharding": { + "detect_ep_shard": { "stage": "sharding", - "use_sharding_from_factory": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 802ec15b5bf..016dc659060 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -19,36 +19,6 @@ from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op -base_model_tp_plan = { - "q_proj": "colwise", - "k_proj": "colwise", - "v_proj": "colwise", - "o_proj": "rowwise", - "gate_proj": "colwise", - "up_proj": "colwise", - "down_proj": "rowwise", - "linear1": "colwise", - "linear2": "rowwise", - "linear": "gather", - # "input_layernorm.weight": "sequence_parallel", - # "post_attention_layernorm.weight": "sequence_parallel", - # "norm.weight": "sequence_parallel", - # "shared_expert.gate_proj": "local_colwise", - # "shared_expert.up_proj": "local_colwise", - # "shared_expert.down_proj": "local_rowwise", - # "experts.gate_up_proj": "local_packed_rowwise", - # "experts.down_proj": "local_colwise", - # "experts": "local", - "feed_forward": "gather", - "self": "gather", - "weight": "gather", -} - -predefined_config = { - "head_dim": 8, - "tp_plan": base_model_tp_plan, -} - class GQA_Block(nn.Module): def __init__( @@ -111,7 +81,6 @@ def _run_job( model_cls: nn.Module, dist_op_expected: str, bias: bool, - from_config: bool, rank: int, world_size: int, ) -> None: @@ -179,9 +148,8 @@ def verify_local_weight_sizes(gm) -> bool: gm_transformed = InferenceOptimizer( None, { - "detect_sharding": { + "detect_column_row_shard": { "stage": "sharding", - "use_sharding_from_factory": from_config, }, "sharding_transform_executor": { "stage": "sharding", @@ -212,7 +180,6 @@ def _run_pattern_detection_job( bias: bool, rank: int, world_size: int, - from_config: bool, ) -> None: # init model and input batch_size = 4 @@ -249,10 +216,10 @@ def _run_pattern_detection_job( # for O layer, we expect: # dim = 1, add_dist = True if "o_proj" in node.args[1].name: - dim = SplitDimension.ROW + dim = SplitDimension.COLUMN dist_op = "all_reduce" else: - dim = SplitDimension.COLUMN + dim = SplitDimension.ROW dist_op = None expected_transformations.append( TPShardingInfo( @@ -270,10 +237,10 @@ def _run_pattern_detection_job( # linear1 should be sharded on dim=0, add_dist=False, min_local_shape=1 # linear2 should be sharded on dim=1, add_dist=True, min_local_shape=1 if "linear1" in node.args[1].name: - dim = SplitDimension.COLUMN + dim = SplitDimension.ROW dist_op = None else: - dim = SplitDimension.ROW + dim = SplitDimension.COLUMN dist_op = "all_reduce" expected_transformations.append( TPShardingInfo( @@ -292,7 +259,7 @@ def _run_pattern_detection_job( expected_transformations.append( TPShardingInfo( target_node=node.name, - split_dim=SplitDimension.COLUMN, # Simple shard uses dim=0 + split_dim=SplitDimension.ROW, # Simple shard uses dim=0 rank=rank, world_size=world_size, dist_op="all_gather", @@ -304,9 +271,8 @@ def _run_pattern_detection_job( optimizer = InferenceOptimizer( None, { - "detect_sharding": { + "detect_column_row_shard": { "stage": "sharding", - "use_sharding_from_factory": from_config, }, }, ) @@ -315,15 +281,12 @@ def _run_pattern_detection_job( _ = optimizer(None, gm) detected_transformations = optimizer.shared_config.sharding_config.tp_transforms - print(f"detected_transformations: {detected_transformations}") - print(f"expected_transformations: {expected_transformations}") # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) @pytest.mark.parametrize("device_count", get_device_counts()) @pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("from_config", [False, True]) @pytest.mark.parametrize( "model_cls, dist_op_expected", ( @@ -332,22 +295,15 @@ def _run_pattern_detection_job( (GQA_Block, "torch_dist_all_reduce"), ), ) -def test_sharding( - model_cls: Type[nn.Module], - dist_op_expected: str, - bias: bool, - device_count: int, - from_config: bool, -): +def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int): dist_common.spawn_multiprocess_job( - job=partial(_run_job, model_cls, dist_op_expected, bias, from_config), + job=partial(_run_job, model_cls, dist_op_expected, bias), size=device_count, ) @pytest.mark.parametrize("world_size", [1, 8]) @pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("from_config", [False, True]) @pytest.mark.parametrize( "model_cls, dist_op_expected", ( @@ -357,19 +313,11 @@ def test_sharding( ), ) def test_sharding_pattern_detection( - model_cls: Type[nn.Module], - dist_op_expected: str, - bias: bool, - world_size: int, - from_config: bool, + model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int ): """Test pattern detection logic without distributed execution. This test verifies only the pattern detection logic with provided world_size. No need to run distributed job, can be run on single process. """ - _run_pattern_detection_job(model_cls, bias, 0, world_size, from_config) - - -if __name__ == "__main__": - _run_pattern_detection_job(nn.Linear, False, 0, 8, False) + _run_pattern_detection_job(model_cls, bias, 0, world_size)