|
17 | 17 | from transformers.models.granite.modeling_granite import ( |
18 | 18 | GraniteAttention, |
19 | 19 | GraniteConfig, |
| 20 | + GraniteDecoderLayer, |
20 | 21 | GraniteForCausalLM, |
21 | 22 | GraniteModel, |
22 | 23 | GraniteRotaryEmbedding, |
@@ -173,6 +174,80 @@ def forward( |
173 | 174 | return attn_output, attn_weights |
174 | 175 |
|
175 | 176 |
|
| 177 | +class QEffGraniteDecoderLayer(GraniteDecoderLayer): |
| 178 | + """ |
| 179 | + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py |
| 180 | + The only differences are: |
| 181 | + - add new args batch idx for the CB models although its not supported yet. |
| 182 | + """ |
| 183 | + |
| 184 | + def forward( |
| 185 | + self, |
| 186 | + hidden_states: torch.Tensor, |
| 187 | + attention_mask: Optional[torch.Tensor] = None, |
| 188 | + position_ids: Optional[torch.LongTensor] = None, |
| 189 | + past_key_values: Optional[Cache] = None, |
| 190 | + output_attentions: Optional[bool] = False, |
| 191 | + batch_index: Optional[torch.LongTensor] = None, |
| 192 | + use_cache: Optional[bool] = False, |
| 193 | + cache_position: Optional[torch.LongTensor] = None, |
| 194 | + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| 195 | + **kwargs, |
| 196 | + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| 197 | + """ |
| 198 | + Args: |
| 199 | + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| 200 | + attention_mask (`torch.FloatTensor`, *optional*): |
| 201 | + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| 202 | + query_sequence_length, key_sequence_length)` if default attention is used. |
| 203 | + output_attentions (`bool`, *optional*): |
| 204 | + Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| 205 | + returned tensors for more detail. |
| 206 | + use_cache (`bool`, *optional*): |
| 207 | + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| 208 | + (see `past_key_values`). |
| 209 | + past_key_values (`Cache`, *optional*): cached past key and value projection states |
| 210 | + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| 211 | + Indices depicting the position of the input sequence tokens in the sequence |
| 212 | + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): |
| 213 | + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, |
| 214 | + with `head_dim` being the embedding dimension of each attention head. |
| 215 | + kwargs (`dict`, *optional*): |
| 216 | + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
| 217 | + into the model |
| 218 | + """ |
| 219 | + residual = hidden_states |
| 220 | + |
| 221 | + hidden_states = self.input_layernorm(hidden_states) |
| 222 | + # Self Attention |
| 223 | + hidden_states, self_attn_weights = self.self_attn( |
| 224 | + hidden_states=hidden_states, |
| 225 | + attention_mask=attention_mask, |
| 226 | + position_ids=position_ids, |
| 227 | + past_key_values=past_key_values, |
| 228 | + output_attentions=output_attentions, |
| 229 | + batch_index=batch_index, |
| 230 | + use_cache=use_cache, |
| 231 | + cache_position=cache_position, |
| 232 | + position_embeddings=position_embeddings, |
| 233 | + **kwargs, |
| 234 | + ) |
| 235 | + hidden_states = residual + hidden_states * self.residual_multiplier |
| 236 | + |
| 237 | + # Fully Connected |
| 238 | + residual = hidden_states |
| 239 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 240 | + hidden_states = self.mlp(hidden_states) |
| 241 | + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama |
| 242 | + |
| 243 | + outputs = (hidden_states,) |
| 244 | + |
| 245 | + if output_attentions: |
| 246 | + outputs += (self_attn_weights,) |
| 247 | + |
| 248 | + return outputs |
| 249 | + |
| 250 | + |
176 | 251 | class QEffGraniteModel(GraniteModel): |
177 | 252 | def forward( |
178 | 253 | self, |
|
0 commit comments