Skip to content

Commit dde295a

Browse files
authored
Deduplicate Transformers backend code using inheritance (#21461)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 6d8d0a2 commit dde295a

File tree

1 file changed

+49
-150
lines changed

1 file changed

+49
-150
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 49 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from vllm.model_executor.layers.quantization import QuantizationConfig
4040
from vllm.model_executor.layers.vocab_parallel_embedding import (
4141
ParallelLMHead, VocabParallelEmbedding)
42-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4342
from vllm.model_executor.sampling_metadata import SamplingMetadata
4443
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
4544
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@@ -55,8 +54,8 @@
5554
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
5655
SupportsQuant)
5756
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
58-
flatten_bn, is_pp_missing_parameter,
59-
make_empty_intermediate_tensors_factory, maybe_prefix)
57+
flatten_bn, make_empty_intermediate_tensors_factory,
58+
maybe_prefix)
6059

6160
logger = init_logger(__name__)
6261

@@ -414,64 +413,63 @@ def __exit__(self, exc_type, exc_value, traceback):
414413
setattr(self.config, key, value)
415414

416415

417-
class TransformersModel:
416+
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
417+
embedding_padding_modules = ["lm_head"]
418+
embedding_modules = ["embed_tokens"
419+
] # TODO transformers will have a util to get it
418420

419421
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
420422
super().__init__()
421423
logger.info("Using Transformers backend.")
422424

423-
config: PretrainedConfig = vllm_config.model_config.hf_config
424-
cache_config: CacheConfig = vllm_config.cache_config
425-
device_config: DeviceConfig = vllm_config.device_config
426-
model_config: ModelConfig = vllm_config.model_config
427-
parallel_config: ParallelConfig = vllm_config.parallel_config
428-
quant_config: QuantizationConfig = vllm_config.quant_config
429-
430-
self.config = config
431-
self.text_config = config.get_text_config()
432-
self.cache_config = cache_config
433-
self.device_config = device_config
434-
self.model_config = model_config
435-
self.parallel_config = parallel_config
436-
self.quant_config = quant_config
425+
self.config: PretrainedConfig = vllm_config.model_config.hf_config
426+
self.text_config: PretrainedConfig = self.config.get_text_config()
427+
self.cache_config: CacheConfig = vllm_config.cache_config
428+
self.device_config: DeviceConfig = vllm_config.device_config
429+
self.model_config: ModelConfig = vllm_config.model_config
430+
self.parallel_config: ParallelConfig = vllm_config.parallel_config
431+
self.quant_config: QuantizationConfig = vllm_config.quant_config
437432

438433
self.pp_group = get_pp_group()
439434
self.pp_size = self.pp_group.world_size
440435
self.pp_rank = self.pp_group.rank_in_group
441436
self.tp_size = get_tensor_model_parallel_world_size()
442437

438+
# To be updated in child classes for use in `load_weights`
439+
self.skip_prefixes: Optional[list[str]] = None
440+
443441
# vLLM handles interleaved sliding window attention by creating a new
444442
# interleaved_sliding_window attribute and deleting the sliding_window
445443
# attribute. This breaks the constructors in Transformers so we
446444
# temporarily add the attribute back to construct the model.
447445
config_override = nullcontext()
448-
if hasattr(config, "interleaved_sliding_window"):
446+
if hasattr(self.config, "interleaved_sliding_window"):
449447
config_override = ConfigOverride(
450-
config, sliding_window=config.interleaved_sliding_window)
448+
self.config,
449+
sliding_window=self.config.interleaved_sliding_window)
451450

452451
# Set correct attn and init on "meta" to delay allocating GPU tensors
453452
# TODO: @raushan, use the public `model.set_attn_implementation()`
454453
# method after v4.54.0 is released
455454
self.text_config._attn_implementation = "vllm"
456455
with init_on_device_without_buffers("meta"), config_override:
457456
self.model: PreTrainedModel = AutoModel.from_config(
458-
config,
459-
torch_dtype=model_config.dtype,
460-
trust_remote_code=model_config.trust_remote_code,
457+
self.config,
458+
torch_dtype=self.model_config.dtype,
459+
trust_remote_code=self.model_config.trust_remote_code,
461460
)
462461

