@@ -566,7 +566,7 @@ def __init__(
566
566
self .decoder_layers = decoder_layers
567
567
self .layer_idx_start = layer_idx_start
568
568
self .per_layer_model_projection = per_layer_model_projection
569
- self .config = vllm_config .model_config .hf_config . text_config
569
+ self .config = vllm_config .model_config .hf_config
570
570
self .embed_scale_per_layer = embed_scale_per_layer
571
571
self .embed_tokens_per_layer = embed_tokens_per_layer
572
572
self .per_layer_projection_norm = per_layer_projection_norm
@@ -590,13 +590,9 @@ def get_per_layer_input_embeddings(
590
590
591
591
def get_per_layer_inputs (
592
592
self ,
593
- input_ids : torch .Tensor ,
594
593
hidden_states_0 : torch .Tensor ,
594
+ per_layer_inputs : Optional [torch .Tensor ],
595
595
) -> torch .Tensor :
596
- per_layer_inputs = self .get_per_layer_input_embeddings (input_ids )
597
- per_layer_inputs = per_layer_inputs .reshape (
598
- - 1 , self .config .num_hidden_layers ,
599
- self .config .hidden_size_per_layer_input )
600
596
per_layer_projection = self .per_layer_model_projection (hidden_states_0 )
601
597
per_layer_projection = per_layer_projection .reshape (
602
598
* hidden_states_0 .shape [:- 1 ],
@@ -605,8 +601,12 @@ def get_per_layer_inputs(
605
601
)
606
602
per_layer_projection = self .per_layer_projection_norm (
607
603
per_layer_projection )
608
- per_layer_inputs = per_layer_projection + per_layer_inputs
609
- per_layer_inputs *= self .per_layer_input_scale
604
+ if per_layer_inputs is not None :
605
+ # Profiling run does not compute per_layer_inputs
606
+ per_layer_inputs = per_layer_projection + per_layer_inputs
607
+ per_layer_inputs *= self .per_layer_input_scale
608
+ else :
609
+ per_layer_inputs = per_layer_projection
610
610
return per_layer_inputs
611
611
612
612
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
@@ -632,15 +632,16 @@ def forward(
632
632
input_ids : torch .Tensor ,
633
633
positions : torch .Tensor ,
634
634
inputs_embeds : Optional [torch .Tensor ] = None ,
635
+ per_layer_inputs : Optional [torch .Tensor ] = None ,
635
636
** kwargs ,
636
637
) -> tuple [torch .Tensor , torch .Tensor ]:
637
638
if inputs_embeds is not None :
638
639
hidden_states_0 = inputs_embeds
639
640
else :
640
641
hidden_states_0 = self .get_input_embeddings (input_ids )
641
642
642
- per_layer_inputs = self .get_per_layer_inputs (input_ids ,
643
- hidden_states_0 )
643
+ adjusted_per_layer_inputs = self .get_per_layer_inputs (
644
+ hidden_states_0 , per_layer_inputs )
644
645
hidden_states = self .altup_embed (hidden_states_0 )
645
646
646
647
# [altnum_inputs, num_tokens, hidden_size]
@@ -652,14 +653,14 @@ def forward(
652
653
hidden_states = layer (
653
654
positions = positions ,
654
655
hidden_states = hidden_states ,
655
- per_layer_input = per_layer_inputs [:, layer_idx , :],
656
+ per_layer_input = adjusted_per_layer_inputs [:, layer_idx , :],
656
657
** kwargs ,
657
658
)
658
659
659
660
# [num_tokens, hidden_size, altnum_inputs]
660
661
hidden_states = hidden_states .permute (1 , 2 , 0 )
661
662
662
- return hidden_states , per_layer_inputs
663
+ return hidden_states , adjusted_per_layer_inputs
663
664
664
665
665
666
# This enables torch.compile if --kv-sharing-fast-prefill passed
@@ -853,6 +854,7 @@ def fast_prefill_forward(
853
854
input_ids : torch .Tensor ,
854
855
positions : torch .Tensor ,
855
856
inputs_embeds : Optional [torch .Tensor ] = None ,
857
+ per_layer_inputs : Optional [torch .Tensor ] = None ,
856
858
** kwargs ,
857
859
) -> torch .Tensor :
858
860
logits_indices_padded , num_logits_indices = None , None
@@ -873,13 +875,14 @@ def fast_prefill_forward(
873
875
# Copy inputs for cudagraph
874
876
batch_size = positions .size (0 )
875
877
self .positions [:batch_size ].copy_ (positions )
876
- # input_ids and inputs_embeds are allocated in model runner
877
- self_decoder_hidden_states , per_layer_inputs = self .self_decoder (
878
- input_ids = input_ids ,
879
- positions = self .positions [:batch_size ],
880
- inputs_embeds = inputs_embeds ,
881
- ** kwargs ,
882
- )
878
+ self_decoder_hidden_states , per_layer_inputs_adjusted = \
879
+ self .self_decoder (
880
+ input_ids = input_ids ,
881
+ positions = self .positions [:batch_size ],
882
+ inputs_embeds = inputs_embeds ,
883
+ per_layer_inputs = per_layer_inputs ,
884
+ ** kwargs ,
885
+ )
883
886
884
887
if logits_indices_padded is None :
885
888
logits_indices_padded = torch .arange (
@@ -903,7 +906,7 @@ def fast_prefill_forward(
903
906
self .hidden_states [:num_padded_logits_indices ].copy_ (
904
907
self_decoder_hidden_states [logits_indices_padded ])
905
908
self .per_layer_inputs [:num_padded_logits_indices ].copy_ (
906
- per_layer_inputs [logits_indices_padded ])
909
+ per_layer_inputs_adjusted [logits_indices_padded ])
907
910
cross_decoder_hidden_states = self .cross_decoder (
908
911
positions = self .positions [:num_padded_logits_indices ],
909
912
hidden_states = self .hidden_states [:num_padded_logits_indices ],
@@ -926,12 +929,14 @@ def normal_forward(
926
929
input_ids : torch .Tensor ,
927
930
positions : torch .Tensor ,
928
931
inputs_embeds : Optional [torch .Tensor ] = None ,
932
+ per_layer_inputs : Optional [torch .Tensor ] = None ,
929
933
** kwargs ,
930
934
) -> torch .Tensor :
931
935
hidden_states , per_layer_inputs = self .self_decoder (
932
936
input_ids = input_ids ,
933
937
positions = positions ,
934
938
inputs_embeds = inputs_embeds ,
939
+ per_layer_inputs = per_layer_inputs ,
935
940
** kwargs ,
936
941
)
937
942
hidden_states = self .cross_decoder (
@@ -966,25 +971,25 @@ def forward(
966
971
self ,
967
972
input_ids : Optional [torch .Tensor ],
968
973
positions : torch .Tensor ,
974
+ per_layer_inputs : Optional [torch .Tensor ] = None ,
975
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
969
976
inputs_embeds : Optional [torch .Tensor ] = None ,
970
977
** kwargs ,
971
978
) -> Union [torch .Tensor , IntermediateTensors ]:
972
- # Per layer inputs.
973
- if input_ids is None :
974
- raise ValueError ("Passing None for input ids is not supported." )
975
-
976
979
if self .fast_prefill_enabled :
977
980
hidden_states = self .fast_prefill_forward (
978
981
input_ids ,
979
982
positions ,
980
983
inputs_embeds ,
984
+ per_layer_inputs ,
981
985
** kwargs ,
982
986
)
983
987
else :
984
988
hidden_states = self .normal_forward (
985
989
input_ids ,
986
990
positions ,
987
991
inputs_embeds ,
992
+ per_layer_inputs ,
988
993
** kwargs ,
989
994
)
990
995
hidden_states = self .altup_unembed (hidden_states )
0 commit comments