Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ transforms:
quantize_moe:
stage: pattern_matcher
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
detect_sharding:
detect_column_row_shard:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
sharding_dims: ['tp', 'ep', 'dp']
detect_ep_shard:
stage: sharding
detect_dp_bmm_shard:
stage: sharding
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:
stage: sharding
Expand Down
11 changes: 0 additions & 11 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
)

use_sharding_from_factory: bool = Field(
default=False,
description="If True, use sharding from the model factory. If False, use sharding from the "
"AutoDeployConfig.",
)

sharding_dims: List[str] = Field(
default=["tp", "ep", "dp"],
description="The sharding methods to apply by the heuristic sharding stage.",
)

compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(
default="torch-compile",
Expand Down
14 changes: 0 additions & 14 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type

import torch
Expand All @@ -13,13 +12,6 @@
from ..utils.logger import ad_logger


class ShardingConfigSource(Enum):
"""Enum for factory source."""

HUGGINGFACE = "huggingface"
UNKNOWN = "unknown"


class ModelFactory(ABC):
"""An interface to return and correctly initialize a model from a desired source.

Expand All @@ -46,8 +38,6 @@ def __init__(
self.max_seq_len = max_seq_len
self._prefetched_model_path: Optional[str] = None
self._prefetched_tokenizer_path: Optional[str] = None
self._sharding_config: Dict[str, Any] = {}
self._sharding_config["source"] = ShardingConfigSource.UNKNOWN

@property
def model(self) -> Optional[str]:
Expand Down Expand Up @@ -106,10 +96,6 @@ def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or None if not quantized."""
return {}

def get_sharding_config(self) -> Dict:
"""Returns the sharding config for this model."""
return self._sharding_config

def get_cache_config(self) -> CacheConfig:
"""Return the cache configuration for the model.

Expand Down
35 changes: 1 addition & 34 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..custom_ops.attention_interface import CacheConfig
from ..utils._config import deep_merge_dicts
from ..utils.logger import ad_logger
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource
from .factory import ModelFactory, ModelFactoryRegistry
from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry


Expand Down Expand Up @@ -94,9 +94,6 @@ def __init__(self, *args, **kwargs):
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
self.model_kwargs["torch_dtype"] = dtype

# set sharding config source to huggingface
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE

@property
def autoconfig_from_pretrained(self):
return AutoConfig.from_pretrained
Expand Down Expand Up @@ -164,30 +161,13 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
if hasattr(model, "post_init"):
model.post_init()

# if present, initialize sharding config. We need head_dim for colwise sharding.
self._set_sharding_config(model.config)

# patch forward method
model.forward = types.MethodType(self._simple_forward, model)

model.eval()

return model

def _set_sharding_config(self, model_config: PretrainedConfig):
"""Set the sharding config for the model."""
self._sharding_config["head_dim"] = 1
if hasattr(model_config, "base_model_tp_plan"):
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
if hasattr(model_config, "head_dim") and model_config.head_dim is not None:
self._sharding_config["head_dim"] = model_config.head_dim
elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"):
self._sharding_config["head_dim"] = (
model_config.hidden_size // model_config.num_attention_heads
)
if hasattr(model_config, "num_hidden_layers"):
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers

def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or an empty dict if not quantized."""
if self._quant_config_reader is not None:
Expand Down Expand Up @@ -359,19 +339,6 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
},
}

def _set_sharding_config(self, model_config: PretrainedConfig):
"""Override the sharding config for the model with text_config."""
super()._set_sharding_config(model_config)

if hasattr(model_config, "text_config"):
text_config = model_config.text_config
if hasattr(text_config, "base_model_tp_plan"):
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
if hasattr(text_config, "head_dim"):
self._sharding_config["head_dim"] = text_config.head_dim
if hasattr(text_config, "num_hidden_layers"):
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers

@property
def automodel_from_config(self):
return AutoModelForImageTextToText.from_config
Loading
Loading