Skip to content

Commit 271f14c

Browse files
committed
Fix rebase
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent dd98cce commit 271f14c

File tree

5 files changed

+39
-33
lines changed

5 files changed

+39
-33
lines changed

vllm/attention/layers/chunked_local_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def __init__(self,
4040
kv_cache_dtype,
4141
block_size)
4242

43-
prefix = \
43+
backend_prefix = \
4444
f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
4545

4646
def build_preprocess_fn(cm: CommonAttentionMetadata):
4747
return make_local_attention_virtual_batches(
4848
attention_chunk_size, cm, block_size)
4949

5050
attn_backend = create_custom_attention_backend(
51-
prefix, underlying_attn_backend, build_preprocess_fn)
51+
backend_prefix, underlying_attn_backend, build_preprocess_fn)
5252
else:
5353
# in v0 the local attention is handled inside the backends
5454
attn_backend = None

vllm/config/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3683,6 +3683,12 @@ def __post_init__(self):
36833683
# local attention.
36843684
self.scheduler_config.disable_hybrid_kv_cache_manager = True
36853685

3686+
if self.cache_config.kv_sharing_fast_prefill:
3687+
# There is an IMA issue currently when using fast prefill with
3688+
# hybrid kv cache manager (e.g. interleaved sliding window)
3689+
# TODO(sarckk): investigate and fix
3690+
self.scheduler_config.disable_hybrid_kv_cache_manager = True
3691+
36863692
def update_sizes_for_sequence_parallelism(self,
36873693
possible_sizes: list) -> list:
36883694
# remove the sizes that not multiple of tp_size when

vllm/config/cache.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def metrics_info(self):
151151
# convert cache_config to dict(key: str, value: str) for prometheus
152152
# metrics info
153153
return {key: str(value) for key, value in self.__dict__.items()}
154-
154+
155155
def _verify_kv_sharing_fast_prefill(self) -> None:
156156
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
157157
raise NotImplementedError(
@@ -169,11 +169,6 @@ def _verify_args(self) -> Self:
169169
"GPU memory utilization must be less than 1.0. Got "
170170
f"{self.gpu_memory_utilization}.")
171171

172-
if self.kv_sharing_fast_prefill:
173-
logger.warning_once(
174-
"--kv-sharing-fast-prefill is currently work in progress "
175-
"and not functional yet (i.e. no prefill savings)")
176-
177172
return self
178173

179174
def _verify_cache_dtype(self) -> None:

