From 3dc9817c432e4a2ad9a693f8dcfff4e89aa1c75c Mon Sep 17 00:00:00 2001 From: w-yyh <1799470752@qq.com> Date: Thu, 27 Nov 2025 15:39:41 +0800 Subject: [PATCH 1/2] add --- .../transformers/gemma3_text/modeling.py | 52 ++++++------- paddleformers/transformers/llama/modeling.py | 59 +++++---------- paddleformers/transformers/phi3/modeling.py | 52 +++++-------- .../transformers/qwen2_5_vl/modeling.py | 73 ++++++++----------- paddleformers/transformers/qwen3/modeling.py | 2 +- 5 files changed, 91 insertions(+), 147 deletions(-) diff --git a/paddleformers/transformers/gemma3_text/modeling.py b/paddleformers/transformers/gemma3_text/modeling.py index 2d544fa43a..008bef8196 100644 --- a/paddleformers/transformers/gemma3_text/modeling.py +++ b/paddleformers/transformers/gemma3_text/modeling.py @@ -28,6 +28,7 @@ from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger from ..activations import ACT2FN +from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_masks_and_row_indices from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel @@ -274,7 +275,7 @@ def forward( hidden_states: paddle.Tensor, position_embeddings: Tuple[paddle.Tensor, paddle.Tensor], attention_mask: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, position_ids: Optional[Tuple[paddle.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, @@ -299,10 +300,8 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: - key_states = paddle.concat([past_key_value[0], key_states], axis=2) - value_states = paddle.concat([past_key_value[1], value_states], axis=2) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) if attn_mask_startend_row_indices is None and attention_mask is None: self.attn_implementation = "sdpa" @@ -327,7 +326,7 @@ def forward( attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Gemma3DecoderLayer(nn.Layer): @@ -349,7 +348,7 @@ def forward( hidden_states: paddle.Tensor, position_embeddings: Tuple[paddle.Tensor, paddle.Tensor], attention_mask: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, position_ids: Optional[paddle.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -360,11 +359,11 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, position_ids=position_ids, output_attentions=output_attentions, use_cache=use_cache, @@ -384,8 +383,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) if type(outputs) is tuple and len(outputs) == 1: outputs = outputs[0] @@ -551,7 +548,7 @@ def recompute_training( hidden_states: paddle.Tensor, position_ids: Optional[paddle.Tensor], attention_mask: paddle.Tensor, - past_key_value: paddle.Tensor, + past_key_values: Cache, output_attentions: bool, use_cache: bool, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, @@ -568,7 +565,7 @@ def custom_forward(*inputs): hidden_states, position_embeddings, attention_mask, - past_key_value, + past_key_values, position_ids, output_attentions, use_cache, @@ -583,7 +580,7 @@ def forward( attention_mask: Optional[paddle.Tensor] = None, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.LongTensor] = None, - past_key_values: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -613,11 +610,9 @@ def forward( # [bs, seq_len, dim] inputs_embeds = self.embed_tokens(input_ids) - cache_length = 0 - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - else: - cache_length = past_key_values[0][0].shape[-2] + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if self.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] @@ -660,12 +655,10 @@ def forward( hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: @@ -674,7 +667,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_value, + past_key_values=past_key_values, position_ids=position_ids, output_attentions=output_attentions, use_cache=use_cache, @@ -687,7 +680,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_value, + past_key_values=past_key_values, position_ids=position_ids, output_attentions=output_attentions, use_cache=use_cache, @@ -704,24 +697,21 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - # Final Norm hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - # Return outputs if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -812,7 +802,7 @@ def forward( input_ids: Optional[paddle.LongTensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.LongTensor] = None, - past_key_values: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.FloatTensor] = None, labels: Optional[paddle.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index c3a4b33dc6..65faf015bf 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -29,6 +29,7 @@ from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger +from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_masks_and_row_indices from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model @@ -144,7 +145,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: paddle.Tensor, - past_key_value: list[paddle.Tensor] | None = None, + past_key_values: Cache | None = None, attention_mask: paddle.Tensor | None = None, attn_mask_startend_row_indices: paddle.Tensor | None = None, position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None, @@ -166,16 +167,14 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - key_states = paddle.concat([past_key_value[0], key_states], axis=2) - value_states = paddle.concat([past_key_value[1], value_states], axis=2) - past_key_value = [key_states, value_states] if use_cache else None + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] if self.config._attn_implementation != "sdpa": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( + attn_output, attn_weights = attention_interface( self, query=query_states, key=key_states, @@ -188,7 +187,7 @@ def forward( if self.config.sequence_parallel: attn_output = attn_output.reshape([-1, attn_output.shape[-1]]) attn_output = self.o_proj(attn_output) - return attn_output, past_key_value + return attn_output, attn_weights class LlamaDecoderLayer(nn.Layer): @@ -222,22 +221,17 @@ def forward( attn_mask_startend_row_indices: paddle.Tensor | None = None, position_ids: paddle.Tensor | None = None, position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None, - past_key_value: list[paddle.Tensor] | None = None, + past_key_values: Cache | None = None, use_cache: bool = False, - ) -> ( - tuple[paddle.Tensor] - | tuple[paddle.Tensor, paddle.Tensor] - | tuple[paddle.Tensor, list[paddle.Tensor]] - | tuple[paddle.Tensor, paddle.Tensor, list[paddle.Tensor]] - ): + ) -> (tuple[paddle.Tensor] | tuple[paddle.Tensor, paddle.Tensor]): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, current_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, position_embeddings=position_embeddings, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = residual + hidden_states @@ -248,9 +242,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if use_cache: - outputs += (current_key_value,) - # for pipeline parallel if len(outputs) == 1 and isinstance(outputs, tuple): outputs = outputs[0] @@ -510,7 +501,7 @@ def forward( input_ids: paddle.Tensor | None = None, attention_mask: paddle.Tensor | None = None, position_ids: paddle.Tensor | None = None, - past_key_values: tuple[list[paddle.Tensor] | None] | None = None, + past_key_values: Cache | None = None, inputs_embeds: paddle.Tensor | None = None, attn_mask_startend_row_indices: paddle.Tensor | None = None, use_cache: bool | None = None, @@ -534,12 +525,9 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) inputs_embeds = ScatterOp.apply(inputs_embeds) - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - kv_seq_len = 0 - else: - assert past_key_values[0] is not None, "past_key_values[0] should not be None if provided" - kv_seq_len = past_key_values[0][0].shape[2] + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + kv_seq_len = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = ( @@ -563,11 +551,9 @@ def forward( all_hidden_states = [] if output_hidden_states else None hidden_states = inputs_embeds - next_key_values = [] if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states.append(hidden_states) - past_key_value: list[paddle.Tensor] | None = past_key_values[idx] # type: ignore[index] has_gradient = not hidden_states.stop_gradient if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: layer_outputs = self.recompute_training( @@ -577,7 +563,7 @@ def forward( attn_mask_startend_row_indices, position_ids, position_embeddings, - past_key_value, + past_key_values, use_cache, ) else: @@ -587,13 +573,11 @@ def forward( attn_mask_startend_row_indices=attn_mask_startend_row_indices, position_ids=position_ids, position_embeddings=position_embeddings, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple | list) else layer_outputs - if use_cache: - next_key_values.append(layer_outputs[1]) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -602,20 +586,17 @@ def forward( ) all_hidden_states = tuple(all_hidden_states) if all_hidden_states else None - next_key_values = tuple(next_key_values) if next_key_values else None if not return_dict: outputs = [] outputs.append(hidden_states) - if use_cache: - outputs.append(next_key_values) if output_hidden_states: outputs.append(all_hidden_states) return tuple(outputs) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_key_values, + past_key_values=past_key_values, hidden_states=all_hidden_states, ) @@ -628,7 +609,7 @@ def recompute_training( attn_mask_startend_row_indices: paddle.Tensor | None, position_ids: paddle.Tensor, position_embeddings: paddle.Tensor, - past_key_value: list[paddle.Tensor] | None, + past_key_values: Cache | None, use_cache: bool, ): hidden_states = recompute( @@ -638,7 +619,7 @@ def recompute_training( attn_mask_startend_row_indices, position_ids, position_embeddings, - past_key_value, + past_key_values, use_cache, ) return hidden_states @@ -665,7 +646,7 @@ def forward( labels: paddle.Tensor | None = None, loss_mask: paddle.Tensor | None = None, use_cache: bool = False, - past_key_values: tuple[list[paddle.Tensor]] | None = None, + past_key_values: Cache | None = None, output_hidden_states: bool | None = False, return_dict: bool = False, # true when decode, false when pretrain & eval **kwargs, diff --git a/paddleformers/transformers/phi3/modeling.py b/paddleformers/transformers/phi3/modeling.py index 576fc0c8c9..eda82aed42 100644 --- a/paddleformers/transformers/phi3/modeling.py +++ b/paddleformers/transformers/phi3/modeling.py @@ -14,7 +14,7 @@ """Paddle Phi3 model.""" from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import paddle from paddle import nn @@ -30,6 +30,7 @@ from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger +from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_masks_and_row_indices from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model @@ -110,7 +111,7 @@ def forward( hidden_states: paddle.Tensor, position_embeddings: tuple[paddle.Tensor, paddle.Tensor], attention_mask: Optional[paddle.Tensor], - past_key_value: Optional[paddle.Tensor] = None, + past_key_values: Optional[Cache] = None, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, batch_size: Optional[int] = None, use_cache: bool = False, @@ -148,10 +149,8 @@ def forward( value_states = value_states.transpose(1, 2) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - key_states = paddle.cat([past_key_value[0], key_states], axis=2) - value_states = paddle.cat([past_key_value[1], value_states], axis=2) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -174,7 +173,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Phi3DecoderLayer(nn.Layer): @@ -215,7 +214,7 @@ def forward( hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -224,12 +223,12 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, position_ids=position_ids, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_cache=use_cache, output_attentions=output_attentions, @@ -244,8 +243,6 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) if type(outputs) is tuple and len(outputs) == 1: outputs = outputs[0] return outputs @@ -416,7 +413,7 @@ def recompute_training_full( position_ids: Optional[paddle.Tensor], attention_mask: paddle.Tensor, output_attentions: bool, - past_key_value: paddle.Tensor, + past_key_values: Cache, use_cache: bool, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, attn_mask_startend_row_indices=None, @@ -432,7 +429,7 @@ def custom_forward(*inputs): hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, use_cache, position_embeddings, output_attentions, @@ -446,7 +443,7 @@ def forward( input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_values: Optional[List[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -482,11 +479,9 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) - cache_length = 0 - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - else: - cache_length = past_key_values[0][0].shape[1] + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) @@ -512,12 +507,10 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: layer_outputs = self.recompute_training_full( @@ -529,7 +522,7 @@ def forward( ], position_ids=position_ids, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, ) @@ -543,7 +536,7 @@ def forward( ], position_ids=position_ids, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, ) @@ -556,20 +549,15 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache] if v is not None) + return tuple(v for v in [hidden_states, past_key_values] if v is not None) - return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache) + return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) class Phi3ForCausalLM(Phi3PreTrainedModel): @@ -652,7 +640,7 @@ def forward( input_ids: Optional[paddle.LongTensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.LongTensor] = None, - past_key_values: Optional[List[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.FloatTensor] = None, labels: Optional[paddle.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/paddleformers/transformers/qwen2_5_vl/modeling.py b/paddleformers/transformers/qwen2_5_vl/modeling.py index 263c812414..2ab2b2df53 100644 --- a/paddleformers/transformers/qwen2_5_vl/modeling.py +++ b/paddleformers/transformers/qwen2_5_vl/modeling.py @@ -37,6 +37,7 @@ from ...nn.lm_head import LMHead as GeneralLMHead from ...nn.mlp import MLP from ...nn.norm import Norm as GeneralNorm +from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_masks_and_row_indices from ..model_outputs import BaseModelOutputWithPast, ModelOutput from ..model_utils import PretrainedModel @@ -673,10 +674,7 @@ def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> padd class Qwen2_5_VLModelOutputWithPast(ModelOutput): """ Args: - past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - + past_key_values (`Cache)`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -693,7 +691,7 @@ class Qwen2_5_VLModelOutputWithPast(ModelOutput): """ last_hidden_state: Optional[paddle.Tensor] = None - past_key_values: Optional[Tuple[paddle.Tensor]] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None rope_deltas: Optional[paddle.Tensor] = None @@ -872,7 +870,7 @@ def forward( hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, # default true position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, @@ -900,10 +898,8 @@ def forward( ) # [bs, num_head, seq_len, head_dim] - if past_key_value is not None: - key_states = paddle.cat([past_key_value[0], key_states], axis=2) - value_states = paddle.cat([past_key_value[1], value_states], axis=2) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -925,7 +921,7 @@ def forward( attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Qwen2_5_VLDecoderLayer(nn.Layer): @@ -967,7 +963,7 @@ def forward( hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_embeddings: Optional[tuple[paddle.Tensor, paddle.Tensor]] = None, @@ -979,14 +975,13 @@ def forward( hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`paddle.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): - If set to `True`, `past_key_value` key value states are returned and can be used to speed up decoding - (see `past_key_value`). - past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). position_embeddings (`Tuple[paddle.Tensor, paddle.Tensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1000,11 +995,11 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -1024,9 +1019,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1069,7 +1061,7 @@ def recompute_training_full( attention_mask: Tensor, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]], position_ids: Optional[paddle.Tensor], - past_key_value: Optional[Tuple[paddle.Tensor]], + past_key_values: Optional[Cache], output_attentions: bool, use_cache: bool, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, @@ -1087,7 +1079,7 @@ def custom_forward(*inputs): attention_mask, position_embeddings, position_ids, - past_key_value, + past_key_values, output_attentions, use_cache, attn_mask_startend_row_indices, @@ -1101,7 +1093,7 @@ def forward( input_ids: paddle.Tensor = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_values: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1135,11 +1127,9 @@ def forward( ) use_cache = False - cache_length = 0 - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - else: - cache_length = past_key_values[0][0].shape[2] + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1190,13 +1180,11 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_cache = () if use_cache else None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: layer_outputs = self.recompute_training_full( @@ -1205,7 +1193,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_embeddings=position_embeddings, position_ids=text_position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, attn_mask_startend_row_indices=attn_mask_startend_row_indices_mapping[ @@ -1220,7 +1208,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_embeddings=position_embeddings, position_ids=text_position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, attn_mask_startend_row_indices=attn_mask_startend_row_indices_mapping[ @@ -1232,8 +1220,6 @@ def forward( hidden_states = layer_outputs[0] - next_cache = next_cache + (layer_outputs[-1],) if use_cache else None - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1244,10 +1230,12 @@ def forward( all_hidden_states += (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1528,7 +1516,7 @@ def forward( input_ids: paddle.Tensor = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_values: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1633,10 +1621,7 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - + past_key_values (`Cache)`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. rope_deltas (`paddle.Tensor` of shape `(batch_size, )`, *optional*): @@ -1645,7 +1630,7 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): loss: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None - past_key_values: Optional[Tuple[paddle.Tensor]] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[paddle.Tensor]] = None attentions: Optional[tuple[paddle.Tensor]] = None rope_deltas: Optional[paddle.Tensor] = None @@ -1697,7 +1682,7 @@ def forward( input_ids: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, position_ids: Optional[paddle.Tensor] = None, - past_key_values: Optional[Tuple[paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[paddle.Tensor] = None, labels: Optional[paddle.Tensor] = None, use_cache: Optional[bool] = None, diff --git a/paddleformers/transformers/qwen3/modeling.py b/paddleformers/transformers/qwen3/modeling.py index 3d279ddec2..a09771b3de 100644 --- a/paddleformers/transformers/qwen3/modeling.py +++ b/paddleformers/transformers/qwen3/modeling.py @@ -551,7 +551,7 @@ def recompute_training_full( layer_module: nn.Layer, hidden_states: Tensor, attention_mask: Tensor, - past_key_values: Tensor, + past_key_values: Cache, use_cache: bool, position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, attn_mask_startend_row_indices=None, From f489635999e2bcb3658955ee941e7675ba75fd29 Mon Sep 17 00:00:00 2001 From: w-yyh <1799470752@qq.com> Date: Mon, 1 Dec 2025 15:45:17 +0800 Subject: [PATCH 2/2] fix --- paddleformers/peft/prefix/prefix_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddleformers/peft/prefix/prefix_model.py b/paddleformers/peft/prefix/prefix_model.py index 538779a906..ce55ed0624 100644 --- a/paddleformers/peft/prefix/prefix_model.py +++ b/paddleformers/peft/prefix/prefix_model.py @@ -24,6 +24,7 @@ import paddle.nn as nn from paddle.distributed import fleet +from ...transformers.cache_utils import DynamicCache from ...transformers.model_utils import ( _add_variant, _load_state_dict_into_model, @@ -273,8 +274,13 @@ def _get_past_key_values(self, batch_size): if self.postprocess_past_key_value is not None: past_key_values = self.postprocess_past_key_value(past_key_values) - - return past_key_values + past_key_values_cache = DynamicCache() + if isinstance(past_key_values, tuple): + for layer_idx, (key_state, value_state) in enumerate(past_key_values): + past_key_values_cache.update(key_state, value_state, layer_idx) + else: + return past_key_values + return past_key_values_cache def train(self): self.training = True