Skip to content
Open

add #3033

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 21 additions & 31 deletions paddleformers/transformers/gemma3_text/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...nn.pp_model import GeneralModelForCausalLMPipe
from ...utils.log import logger
from ..activations import ACT2FN
from ..cache_utils import Cache, DynamicCache
from ..masking_utils import create_causal_masks_and_row_indices
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ..model_utils import PretrainedModel
Expand Down Expand Up @@ -257,7 +258,7 @@ def forward(
hidden_states: paddle.Tensor,
position_embeddings: Tuple[paddle.Tensor, paddle.Tensor],
attention_mask: Optional[paddle.Tensor] = None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
past_key_values: Optional[Cache] = None,
position_ids: Optional[Tuple[paddle.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
Expand Down Expand Up @@ -309,10 +310,8 @@ def forward(
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
past_key_value = (key_states, value_states) if use_cache else None
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)

if attn_mask_startend_row_indices is None and attention_mask is None:
self.attn_implementation = "sdpa"
Expand All @@ -337,7 +336,7 @@ def forward(
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class Gemma3DecoderLayer(nn.Layer):
Expand All @@ -359,7 +358,7 @@ def forward(
hidden_states: paddle.Tensor,
position_embeddings: Tuple[paddle.Tensor, paddle.Tensor],
attention_mask: Optional[paddle.Tensor] = None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
past_key_values: Optional[Cache] = None,
position_ids: Optional[paddle.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
Expand All @@ -370,11 +369,11 @@ def forward(
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
past_key_value=past_key_value,
past_key_values=past_key_values,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -394,8 +393,6 @@ def forward(

if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]

Expand Down Expand Up @@ -608,7 +605,7 @@ def recompute_training(
hidden_states: paddle.Tensor,
position_ids: Optional[paddle.Tensor],
attention_mask: paddle.Tensor,
past_key_value: paddle.Tensor,
past_key_values: Cache,
output_attentions: bool,
use_cache: bool,
position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Expand All @@ -625,7 +622,7 @@ def custom_forward(*inputs):
hidden_states,
position_embeddings,
attention_mask,
past_key_value,
past_key_values,
position_ids,
output_attentions,
use_cache,
Expand All @@ -640,7 +637,7 @@ def forward(
attention_mask: Optional[paddle.Tensor] = None,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
position_ids: Optional[paddle.LongTensor] = None,
past_key_values: Optional[Tuple[paddle.Tensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[paddle.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -670,11 +667,9 @@ def forward(
# [bs, seq_len, dim]
inputs_embeds = self.embed_tokens(input_ids)

cache_length = 0
if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
else:
cache_length = past_key_values[0][0].shape[-2]
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0

if self.sequence_parallel:
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
Expand Down Expand Up @@ -717,12 +712,10 @@ def forward(
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient
if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient:
Expand All @@ -731,7 +724,7 @@ def forward(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
past_key_value=past_key_value,
past_key_values=past_key_values,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -744,7 +737,7 @@ def forward(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
past_key_value=past_key_value,
past_key_values=past_key_values,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -761,24 +754,21 @@ def forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

# Final Norm
hidden_states = self.norm(hidden_states)

if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None

# Return outputs
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down Expand Up @@ -869,7 +859,7 @@ def forward(
input_ids: Optional[paddle.LongTensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
position_ids: Optional[paddle.LongTensor] = None,
past_key_values: Optional[Tuple[paddle.Tensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[paddle.FloatTensor] = None,
labels: Optional[paddle.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
59 changes: 20 additions & 39 deletions paddleformers/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...nn.norm import Norm as GeneralNorm
from ...nn.pp_model import GeneralModelForCausalLMPipe
from ...utils.log import logger
from ..cache_utils import Cache, DynamicCache
from ..masking_utils import create_causal_masks_and_row_indices
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ..model_utils import PretrainedModel, register_base_model
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
def forward(
self,
hidden_states: paddle.Tensor,
past_key_value: list[paddle.Tensor] | None = None,
past_key_values: Cache | None = None,
attention_mask: paddle.Tensor | None = None,
attn_mask_startend_row_indices: paddle.Tensor | None = None,
position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None,
Expand All @@ -166,16 +167,14 @@ def forward(
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
past_key_value = [key_states, value_states] if use_cache else None
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
if self.config._attn_implementation != "sdpa":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, _ = attention_interface(
attn_output, attn_weights = attention_interface(
self,
query=query_states,
key=key_states,
Expand All @@ -188,7 +187,7 @@ def forward(
if self.config.sequence_parallel:
attn_output = attn_output.reshape([-1, attn_output.shape[-1]])
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
return attn_output, attn_weights


class LlamaDecoderLayer(nn.Layer):
Expand Down Expand Up @@ -222,22 +221,17 @@ def forward(
attn_mask_startend_row_indices: paddle.Tensor | None = None,
position_ids: paddle.Tensor | None = None,
position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None,
past_key_value: list[paddle.Tensor] | None = None,
past_key_values: Cache | None = None,
use_cache: bool = False,
) -> (
tuple[paddle.Tensor]
| tuple[paddle.Tensor, paddle.Tensor]
| tuple[paddle.Tensor, list[paddle.Tensor]]
| tuple[paddle.Tensor, paddle.Tensor, list[paddle.Tensor]]
):
) -> (tuple[paddle.Tensor] | tuple[paddle.Tensor, paddle.Tensor]):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, current_key_value = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
position_embeddings=position_embeddings,
past_key_value=past_key_value,
past_key_values=past_key_values,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
Expand All @@ -248,9 +242,6 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)

if use_cache:
outputs += (current_key_value,)

# for pipeline parallel
if len(outputs) == 1 and isinstance(outputs, tuple):
outputs = outputs[0]
Expand Down Expand Up @@ -510,7 +501,7 @@ def forward(
input_ids: paddle.Tensor | None = None,
attention_mask: paddle.Tensor | None = None,
position_ids: paddle.Tensor | None = None,
past_key_values: tuple[list[paddle.Tensor] | None] | None = None,
past_key_values: Cache | None = None,
inputs_embeds: paddle.Tensor | None = None,
attn_mask_startend_row_indices: paddle.Tensor | None = None,
use_cache: bool | None = None,
Expand All @@ -534,12 +525,9 @@ def forward(
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
inputs_embeds = ScatterOp.apply(inputs_embeds)

if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
kv_seq_len = 0
else:
assert past_key_values[0] is not None, "past_key_values[0] should not be None if provided"
kv_seq_len = past_key_values[0][0].shape[2]
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
kv_seq_len = past_key_values.get_seq_length() if past_key_values is not None else 0

if position_ids is None:
position_ids = (
Expand All @@ -563,11 +551,9 @@ def forward(
all_hidden_states = [] if output_hidden_states else None

hidden_states = inputs_embeds
next_key_values = [] if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
past_key_value: list[paddle.Tensor] | None = past_key_values[idx] # type: ignore[index]
has_gradient = not hidden_states.stop_gradient
if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient:
layer_outputs = self.recompute_training(
Expand All @@ -577,7 +563,7 @@ def forward(
attn_mask_startend_row_indices,
position_ids,
position_embeddings,
past_key_value,
past_key_values,
use_cache,
)
else:
Expand All @@ -587,13 +573,11 @@ def forward(
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
position_ids=position_ids,
position_embeddings=position_embeddings,
past_key_value=past_key_value,
past_key_values=past_key_values,
use_cache=use_cache,
)

hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple | list) else layer_outputs
if use_cache:
next_key_values.append(layer_outputs[1])

hidden_states = self.norm(hidden_states)
if output_hidden_states:
Expand All @@ -602,20 +586,17 @@ def forward(
)

all_hidden_states = tuple(all_hidden_states) if all_hidden_states else None
next_key_values = tuple(next_key_values) if next_key_values else None

if not return_dict:
outputs = []
outputs.append(hidden_states)
if use_cache:
outputs.append(next_key_values)
if output_hidden_states:
outputs.append(all_hidden_states)
return tuple(outputs)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_key_values,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
)

Expand All @@ -628,7 +609,7 @@ def recompute_training(
attn_mask_startend_row_indices: paddle.Tensor | None,
position_ids: paddle.Tensor,
position_embeddings: paddle.Tensor,
past_key_value: list[paddle.Tensor] | None,
past_key_values: Cache | None,
use_cache: bool,
):
hidden_states = recompute(
Expand All @@ -638,7 +619,7 @@ def recompute_training(
attn_mask_startend_row_indices,
position_ids,
position_embeddings,
past_key_value,
past_key_values,
use_cache,
)
return hidden_states
Expand All @@ -665,7 +646,7 @@ def forward(
labels: paddle.Tensor | None = None,
loss_mask: paddle.Tensor | None = None,
use_cache: bool = False,
past_key_values: tuple[list[paddle.Tensor]] | None = None,
past_key_values: Cache | None = None,
output_hidden_states: bool | None = False,
return_dict: bool = False, # true when decode, false when pretrain & eval
**kwargs,
Expand Down
Loading
Loading