vllm/model_executor/models/gemma3n.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def __init__(
566566
self.decoder_layers = decoder_layers
567567
self.layer_idx_start = layer_idx_start
568568
self.per_layer_model_projection = per_layer_model_projection
569-
self.config = vllm_config.model_config.hf_config.text_config
569+
self.config = vllm_config.model_config.hf_config
570570
self.embed_scale_per_layer = embed_scale_per_layer
571571
self.embed_tokens_per_layer = embed_tokens_per_layer
572572
self.per_layer_projection_norm = per_layer_projection_norm
@@ -590,13 +590,9 @@ def get_per_layer_input_embeddings(
590590

591591
def get_per_layer_inputs(
592592
self,
593-
input_ids: torch.Tensor,
594593
hidden_states_0: torch.Tensor,
594+
per_layer_inputs: Optional[torch.Tensor],
595595
) -> torch.Tensor:
596-
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
597-
per_layer_inputs = per_layer_inputs.reshape(
598-
-1, self.config.num_hidden_layers,
599-
self.config.hidden_size_per_layer_input)
600596
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
601597
per_layer_projection = per_layer_projection.reshape(
602598
*hidden_states_0.shape[:-1],
@@ -605,8 +601,12 @@ def get_per_layer_inputs(
605601
)
606602
per_layer_projection = self.per_layer_projection_norm(
607603
per_layer_projection)
608-
per_layer_inputs = per_layer_projection + per_layer_inputs
609-
per_layer_inputs *= self.per_layer_input_scale
604+
if per_layer_inputs is not None:
605+
# Profiling run does not compute per_layer_inputs
606+
per_layer_inputs = per_layer_projection + per_layer_inputs
607+
per_layer_inputs *= self.per_layer_input_scale
608+
else:
609+
per_layer_inputs = per_layer_projection
610610
return per_layer_inputs
611611

612612
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -632,15 +632,16 @@ def forward(
632632
input_ids: torch.Tensor,
633633
positions: torch.Tensor,
634634
inputs_embeds: Optional[torch.Tensor] = None,
635+
per_layer_inputs: Optional[torch.Tensor] = None,
635636
**kwargs,
636637
) -> tuple[torch.Tensor, torch.Tensor]:
637638
if inputs_embeds is not None:
638639
hidden_states_0 = inputs_embeds
639640
else:
640641
hidden_states_0 = self.get_input_embeddings(input_ids)
641642

642-
per_layer_inputs = self.get_per_layer_inputs(input_ids,
643-
hidden_states_0)
643+
adjusted_per_layer_inputs = self.get_per_layer_inputs(
644+
hidden_states_0, per_layer_inputs)
644645
hidden_states = self.altup_embed(hidden_states_0)
645646

646647
# [altnum_inputs, num_tokens, hidden_size]
@@ -652,14 +653,14 @@ def forward(
652653
hidden_states = layer(
653654
positions=positions,
654655
hidden_states=hidden_states,
655-
per_layer_input=per_layer_inputs[:, layer_idx, :],
656+
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
656657
**kwargs,
657658
)
658659

659660
# [num_tokens, hidden_size, altnum_inputs]
660661
hidden_states = hidden_states.permute(1, 2, 0)
661662

662-
return hidden_states, per_layer_inputs
663+
return hidden_states, adjusted_per_layer_inputs
663664

664665

665666
# This enables torch.compile if --kv-sharing-fast-prefill passed
@@ -853,6 +854,7 @@ def fast_prefill_forward(
853854
input_ids: torch.Tensor,
854855
positions: torch.Tensor,
855856
inputs_embeds: Optional[torch.Tensor] = None,
857+
per_layer_inputs: Optional[torch.Tensor] = None,
856858
**kwargs,
857859
) -> torch.Tensor:
858860
logits_indices_padded, num_logits_indices = None, None
@@ -873,13 +875,14 @@ def fast_prefill_forward(
873875
# Copy inputs for cudagraph
874876
batch_size = positions.size(0)
875877
self.positions[:batch_size].copy_(positions)
876-
# input_ids and inputs_embeds are allocated in model runner
877-
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
878-
input_ids=input_ids,
879-
positions=self.positions[:batch_size],
880-
inputs_embeds=inputs_embeds,
881-
**kwargs,
882-
)
878+
self_decoder_hidden_states, per_layer_inputs_adjusted = \
879+
self.self_decoder(
880+
input_ids=input_ids,
881+
positions=self.positions[:batch_size],
882+
inputs_embeds=inputs_embeds,
883+
per_layer_inputs=per_layer_inputs,
884+
**kwargs,
885+
)
883886

884887
if logits_indices_padded is None:
885888
logits_indices_padded = torch.arange(
@@ -903,7 +906,7 @@ def fast_prefill_forward(
903906
self.hidden_states[:num_padded_logits_indices].copy_(
904907
self_decoder_hidden_states[logits_indices_padded])
905908
self.per_layer_inputs[:num_padded_logits_indices].copy_(
906-
per_layer_inputs[logits_indices_padded])
909+
per_layer_inputs_adjusted[logits_indices_padded])
907910
cross_decoder_hidden_states = self.cross_decoder(
908911
positions=self.positions[:num_padded_logits_indices],
909912
hidden_states=self.hidden_states[:num_padded_logits_indices],
@@ -926,12 +929,14 @@ def normal_forward(
926929
input_ids: torch.Tensor,
927930
positions: torch.Tensor,
928931
inputs_embeds: Optional[torch.Tensor] = None,
932+
per_layer_inputs: Optional[torch.Tensor] = None,
929933
**kwargs,
930934
) -> torch.Tensor:
931935
hidden_states, per_layer_inputs = self.self_decoder(
932936
input_ids=input_ids,
933937
positions=positions,
934938
inputs_embeds=inputs_embeds,
939+
per_layer_inputs=per_layer_inputs,
935940
**kwargs,
936941
)
937942
hidden_states = self.cross_decoder(
@@ -966,25 +971,25 @@ def forward(
966971
self,
967972
input_ids: Optional[torch.Tensor],
968973
positions: torch.Tensor,
974+
per_layer_inputs: Optional[torch.Tensor] = None,
975+
intermediate_tensors: Optional[IntermediateTensors] = None,
969976
inputs_embeds: Optional[torch.Tensor] = None,
970977
**kwargs,
971978
) -> Union[torch.Tensor, IntermediateTensors]:
972-
# Per layer inputs.
973-
if input_ids is None:
974-
raise ValueError("Passing None for input ids is not supported.")
975-
976979
if self.fast_prefill_enabled:
977980
hidden_states = self.fast_prefill_forward(
978981
input_ids,
979982
positions,
980983
inputs_embeds,
984+
per_layer_inputs,
981985
**kwargs,
982986
)
983987
else:
984988
hidden_states = self.normal_forward(
985989
input_ids,
986990
positions,
987991
inputs_embeds,
992+
per_layer_inputs,
988993
**kwargs,
989994
)
990995
hidden_states = self.altup_unembed(hidden_states)

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def get_input_embeddings(
624624
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
625625
# them here, as the model forward has only access to the input_embeds.
626626
if input_ids is not None:
627-
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
627+
per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings(
628628
input_ids)
629629
per_layer_inputs = per_layer_inputs.reshape(
630630
-1, self.config.text_config.num_hidden_layers,

0 commit comments

Comments
 (0)