diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 610f8e752dbd..8cd95605cdfa 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -55,8 +54,8 @@ from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - flatten_bn, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) + flatten_bn, make_empty_intermediate_tensors_factory, + maybe_prefix) logger = init_logger(__name__) @@ -414,40 +413,40 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self.config, key, value) -class TransformersModel: +class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens" + ] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() logger.info("Using Transformers backend.") - config: PretrainedConfig = vllm_config.model_config.hf_config - cache_config: CacheConfig = vllm_config.cache_config - device_config: DeviceConfig = vllm_config.device_config - model_config: ModelConfig = vllm_config.model_config - parallel_config: ParallelConfig = vllm_config.parallel_config - quant_config: QuantizationConfig = vllm_config.quant_config - - self.config = config - self.text_config = config.get_text_config() - self.cache_config = cache_config - self.device_config = device_config - self.model_config = model_config - self.parallel_config = parallel_config - self.quant_config = quant_config + self.config: PretrainedConfig = vllm_config.model_config.hf_config + self.text_config: PretrainedConfig = self.config.get_text_config() + self.cache_config: CacheConfig = vllm_config.cache_config + self.device_config: DeviceConfig = vllm_config.device_config + self.model_config: ModelConfig = vllm_config.model_config + self.parallel_config: ParallelConfig = vllm_config.parallel_config + self.quant_config: QuantizationConfig = vllm_config.quant_config self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size self.pp_rank = self.pp_group.rank_in_group self.tp_size = get_tensor_model_parallel_world_size() + # To be updated in child classes for use in `load_weights` + self.skip_prefixes: Optional[list[str]] = None + # vLLM handles interleaved sliding window attention by creating a new # interleaved_sliding_window attribute and deleting the sliding_window # attribute. This breaks the constructors in Transformers so we # temporarily add the attribute back to construct the model. config_override = nullcontext() - if hasattr(config, "interleaved_sliding_window"): + if hasattr(self.config, "interleaved_sliding_window"): config_override = ConfigOverride( - config, sliding_window=config.interleaved_sliding_window) + self.config, + sliding_window=self.config.interleaved_sliding_window) # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -455,23 +454,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"), config_override: self.model: PreTrainedModel = AutoModel.from_config( - config, - torch_dtype=model_config.dtype, - trust_remote_code=model_config.trust_remote_code, + self.config, + torch_dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, ) self.pipeline_parallel() self.tensor_parallel() # Input embeddings - text_config = config.get_text_config() if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): self.model.set_input_embeddings( VocabParallelEmbedding( - text_config.vocab_size, - text_config.hidden_size, - org_num_embeddings=text_config.vocab_size, - quant_config=quant_config, + self.text_config.vocab_size, + self.text_config.hidden_size, + org_num_embeddings=self.text_config.vocab_size, + quant_config=self.quant_config, )) # Attention layers @@ -481,8 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.init_parameters(self.model) self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - text_config.hidden_size)) + make_empty_intermediate_tensors_factory( + ["hidden_states"], self.text_config.hidden_size)) def pipeline_parallel(self): """ @@ -654,78 +652,40 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - - loaded_params = set[str]() - for name, loaded_weight in weights: - # Use "model" instead of base_model_prefix because - # the base model attribute in vLLM is always `model` - if not name.startswith(prefix := "model."): - name = prefix + name - - if is_pp_missing_parameter(name, self): - continue - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @support_torch_compile -class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, - SupportsPP): - embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens" - ] # TODO transformers will have a util to get it +class TransformersForCausalLM(TransformersBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config: PretrainedConfig = vllm_config.model_config.hf_config - quant_config: QuantizationConfig = vllm_config.quant_config - - self.config = config + super().__init__(vllm_config=vllm_config, prefix=prefix) - self.transformers_model = TransformersModel(vllm_config=vllm_config, - prefix=prefix) - self.model = self.transformers_model.model + # Tell `TransformersBase.load_weights` to skip + # `lm_head` if the model has tied word embeddings + if self.text_config.tie_word_embeddings: + self.skip_prefixes = ["lm_head."] if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size + self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, + self.text_config.vocab_size, + self.text_config.hidden_size, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - if config.tie_word_embeddings: + if self.text_config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.get_input_embeddings()) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + logit_scale = getattr(self.text_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.text_config.vocab_size, + logit_scale) else: self.lm_head = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - self.transformers_model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.transformers_model.forward(input_ids, positions, - intermediate_tensors, - inputs_embeds) - return model_output - def compute_logits( self, hidden_states: torch.Tensor, @@ -735,23 +695,12 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - skip_prefixes = ["lm_head." - ] if self.config.tie_word_embeddings else None - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights) - @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) -class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, - SupportsPP, SupportsMultiModal): - embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens"] - +class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is hf_to_vllm_mapper = WeightsMapper( @@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config: PretrainedConfig = vllm_config.model_config.hf_config - quant_config: QuantizationConfig = vllm_config.quant_config + super().__init__(vllm_config=vllm_config, prefix=prefix) - self.config = config self.dtype = vllm_config.model_config.dtype - self.transformers_model = TransformersModel(vllm_config=vllm_config, - prefix=prefix) - self.model = self.transformers_model.model - text_config = config.get_text_config() - - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = text_config.vocab_size - self.lm_head = ParallelLMHead( - text_config.vocab_size, - text_config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if text_config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings()) - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - text_config.vocab_size, - logit_scale) - else: - self.lm_head = PPMissingLayer() - - self.make_empty_intermediate_tensors = ( - self.transformers_model.make_empty_intermediate_tensors) - def forward( self, input_ids: Optional[torch.Tensor], @@ -828,30 +747,10 @@ def forward( input_ids, multimodal_embeds) input_ids = None - model_output = self.transformers_model.forward(input_ids, positions, - intermediate_tensors, - inputs_embeds) + model_output = super().forward(input_ids, positions, + intermediate_tensors, inputs_embeds) return model_output - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "lm_head." - ] if self.config.get_text_config().tie_word_embeddings else None), - ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def get_multimodal_embeddings(self, **kwargs): pixel_values = kwargs.pop("pixel_values", None) pixel_values = pixel_values if pixel_values is not None else kwargs.pop(