|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | +from typing import Callable, List, Optional, Tuple, Union |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from transformers.cache_utils import Cache |
| 13 | +from transformers.modeling_outputs import ( |
| 14 | + BaseModelOutputWithPast, |
| 15 | + CausalLMOutputWithPast, |
| 16 | +) |
| 17 | +from transformers.models.olmo2.modeling_olmo2 import ( |
| 18 | + Olmo2Attention, |
| 19 | + Olmo2Config, |
| 20 | + Olmo2DecoderLayer, |
| 21 | + Olmo2ForCausalLM, |
| 22 | + Olmo2Model, |
| 23 | + Olmo2RotaryEmbedding, |
| 24 | + repeat_kv, |
| 25 | + rotate_half, |
| 26 | +) |
| 27 | + |
| 28 | +from QEfficient.transformers.cache_utils import QEffDynamicCache |
| 29 | +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask |
| 30 | + |
| 31 | + |
| 32 | +class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): |
| 33 | + """ |
| 34 | + Copied from Olmo2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py |
| 35 | + The only differences are: |
| 36 | + - Add static sin/cos computations. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__(self, config: Olmo2Config, device=None): |
| 40 | + super().__init__(config=config) |
| 41 | + |
| 42 | + self._set_cos_sin_cache( |
| 43 | + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
| 44 | + ) |
| 45 | + |
| 46 | + def _set_cos_sin_cache(self, seq_len, device, dtype): |
| 47 | + self.max_seq_len_cached = seq_len |
| 48 | + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) |
| 49 | + |
| 50 | + freqs = torch.outer(t, self.inv_freq) |
| 51 | + |
| 52 | + emb = torch.cat((freqs, freqs), dim=-1) |
| 53 | + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| 54 | + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
| 55 | + |
| 56 | + def forward(self, x, seq_len=None): |
| 57 | + # x: [bs, num_attention_heads, seq_len, head_size] |
| 58 | + if seq_len > self.max_seq_len_cached: |
| 59 | + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
| 60 | + |
| 61 | + return ( |
| 62 | + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, |
| 63 | + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
| 68 | + """Applies Rotary Position Embedding to the query and key tensors. |
| 69 | +
|
| 70 | + Args: |
| 71 | + q (`torch.Tensor`): The query tensor. |
| 72 | + k (`torch.Tensor`): The key tensor. |
| 73 | + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| 74 | + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| 75 | + position_ids (`torch.Tensor`): |
| 76 | + The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
| 77 | + used to pass offsetted position ids when working with a KV-cache. |
| 78 | + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| 79 | + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| 80 | + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| 81 | + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| 82 | + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| 83 | + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| 84 | + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| 85 | + Returns: |
| 86 | + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| 87 | + """ |
| 88 | + cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| 89 | + sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
| 90 | + |
| 91 | + # Apply rotation |
| 92 | + q_embed = (q * cos) + (rotate_half(q) * sin) |
| 93 | + k_embed = (k * cos) + (rotate_half(k) * sin) |
| 94 | + # Cast back to original dtype |
| 95 | + return q_embed.to(q.dtype), k_embed.to(k.dtype) |
| 96 | + |
| 97 | + |
| 98 | +def eager_attention_forward( |
| 99 | + module: nn.Module, |
| 100 | + query: torch.Tensor, |
| 101 | + key: torch.Tensor, |
| 102 | + value: torch.Tensor, |
| 103 | + attention_mask: Optional[torch.Tensor], |
| 104 | + scaling: float, |
| 105 | + **kwargs, |
| 106 | +): |
| 107 | + key_states = repeat_kv(key, module.num_key_value_groups) |
| 108 | + value_states = repeat_kv(value, module.num_key_value_groups) |
| 109 | + |
| 110 | + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| 111 | + if attention_mask is not None: |
| 112 | + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) |
| 113 | + |
| 114 | + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| 115 | + attn_output = torch.matmul(attn_weights, value_states) |
| 116 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 117 | + |
| 118 | + return attn_output, attn_weights |
| 119 | + |
| 120 | + |
| 121 | +class QEffOlmo2Attention(Olmo2Attention): |
| 122 | + """Multi-headed attention from 'Attention Is All You Need' paper""" |
| 123 | + |
| 124 | + def __qeff_init__(self): |
| 125 | + self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) |
| 126 | + |
| 127 | + def forward( |
| 128 | + self, |
| 129 | + hidden_states: torch.Tensor, |
| 130 | + attention_mask: Optional[torch.Tensor], |
| 131 | + position_ids: Optional[torch.LongTensor] = None, |
| 132 | + past_key_value: Optional[Cache] = None, |
| 133 | + batch_index: Optional[torch.LongTensor] = None, |
| 134 | + **kwargs, |
| 135 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 136 | + input_shape = hidden_states.shape[:-1] |
| 137 | + hidden_shape = (*input_shape, -1, self.head_dim) |
| 138 | + |
| 139 | + query_states = self.q_norm(self.q_proj(hidden_states)) |
| 140 | + key_states = self.k_norm(self.k_proj(hidden_states)) |
| 141 | + value_states = self.v_proj(hidden_states) |
| 142 | + |
| 143 | + query_states = query_states.view(hidden_shape).transpose(1, 2) |
| 144 | + key_states = key_states.view(hidden_shape).transpose(1, 2) |
| 145 | + value_states = value_states.view(hidden_shape).transpose(1, 2) |
| 146 | + |
| 147 | + kv_seq_len = key_states.shape[-2] |
| 148 | + |
| 149 | + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| 150 | + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 151 | + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| 152 | + |
| 153 | + if past_key_value is not None: |
| 154 | + # sin and cos are specific to RoPE models; cache_position needed for the static cache |
| 155 | + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} |
| 156 | + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| 157 | + |
| 158 | + attention_interface: Callable = eager_attention_forward |
| 159 | + |
| 160 | + attn_output, attn_weights = attention_interface( |
| 161 | + self, |
| 162 | + query_states, |
| 163 | + key_states, |
| 164 | + value_states, |
| 165 | + attention_mask, |
| 166 | + scaling=self.scaling, |
| 167 | + **kwargs, |
| 168 | + ) |
| 169 | + |
| 170 | + attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| 171 | + attn_output = self.o_proj(attn_output) |
| 172 | + return attn_output, attn_weights, past_key_value |
| 173 | + |
| 174 | + |
| 175 | +class QEffOlmo2DecoderLayer(Olmo2DecoderLayer): |
| 176 | + """ |
| 177 | + Copied from Olmo2DecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py |
| 178 | + The only differences are: |
| 179 | + - add new args batch idx for the CB models |
| 180 | + """ |
| 181 | + |
| 182 | + def forward( |
| 183 | + self, |
| 184 | + hidden_states: torch.Tensor, |
| 185 | + attention_mask: Optional[torch.Tensor] = None, |
| 186 | + position_ids: Optional[torch.LongTensor] = None, |
| 187 | + past_key_value: Optional[Cache] = None, |
| 188 | + batch_index: Optional[torch.LongTensor] = None, |
| 189 | + output_attentions: Optional[bool] = False, |
| 190 | + use_cache: Optional[bool] = False, |
| 191 | + cache_position: Optional[torch.LongTensor] = None, |
| 192 | + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC |
| 193 | + **kwargs, |
| 194 | + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| 195 | + residual = hidden_states |
| 196 | + |
| 197 | + # Self Attention |
| 198 | + hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| 199 | + hidden_states=hidden_states, |
| 200 | + attention_mask=attention_mask, |
| 201 | + position_ids=position_ids, |
| 202 | + past_key_value=past_key_value, |
| 203 | + batch_index=batch_index, |
| 204 | + output_attentions=output_attentions, |
| 205 | + use_cache=use_cache, |
| 206 | + cache_position=cache_position, |
| 207 | + position_embeddings=position_embeddings, |
| 208 | + **kwargs, |
| 209 | + ) |
| 210 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 211 | + hidden_states = residual + hidden_states |
| 212 | + |
| 213 | + # Fully Connected |
| 214 | + residual = hidden_states |
| 215 | + hidden_states = self.mlp(hidden_states) |
| 216 | + hidden_states = self.post_feedforward_layernorm(hidden_states) |
| 217 | + hidden_states = residual + hidden_states |
| 218 | + |
| 219 | + outputs = (hidden_states,) |
| 220 | + |
| 221 | + if output_attentions: |
| 222 | + outputs += (self_attn_weights,) |
| 223 | + if use_cache: |
| 224 | + outputs += (present_key_value,) |
| 225 | + |
| 226 | + return outputs |
| 227 | + |
| 228 | + |
| 229 | +class QEffOlmo2Model(Olmo2Model): |
| 230 | + """ |
| 231 | + Copied from Olmo2Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py |
| 232 | + The only differences are: |
| 233 | + - add new args cache idx for the kv retention |
| 234 | + """ |
| 235 | + |
| 236 | + def forward( |
| 237 | + self, |
| 238 | + input_ids: torch.LongTensor = None, |
| 239 | + attention_mask: Optional[torch.Tensor] = None, |
| 240 | + position_ids: Optional[torch.LongTensor] = None, |
| 241 | + past_key_values: Optional[Cache] = None, |
| 242 | + batch_index: Optional[torch.LongTensor] = None, |
| 243 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 244 | + use_cache: Optional[bool] = None, |
| 245 | + output_attentions: Optional[bool] = None, |
| 246 | + output_hidden_states: Optional[bool] = None, |
| 247 | + return_dict: Optional[bool] = None, |
| 248 | + cache_position: Optional[torch.LongTensor] = None, |
| 249 | + **kwargs, |
| 250 | + ) -> Union[Tuple, BaseModelOutputWithPast]: |
| 251 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 252 | + output_hidden_states = ( |
| 253 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 254 | + ) |
| 255 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
| 256 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 257 | + |
| 258 | + if (input_ids is None) ^ (inputs_embeds is not None): |
| 259 | + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| 260 | + |
| 261 | + if inputs_embeds is None: |
| 262 | + inputs_embeds = self.embed_tokens(input_ids) |
| 263 | + |
| 264 | + return_legacy_cache = False |
| 265 | + if use_cache and not isinstance(past_key_values, Cache): |
| 266 | + return_legacy_cache = True |
| 267 | + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) |
| 268 | + |
| 269 | + if cache_position is None: |
| 270 | + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| 271 | + cache_position = torch.arange( |
| 272 | + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| 273 | + ) |
| 274 | + if position_ids is None: |
| 275 | + position_ids = cache_position.unsqueeze(0) |
| 276 | + |
| 277 | + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) |
| 278 | + |
| 279 | + # embed positions |
| 280 | + hidden_states = inputs_embeds |
| 281 | + |
| 282 | + # decoder layers |
| 283 | + all_hidden_states = () if output_hidden_states else None |
| 284 | + all_self_attns = () if output_attentions else None |
| 285 | + |
| 286 | + for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| 287 | + if output_hidden_states: |
| 288 | + all_hidden_states += (hidden_states,) |
| 289 | + |
| 290 | + layer_outputs = decoder_layer( |
| 291 | + hidden_states, |
| 292 | + attention_mask=causal_mask, |
| 293 | + position_ids=position_ids, |
| 294 | + past_key_value=past_key_values, |
| 295 | + batch_index=batch_index, |
| 296 | + output_attentions=output_attentions, |
| 297 | + use_cache=use_cache, |
| 298 | + cache_position=cache_position, |
| 299 | + **kwargs, |
| 300 | + ) |
| 301 | + |
| 302 | + hidden_states = layer_outputs[0] |
| 303 | + |
| 304 | + if output_attentions: |
| 305 | + all_self_attns += (layer_outputs[1],) |
| 306 | + |
| 307 | + hidden_states = self.norm(hidden_states) |
| 308 | + |
| 309 | + # add hidden states from the last decoder layer |
| 310 | + if output_hidden_states: |
| 311 | + all_hidden_states += (hidden_states,) |
| 312 | + |
| 313 | + if return_legacy_cache: |
| 314 | + past_key_values = past_key_values.to_legacy_cache() |
| 315 | + |
| 316 | + output = BaseModelOutputWithPast( |
| 317 | + last_hidden_state=hidden_states, |
| 318 | + past_key_values=past_key_values if use_cache else None, |
| 319 | + hidden_states=all_hidden_states, |
| 320 | + attentions=all_self_attns, |
| 321 | + ) |
| 322 | + return output if return_dict else output.to_tuple() |
| 323 | + |
| 324 | + |
| 325 | +class QEffOlmo2ForCausalLM(Olmo2ForCausalLM): |
| 326 | + """ |
| 327 | + Copied from Olmo2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py |
| 328 | + The only differences are: |
| 329 | + - add new args cache idx for the kv retention |
| 330 | + """ |
| 331 | + |
| 332 | + def forward( |
| 333 | + self, |
| 334 | + input_ids: torch.LongTensor = None, |
| 335 | + attention_mask: Optional[torch.Tensor] = None, |
| 336 | + position_ids: Optional[torch.LongTensor] = None, |
| 337 | + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| 338 | + batch_index: Optional[torch.LongTensor] = None, |
| 339 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 340 | + use_cache: Optional[bool] = None, |
| 341 | + output_attentions: Optional[bool] = None, |
| 342 | + output_hidden_states: Optional[bool] = None, |
| 343 | + return_dict: Optional[bool] = None, |
| 344 | + cache_position: Optional[torch.LongTensor] = None, |
| 345 | + **kwargs, |
| 346 | + ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 347 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 348 | + output_hidden_states = ( |
| 349 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 350 | + ) |
| 351 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 352 | + |
| 353 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 354 | + outputs = self.model( |
| 355 | + input_ids=input_ids, |
| 356 | + attention_mask=attention_mask, |
| 357 | + position_ids=position_ids, |
| 358 | + past_key_values=past_key_values, |
| 359 | + batch_index=batch_index, |
| 360 | + inputs_embeds=inputs_embeds, |
| 361 | + use_cache=use_cache, |
| 362 | + output_attentions=output_attentions, |
| 363 | + output_hidden_states=output_hidden_states, |
| 364 | + return_dict=return_dict, |
| 365 | + cache_position=cache_position, |
| 366 | + **kwargs, |
| 367 | + ) |
| 368 | + |
| 369 | + # Cast to INT32 to avoid issue while running in ONNXRT |
| 370 | + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) |
| 371 | + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] |
| 372 | + |
| 373 | + logits = self.lm_head(hidden_states) |
| 374 | + logits = logits.float() |
| 375 | + |
| 376 | + return CausalLMOutputWithPast( |
| 377 | + loss=None, |
| 378 | + logits=logits, |
| 379 | + past_key_values=outputs.past_key_values, |
| 380 | + hidden_states=outputs.hidden_states, |
| 381 | + attentions=outputs.attentions, |
| 382 | + ) |
0 commit comments