diff --git a/QEfficient/transformers/models/olmo2/__init__.py b/QEfficient/transformers/models/olmo2/__init__.py new file mode 100644 index 000000000..e467e4d4c --- /dev/null +++ b/QEfficient/transformers/models/olmo2/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py new file mode 100644 index 000000000..f5a0658e2 --- /dev/null +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -0,0 +1,382 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.olmo2.modeling_olmo2 import ( + Olmo2Attention, + Olmo2Config, + Olmo2DecoderLayer, + Olmo2ForCausalLM, + Olmo2Model, + Olmo2RotaryEmbedding, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): + """ + Copied from Olmo2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Olmo2Config, device=None): + super().__init__(config=config) + + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffOlmo2Attention(Olmo2Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffOlmo2DecoderLayer(Olmo2DecoderLayer): + """ + Copied from Olmo2DecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py + The only differences are: + - add new args batch idx for the CB models + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffOlmo2Model(Olmo2Model): + """ + Copied from Olmo2Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffOlmo2ForCausalLM(Olmo2ForCausalLM): + """ + Copied from Olmo2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..7aae89e0d 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -121,6 +121,13 @@ MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel +from transformers.models.olmo2.modeling_olmo2 import ( + Olmo2Attention, + Olmo2DecoderLayer, + Olmo2ForCausalLM, + Olmo2Model, + Olmo2RMSNorm, +) from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import ( Phi3Attention, @@ -285,6 +292,12 @@ QEffMptForCausalLM, QEFfMptModel, ) +from QEfficient.transformers.models.olmo2.modeling_olmo2 import ( + QEffOlmo2Attention, + QEffOlmo2DecoderLayer, + QEffOlmo2ForCausalLM, + QEffOlmo2Model, +) from QEfficient.transformers.models.phi.modeling_phi import ( QEffPhiAttention, QEffPhiDecoderLayer, @@ -339,6 +352,7 @@ class CustomOpsTransform(ModuleMappingTransform): GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, + Olmo2RMSNorm: CustomRMSNormAIC, } @@ -462,6 +476,11 @@ class KVCacheTransform(ModuleMappingTransform): GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, + # Olmo2 + Olmo2Attention: QEffOlmo2Attention, + Olmo2DecoderLayer: QEffOlmo2DecoderLayer, + Olmo2Model: QEffOlmo2Model, + Olmo2ForCausalLM: QEffOlmo2ForCausalLM, # Whisper encoder and decoder layers WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, WhisperAttention: QEffWhisperAttention, diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 3195c4828..428ab2607 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -47,6 +47,7 @@ "ibm-granite/granite-3.1-2b-instruct", "ibm-granite/granite-guardian-3.1-2b", "hpcai-tech/grok-1", + "allenai/OLMo-2-0425-1B", ] test_models_qnn = [ diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index d0250d899..6447acf99 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -30,6 +30,7 @@ ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [