Skip to content

Deduplicate Transformers backend code using inheritance #21461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 49 additions & 150 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
Expand All @@ -55,8 +54,8 @@
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix)

logger = init_logger(__name__)

Expand Down Expand Up @@ -414,64 +413,63 @@ def __exit__(self, exc_type, exc_value, traceback):
setattr(self.config, key, value)


class TransformersModel:
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")

config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config: CacheConfig = vllm_config.cache_config
device_config: DeviceConfig = vllm_config.device_config
model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config

self.config = config
self.text_config = config.get_text_config()
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.config: PretrainedConfig = vllm_config.model_config.hf_config
self.text_config: PretrainedConfig = self.config.get_text_config()
self.cache_config: CacheConfig = vllm_config.cache_config
self.device_config: DeviceConfig = vllm_config.device_config
self.model_config: ModelConfig = vllm_config.model_config
self.parallel_config: ParallelConfig = vllm_config.parallel_config
self.quant_config: QuantizationConfig = vllm_config.quant_config

self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()

# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None

# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
if hasattr(self.config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
self.config,
sliding_window=self.config.interleaved_sliding_window)

# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)

self.pipeline_parallel()
self.tensor_parallel()

# Input embeddings
text_config = config.get_text_config()
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
text_config.vocab_size,
text_config.hidden_size,
org_num_embeddings=text_config.vocab_size,
quant_config=quant_config,
self.text_config.vocab_size,
self.text_config.hidden_size,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
))

# Attention layers
Expand All @@ -481,8 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.init_parameters(self.model)

self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
text_config.hidden_size))
make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size))

def pipeline_parallel(self):
"""
Expand Down Expand Up @@ -654,78 +652,40 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())

loaded_params = set[str]()
for name, loaded_weight in weights:
# Use "model" instead of base_model_prefix because
# the base model attribute in vLLM is always `model`
if not name.startswith(prefix := "model."):
name = prefix + name

if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
class TransformersForCausalLM(TransformersBase):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config

self.config = config
super().__init__(vllm_config=vllm_config, prefix=prefix)

self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes = ["lm_head."]

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
self,
hidden_states: torch.Tensor,
Expand All @@ -735,23 +695,12 @@ def compute_logits(
sampling_metadata)
return logits

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)


@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP, SupportsMultiModal):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]

class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
Expand All @@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
super().__init__(vllm_config=vllm_config, prefix=prefix)

self.config = config
self.dtype = vllm_config.model_config.dtype

self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = text_config.vocab_size
self.lm_head = ParallelLMHead(
text_config.vocab_size,
text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
text_config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()

self.make_empty_intermediate_tensors = (
self.transformers_model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand All @@ -828,30 +747,10 @@ def forward(
input_ids, multimodal_embeds)
input_ids = None

model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
model_output = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds)
return model_output

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"lm_head."
] if self.config.get_text_config().tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
Expand Down