463462
self.pipeline_parallel()
464463
self.tensor_parallel()
465464

466465
# Input embeddings
467-
text_config = config.get_text_config()
468466
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
469467
self.model.set_input_embeddings(
470468
VocabParallelEmbedding(
471-
text_config.vocab_size,
472-
text_config.hidden_size,
473-
org_num_embeddings=text_config.vocab_size,
474-
quant_config=quant_config,
469+
self.text_config.vocab_size,
470+
self.text_config.hidden_size,
471+
org_num_embeddings=self.text_config.vocab_size,
472+
quant_config=self.quant_config,
475473
))
476474

477475
# Attention layers
@@ -481,8 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
481479
self.init_parameters(self.model)
482480

483481
self.make_empty_intermediate_tensors = (
484-
make_empty_intermediate_tensors_factory(["hidden_states"],
485-
text_config.hidden_size))
482+
make_empty_intermediate_tensors_factory(
483+
["hidden_states"], self.text_config.hidden_size))
486484

487485
def pipeline_parallel(self):
488486
"""
@@ -654,78 +652,40 @@ def forward(
654652

655653
def load_weights(self, weights: Iterable[tuple[str,
656654
torch.Tensor]]) -> set[str]:
657-
params_dict = dict(self.named_parameters())
658-
659-
loaded_params = set[str]()
660-
for name, loaded_weight in weights:
661-
# Use "model" instead of base_model_prefix because
662-
# the base model attribute in vLLM is always `model`
663-
if not name.startswith(prefix := "model."):
664-
name = prefix + name
665-
666-
if is_pp_missing_parameter(name, self):
667-
continue
668-
if name in params_dict:
669-
param = params_dict[name]
670-
weight_loader = getattr(param, "weight_loader",
671-
default_weight_loader)
672-
weight_loader(param, loaded_weight)
673-
loaded_params.add(name)
674-
return loaded_params
655+
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
656+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
675657

676658

677659
@support_torch_compile
678-
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
679-
SupportsPP):
680-
embedding_padding_modules = ["lm_head"]
681-
embedding_modules = ["embed_tokens"
682-
] # TODO transformers will have a util to get it
660+
class TransformersForCausalLM(TransformersBase):
683661

684662
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
685-
super().__init__()
686-
config: PretrainedConfig = vllm_config.model_config.hf_config
687-
quant_config: QuantizationConfig = vllm_config.quant_config
688-
689-
self.config = config
663+
super().__init__(vllm_config=vllm_config, prefix=prefix)
690664

691-
self.transformers_model = TransformersModel(vllm_config=vllm_config,
692-
prefix=prefix)
693-
self.model = self.transformers_model.model
665+
# Tell `TransformersBase.load_weights` to skip
666+
# `lm_head` if the model has tied word embeddings
667+
if self.text_config.tie_word_embeddings:
668+
self.skip_prefixes = ["lm_head."]
694669

695670
if get_pp_group().is_last_rank:
696-
self.unpadded_vocab_size = config.vocab_size
671+
self.unpadded_vocab_size = self.text_config.vocab_size
697672
self.lm_head = ParallelLMHead(
698-
config.vocab_size,
699-
config.hidden_size,
700-
quant_config=quant_config,
673+
self.text_config.vocab_size,
674+
self.text_config.hidden_size,
675+
quant_config=self.quant_config,
701676
prefix=maybe_prefix(prefix, "lm_head"),
702677
)
703-
if config.tie_word_embeddings:
678+
if self.text_config.tie_word_embeddings:
704679
self.lm_head = self.lm_head.tie_weights(
705680
self.model.get_input_embeddings())
706681

707-
logit_scale = getattr(config, "logit_scale", 1.0)
708-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
709-
config.vocab_size,
710-
logit_scale)
682+
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
683+
self.logits_processor = LogitsProcessor(
684+
self.unpadded_vocab_size, self.text_config.vocab_size,
685+
logit_scale)
711686
else:
712687
self.lm_head = PPMissingLayer()
713688

714-
self.make_empty_intermediate_tensors = (
715-
self.transformers_model.make_empty_intermediate_tensors)
716-
717-
def forward(
718-
self,
719-
input_ids: Optional[torch.Tensor],
720-
positions: torch.Tensor,
721-
intermediate_tensors: Optional[IntermediateTensors] = None,
722-
inputs_embeds: Optional[torch.Tensor] = None,
723-
) -> Union[torch.Tensor, IntermediateTensors]:
724-
model_output = self.transformers_model.forward(input_ids, positions,
725-
intermediate_tensors,
726-
inputs_embeds)
727-
return model_output
728-
729689
def compute_logits(
730690
self,
731691
hidden_states: torch.Tensor,
@@ -735,23 +695,12 @@ def compute_logits(
735695
sampling_metadata)
736696
return logits
737697

738-
def load_weights(self, weights: Iterable[tuple[str,
739-
torch.Tensor]]) -> set[str]:
740-
skip_prefixes = ["lm_head."
741-
] if self.config.tie_word_embeddings else None
742-
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
743-
return loader.load_weights(weights)
744-
745698

746699
@MULTIMODAL_REGISTRY.register_processor(
747700
MultiModalProcessor,
748701
info=MultiModalProcessingInfo,
749702
dummy_inputs=MultiModalDummyInputsBuilder)
750-
class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
751-
SupportsPP, SupportsMultiModal):
752-
embedding_padding_modules = ["lm_head"]
753-
embedding_modules = ["embed_tokens"]
754-
703+
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
755704
# Backwards compatibility for prev released models. State dicts back then
756705
# had different formats and cannot be loaded with `AutoModel` mapping as is
757706
hf_to_vllm_mapper = WeightsMapper(
@@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
776725
})
777726

778727
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
779-
super().__init__()
780-
config: PretrainedConfig = vllm_config.model_config.hf_config
781-
quant_config: QuantizationConfig = vllm_config.quant_config
728+
super().__init__(vllm_config=vllm_config, prefix=prefix)
782729

783-
self.config = config
784730
self.dtype = vllm_config.model_config.dtype
785731

786-
self.transformers_model = TransformersModel(vllm_config=vllm_config,
787-
prefix=prefix)
788-
self.model = self.transformers_model.model
789-
text_config = config.get_text_config()
790-
791-
if get_pp_group().is_last_rank:
792-
self.unpadded_vocab_size = text_config.vocab_size
793-
self.lm_head = ParallelLMHead(
794-
text_config.vocab_size,
795-
text_config.hidden_size,
796-
quant_config=quant_config,
797-
prefix=maybe_prefix(prefix, "lm_head"),
798-
)
799-
if text_config.tie_word_embeddings:
800-
self.lm_head = self.lm_head.tie_weights(
801-
self.model.get_input_embeddings())
802-
803-
logit_scale = getattr(config, "logit_scale", 1.0)
804-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
805-
text_config.vocab_size,
806-
logit_scale)
807-
else:
808-
self.lm_head = PPMissingLayer()
809-
810-
self.make_empty_intermediate_tensors = (
811-
self.transformers_model.make_empty_intermediate_tensors)
812-
813732
def forward(
814733
self,
815734
input_ids: Optional[torch.Tensor],
@@ -828,30 +747,10 @@ def forward(
828747
input_ids, multimodal_embeds)
829748
input_ids = None
830749

831-
model_output = self.transformers_model.forward(input_ids, positions,
832-
intermediate_tensors,
833-
inputs_embeds)
750+
model_output = super().forward(input_ids, positions,
751+
intermediate_tensors, inputs_embeds)
834752
return model_output
835753

836-
def compute_logits(
837-
self,
838-
hidden_states: torch.Tensor,
839-
sampling_metadata: SamplingMetadata,
840-
) -> Optional[torch.Tensor]:
841-
logits = self.logits_processor(self.lm_head, hidden_states,
842-
sampling_metadata)
843-
return logits
844-
845-
def load_weights(self, weights: Iterable[tuple[str,
846-
torch.Tensor]]) -> set[str]:
847-
loader = AutoWeightsLoader(
848-
self,
849-
skip_prefixes=([
850-
"lm_head."
851-
] if self.config.get_text_config().tie_word_embeddings else None),
852-
)
853-
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
854-
855754
def get_multimodal_embeddings(self, **kwargs):
856755
pixel_values = kwargs.pop("pixel_values", None)
857756
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(

0 commit comments

Comments
 (0)