Skip to content

Commit 2b7d575

Browse files
authored
Merge branch 'quic:main' into ft
2 parents bcb5cc2 + c17be77 commit 2b7d575

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

QEfficient/transformers/models/granite/modeling_granite.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from transformers.models.granite.modeling_granite import (
1818
GraniteAttention,
1919
GraniteConfig,
20+
GraniteDecoderLayer,
2021
GraniteForCausalLM,
2122
GraniteModel,
2223
GraniteRotaryEmbedding,
@@ -173,6 +174,80 @@ def forward(
173174
return attn_output, attn_weights
174175

175176

177+
class QEffGraniteDecoderLayer(GraniteDecoderLayer):
178+
"""
179+
Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
180+
The only differences are:
181+
- add new args batch idx for the CB models although its not supported yet.
182+
"""
183+
184+
def forward(
185+
self,
186+
hidden_states: torch.Tensor,
187+
attention_mask: Optional[torch.Tensor] = None,
188+
position_ids: Optional[torch.LongTensor] = None,
189+
past_key_values: Optional[Cache] = None,
190+
output_attentions: Optional[bool] = False,
191+
batch_index: Optional[torch.LongTensor] = None,
192+
use_cache: Optional[bool] = False,
193+
cache_position: Optional[torch.LongTensor] = None,
194+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
195+
**kwargs,
196+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
197+
"""
198+
Args:
199+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
200+
attention_mask (`torch.FloatTensor`, *optional*):
201+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202+
query_sequence_length, key_sequence_length)` if default attention is used.
203+
output_attentions (`bool`, *optional*):
204+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205+
returned tensors for more detail.
206+
use_cache (`bool`, *optional*):
207+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208+
(see `past_key_values`).
209+
past_key_values (`Cache`, *optional*): cached past key and value projection states
210+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
211+
Indices depicting the position of the input sequence tokens in the sequence
212+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
213+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
214+
with `head_dim` being the embedding dimension of each attention head.
215+
kwargs (`dict`, *optional*):
216+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
217+
into the model
218+
"""
219+
residual = hidden_states
220+
221+
hidden_states = self.input_layernorm(hidden_states)
222+
# Self Attention
223+
hidden_states, self_attn_weights = self.self_attn(
224+
hidden_states=hidden_states,
225+
attention_mask=attention_mask,
226+
position_ids=position_ids,
227+
past_key_values=past_key_values,
228+
output_attentions=output_attentions,
229+
batch_index=batch_index,
230+
use_cache=use_cache,
231+
cache_position=cache_position,
232+
position_embeddings=position_embeddings,
233+
**kwargs,
234+
)
235+
hidden_states = residual + hidden_states * self.residual_multiplier
236+
237+
# Fully Connected
238+
residual = hidden_states
239+
hidden_states = self.post_attention_layernorm(hidden_states)
240+
hidden_states = self.mlp(hidden_states)
241+
hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
242+
243+
outputs = (hidden_states,)
244+
245+
if output_attentions:
246+
outputs += (self_attn_weights,)
247+
248+
return outputs
249+
250+
176251
class QEffGraniteModel(GraniteModel):
177252
def forward(
178253
self,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel
6464
from transformers.models.granite.modeling_granite import (
6565
GraniteAttention,
66+
GraniteDecoderLayer,
6667
GraniteForCausalLM,
6768
GraniteModel,
6869
GraniteRMSNorm,
@@ -268,6 +269,7 @@
268269
)
269270
from QEfficient.transformers.models.granite.modeling_granite import (
270271
QEffGraniteAttention,
272+
QEffGraniteDecoderLayer,
271273
QEffGraniteForCausalLM,
272274
QEffGraniteModel,
273275
)
@@ -531,6 +533,7 @@ class KVCacheTransform(ModuleMappingTransform):
531533
GraniteModel: QEffGraniteModel,
532534
GraniteForCausalLM: QEffGraniteForCausalLM,
533535
GraniteAttention: QEffGraniteAttention,
536+
GraniteDecoderLayer: QEffGraniteDecoderLayer,
534537
# GraniteMoe
535538
GraniteMoeModel: QEffGraniteMoeModel,
536539
GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM,

0 commit comments

Comments
 (0)