39
39
from vllm .model_executor .layers .quantization import QuantizationConfig
40
40
from vllm .model_executor .layers .vocab_parallel_embedding import (
41
41
ParallelLMHead , VocabParallelEmbedding )
42
- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
43
42
from vllm .model_executor .sampling_metadata import SamplingMetadata
44
43
from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
45
44
from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
55
54
from .interfaces import (SupportsLoRA , SupportsMultiModal , SupportsPP ,
56
55
SupportsQuant )
57
56
from .utils import (AutoWeightsLoader , PPMissingLayer , WeightsMapper ,
58
- flatten_bn , is_pp_missing_parameter ,
59
- make_empty_intermediate_tensors_factory , maybe_prefix )
57
+ flatten_bn , make_empty_intermediate_tensors_factory ,
58
+ maybe_prefix )
60
59
61
60
logger = init_logger (__name__ )
62
61
@@ -414,64 +413,63 @@ def __exit__(self, exc_type, exc_value, traceback):
414
413
setattr (self .config , key , value )
415
414
416
415
417
- class TransformersModel :
416
+ class TransformersBase (nn .Module , SupportsQuant , SupportsLoRA , SupportsPP ):
417
+ embedding_padding_modules = ["lm_head" ]
418
+ embedding_modules = ["embed_tokens"
419
+ ] # TODO transformers will have a util to get it
418
420
419
421
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
420
422
super ().__init__ ()
421
423
logger .info ("Using Transformers backend." )
422
424
423
- config : PretrainedConfig = vllm_config .model_config .hf_config
424
- cache_config : CacheConfig = vllm_config .cache_config
425
- device_config : DeviceConfig = vllm_config .device_config
426
- model_config : ModelConfig = vllm_config .model_config
427
- parallel_config : ParallelConfig = vllm_config .parallel_config
428
- quant_config : QuantizationConfig = vllm_config .quant_config
429
-
430
- self .config = config
431
- self .text_config = config .get_text_config ()
432
- self .cache_config = cache_config
433
- self .device_config = device_config
434
- self .model_config = model_config
435
- self .parallel_config = parallel_config
436
- self .quant_config = quant_config
425
+ self .config : PretrainedConfig = vllm_config .model_config .hf_config
426
+ self .text_config : PretrainedConfig = self .config .get_text_config ()
427
+ self .cache_config : CacheConfig = vllm_config .cache_config
428
+ self .device_config : DeviceConfig = vllm_config .device_config
429
+ self .model_config : ModelConfig = vllm_config .model_config
430
+ self .parallel_config : ParallelConfig = vllm_config .parallel_config
431
+ self .quant_config : QuantizationConfig = vllm_config .quant_config
437
432
438
433
self .pp_group = get_pp_group ()
439
434
self .pp_size = self .pp_group .world_size
440
435
self .pp_rank = self .pp_group .rank_in_group
441
436
self .tp_size = get_tensor_model_parallel_world_size ()
442
437
438
+ # To be updated in child classes for use in `load_weights`
439
+ self .skip_prefixes : Optional [list [str ]] = None
440
+
443
441
# vLLM handles interleaved sliding window attention by creating a new
444
442
# interleaved_sliding_window attribute and deleting the sliding_window
445
443
# attribute. This breaks the constructors in Transformers so we
446
444
# temporarily add the attribute back to construct the model.
447
445
config_override = nullcontext ()
448
- if hasattr (config , "interleaved_sliding_window" ):
446
+ if hasattr (self . config , "interleaved_sliding_window" ):
449
447
config_override = ConfigOverride (
450
- config , sliding_window = config .interleaved_sliding_window )
448
+ self .config ,
449
+ sliding_window = self .config .interleaved_sliding_window )
451
450
452
451
# Set correct attn and init on "meta" to delay allocating GPU tensors
453
452
# TODO: @raushan, use the public `model.set_attn_implementation()`
454
453
# method after v4.54.0 is released
455
454
self .text_config ._attn_implementation = "vllm"
456
455
with init_on_device_without_buffers ("meta" ), config_override :
457
456
self .model : PreTrainedModel = AutoModel .from_config (
458
- config ,
459
- torch_dtype = model_config .dtype ,
460
- trust_remote_code = model_config .trust_remote_code ,
457
+ self . config ,
458
+ torch_dtype = self . model_config .dtype ,
459
+ trust_remote_code = self . model_config .trust_remote_code ,
461
460
)
462
461
463
462
self .pipeline_parallel ()
464
463
self .tensor_parallel ()
465
464
466
465
# Input embeddings
467
- text_config = config .get_text_config ()
468
466
if not isinstance (self .model .get_input_embeddings (), PPMissingLayer ):
469
467
self .model .set_input_embeddings (
470
468
VocabParallelEmbedding (
471
- text_config .vocab_size ,
472
- text_config .hidden_size ,
473
- org_num_embeddings = text_config .vocab_size ,
474
- quant_config = quant_config ,
469
+ self . text_config .vocab_size ,
470
+ self . text_config .hidden_size ,
471
+ org_num_embeddings = self . text_config .vocab_size ,
472
+ quant_config = self . quant_config ,
475
473
))
476
474
477
475
# Attention layers
@@ -481,8 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
481
479
self .init_parameters (self .model )
482
480
483
481
self .make_empty_intermediate_tensors = (
484
- make_empty_intermediate_tensors_factory ([ "hidden_states" ],
485
- text_config .hidden_size ))
482
+ make_empty_intermediate_tensors_factory (
483
+ [ "hidden_states" ], self . text_config .hidden_size ))
486
484
487
485
def pipeline_parallel (self ):
488
486
"""
@@ -654,78 +652,40 @@ def forward(
654
652
655
653
def load_weights (self , weights : Iterable [tuple [str ,
656
654
torch .Tensor ]]) -> set [str ]:
657
- params_dict = dict (self .named_parameters ())
658
-
659
- loaded_params = set [str ]()
660
- for name , loaded_weight in weights :
661
- # Use "model" instead of base_model_prefix because
662
- # the base model attribute in vLLM is always `model`
663
- if not name .startswith (prefix := "model." ):
664
- name = prefix + name
665
-
666
- if is_pp_missing_parameter (name , self ):
667
- continue
668
- if name in params_dict :
669
- param = params_dict [name ]
670
- weight_loader = getattr (param , "weight_loader" ,
671
- default_weight_loader )
672
- weight_loader (param , loaded_weight )
673
- loaded_params .add (name )
674
- return loaded_params
655
+ loader = AutoWeightsLoader (self , skip_prefixes = self .skip_prefixes )
656
+ return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
675
657
676
658
677
659
@support_torch_compile
678
- class TransformersForCausalLM (nn .Module , SupportsQuant , SupportsLoRA ,
679
- SupportsPP ):
680
- embedding_padding_modules = ["lm_head" ]
681
- embedding_modules = ["embed_tokens"
682
- ] # TODO transformers will have a util to get it
660
+ class TransformersForCausalLM (TransformersBase ):
683
661
684
662
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
685
- super ().__init__ ()
686
- config : PretrainedConfig = vllm_config .model_config .hf_config
687
- quant_config : QuantizationConfig = vllm_config .quant_config
688
-
689
- self .config = config
663
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
690
664
691
- self .transformers_model = TransformersModel (vllm_config = vllm_config ,
692
- prefix = prefix )
693
- self .model = self .transformers_model .model
665
+ # Tell `TransformersBase.load_weights` to skip
666
+ # `lm_head` if the model has tied word embeddings
667
+ if self .text_config .tie_word_embeddings :
668
+ self .skip_prefixes = ["lm_head." ]
694
669
695
670
if get_pp_group ().is_last_rank :
696
- self .unpadded_vocab_size = config .vocab_size
671
+ self .unpadded_vocab_size = self . text_config .vocab_size
697
672
self .lm_head = ParallelLMHead (
698
- config .vocab_size ,
699
- config .hidden_size ,
700
- quant_config = quant_config ,
673
+ self . text_config .vocab_size ,
674
+ self . text_config .hidden_size ,
675
+ quant_config = self . quant_config ,
701
676
prefix = maybe_prefix (prefix , "lm_head" ),
702
677
)
703
- if config .tie_word_embeddings :
678
+ if self . text_config .tie_word_embeddings :
704
679
self .lm_head = self .lm_head .tie_weights (
705
680
self .model .get_input_embeddings ())
706
681
707
- logit_scale = getattr (config , "logit_scale" , 1.0 )
708
- self .logits_processor = LogitsProcessor (self . unpadded_vocab_size ,
709
- config .vocab_size ,
710
- logit_scale )
682
+ logit_scale = getattr (self . text_config , "logit_scale" , 1.0 )
683
+ self .logits_processor = LogitsProcessor (
684
+ self . unpadded_vocab_size , self . text_config .vocab_size ,
685
+ logit_scale )
711
686
else :
712
687
self .lm_head = PPMissingLayer ()
713
688
714
- self .make_empty_intermediate_tensors = (
715
- self .transformers_model .make_empty_intermediate_tensors )
716
-
717
- def forward (
718
- self ,
719
- input_ids : Optional [torch .Tensor ],
720
- positions : torch .Tensor ,
721
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
722
- inputs_embeds : Optional [torch .Tensor ] = None ,
723
- ) -> Union [torch .Tensor , IntermediateTensors ]:
724
- model_output = self .transformers_model .forward (input_ids , positions ,
725
- intermediate_tensors ,
726
- inputs_embeds )
727
- return model_output
728
-
729
689
def compute_logits (
730
690
self ,
731
691
hidden_states : torch .Tensor ,
@@ -735,23 +695,12 @@ def compute_logits(
735
695
sampling_metadata )
736
696
return logits
737
697
738
- def load_weights (self , weights : Iterable [tuple [str ,
739
- torch .Tensor ]]) -> set [str ]:
740
- skip_prefixes = ["lm_head."
741
- ] if self .config .tie_word_embeddings else None
742
- loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
743
- return loader .load_weights (weights )
744
-
745
698
746
699
@MULTIMODAL_REGISTRY .register_processor (
747
700
MultiModalProcessor ,
748
701
info = MultiModalProcessingInfo ,
749
702
dummy_inputs = MultiModalDummyInputsBuilder )
750
- class TransformersForMultimodalLM (nn .Module , SupportsQuant , SupportsLoRA ,
751
- SupportsPP , SupportsMultiModal ):
752
- embedding_padding_modules = ["lm_head" ]
753
- embedding_modules = ["embed_tokens" ]
754
-
703
+ class TransformersForMultimodalLM (TransformersForCausalLM , SupportsMultiModal ):
755
704
# Backwards compatibility for prev released models. State dicts back then
756
705
# had different formats and cannot be loaded with `AutoModel` mapping as is
757
706
hf_to_vllm_mapper = WeightsMapper (
@@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
776
725
})
777
726
778
727
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
779
- super ().__init__ ()
780
- config : PretrainedConfig = vllm_config .model_config .hf_config
781
- quant_config : QuantizationConfig = vllm_config .quant_config
728
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
782
729
783
- self .config = config
784
730
self .dtype = vllm_config .model_config .dtype
785
731
786
- self .transformers_model = TransformersModel (vllm_config = vllm_config ,
787
- prefix = prefix )
788
- self .model = self .transformers_model .model
789
- text_config = config .get_text_config ()
790
-
791
- if get_pp_group ().is_last_rank :
792
- self .unpadded_vocab_size = text_config .vocab_size
793
- self .lm_head = ParallelLMHead (
794
- text_config .vocab_size ,
795
- text_config .hidden_size ,
796
- quant_config = quant_config ,
797
- prefix = maybe_prefix (prefix , "lm_head" ),
798
- )
799
- if text_config .tie_word_embeddings :
800
- self .lm_head = self .lm_head .tie_weights (
801
- self .model .get_input_embeddings ())
802
-
803
- logit_scale = getattr (config , "logit_scale" , 1.0 )
804
- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
805
- text_config .vocab_size ,
806
- logit_scale )
807
- else :
808
- self .lm_head = PPMissingLayer ()
809
-
810
- self .make_empty_intermediate_tensors = (
811
- self .transformers_model .make_empty_intermediate_tensors )
812
-
813
732
def forward (
814
733
self ,
815
734
input_ids : Optional [torch .Tensor ],
@@ -828,30 +747,10 @@ def forward(
828
747
input_ids , multimodal_embeds )
829
748
input_ids = None
830
749
831
- model_output = self .transformers_model .forward (input_ids , positions ,
832
- intermediate_tensors ,
833
- inputs_embeds )
750
+ model_output = super ().forward (input_ids , positions ,
751
+ intermediate_tensors , inputs_embeds )
834
752
return model_output
835
753
836
- def compute_logits (
837
- self ,
838
- hidden_states : torch .Tensor ,
839
- sampling_metadata : SamplingMetadata ,
840
- ) -> Optional [torch .Tensor ]:
841
- logits = self .logits_processor (self .lm_head , hidden_states ,
842
- sampling_metadata )
843
- return logits
844
-
845
- def load_weights (self , weights : Iterable [tuple [str ,
846
- torch .Tensor ]]) -> set [str ]:
847
- loader = AutoWeightsLoader (
848
- self ,
849
- skip_prefixes = ([
850
- "lm_head."
851
- ] if self .config .get_text_config ().tie_word_embeddings else None ),
852
- )
853
- return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
854
-
855
754
def get_multimodal_embeddings (self , ** kwargs ):
856
755
pixel_values = kwargs .pop ("pixel_values" , None )
857
756
pixel_values = pixel_values if pixel_values is not None else kwargs .pop (
0 commit comments