Skip to content

Commit c17be77

Browse files
Added Decoder layer class in Qeff for granite (#628)
This PR introduces support for the Decoder layer class in Qeff for granite, which is required for the subfunction. While there are alternative approaches to achieve this, implementing it now ensures future compatibility, particularly if we decide to add CB support for Granite. In that case, the Qeff Granite Decoder layer will be necessary. Signed-off-by: abhishek-singh591 <[email protected]>
1 parent 999f465 commit c17be77

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

QEfficient/transformers/models/granite/modeling_granite.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from transformers.models.granite.modeling_granite import (
1818
GraniteAttention,
1919
GraniteConfig,
20+
GraniteDecoderLayer,
2021
GraniteForCausalLM,
2122
GraniteModel,
2223
GraniteRotaryEmbedding,
@@ -173,6 +174,80 @@ def forward(
173174
return attn_output, attn_weights
174175

175176

177+
class QEffGraniteDecoderLayer(GraniteDecoderLayer):
178+
"""
179+
Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
180+
The only differences are:
181+
- add new args batch idx for the CB models although its not supported yet.
182+
"""
183+
184+
def forward(
185+
self,
186+
hidden_states: torch.Tensor,
187+
attention_mask: Optional[torch.Tensor] = None,
188+
position_ids: Optional[torch.LongTensor] = None,
189+
past_key_values: Optional[Cache] = None,
190+
output_attentions: Optional[bool] = False,
191+
batch_index: Optional[torch.LongTensor] = None,
192+
use_cache: Optional[bool] = False,
193+
cache_position: Optional[torch.LongTensor] = None,
194+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
195+
**kwargs,
196+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
197+
"""
198+
Args:
199+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
200+
attention_mask (`torch.FloatTensor`, *optional*):
201+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202+
query_sequence_length, key_sequence_length)` if default attention is used.
203+
output_attentions (`bool`, *optional*):
204+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205+
returned tensors for more detail.
206+
use_cache (`bool`, *optional*):
207+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208+
(see `past_key_values`).
209+
past_key_values (`Cache`, *optional*): cached past key and value projection states
210+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
211+
Indices depicting the position of the input sequence tokens in the sequence
212+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
213+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
214+
with `head_dim` being the embedding dimension of each attention head.
215+
kwargs (`dict`, *optional*):
216+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
217+
into the model
218+
"""
219+
residual = hidden_states
220+
221+
hidden_states = self.input_layernorm(hidden_states)
222+
# Self Attention
223+
hidden_states, self_attn_weights = self.self_attn(
224+
hidden_states=hidden_states,
225+
attention_mask=attention_mask,
226+
position_ids=position_ids,
227+
past_key_values=past_key_values,
228+
output_attentions=output_attentions,
229+
batch_index=batch_index,
230+
use_cache=use_cache,
231+
cache_position=cache_position,
232+
position_embeddings=position_embeddings,
233+
**kwargs,
234+
)
235+
hidden_states = residual + hidden_states * self.residual_multiplier
236+
237+
# Fully Connected
238+
residual = hidden_states
239+
hidden_states = self.post_attention_layernorm(hidden_states)
240+
hidden_states = self.mlp(hidden_states)
241+
hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
242+
243+
outputs = (hidden_states,)
244+
245+
if output_attentions:
246+
outputs += (self_attn_weights,)
247+
248+
return outputs
249+
250+
176251
class QEffGraniteModel(GraniteModel):
177252
def forward(
178253
self,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel
6464
from transformers.models.granite.modeling_granite import (
6565
GraniteAttention,
66+
GraniteDecoderLayer,
6667
GraniteForCausalLM,
6768
GraniteModel,
6869
GraniteRMSNorm,
@@ -268,6 +269,7 @@
268269
)
269270
from QEfficient.transformers.models.granite.modeling_granite import (
270271
QEffGraniteAttention,
272+
QEffGraniteDecoderLayer,
271273
QEffGraniteForCausalLM,
272274
QEffGraniteModel,
273275
)
@@ -531,6 +533,7 @@ class KVCacheTransform(ModuleMappingTransform):
531533
GraniteModel: QEffGraniteModel,
532534
GraniteForCausalLM: QEffGraniteForCausalLM,
533535
GraniteAttention: QEffGraniteAttention,
536+
GraniteDecoderLayer: QEffGraniteDecoderLayer,
534537
# GraniteMoe
535538
GraniteMoeModel: QEffGraniteMoeModel,
536539
GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM,

0 commit comments

Comments
 (0)