diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..063a8a64a 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -6,10 +6,11 @@ # ----------------------------------------------------------------------------- +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( CtxGatherFunc, @@ -23,6 +24,59 @@ ) +class QEffDynamicLayer(DynamicLayer): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + + k_out, v_out = self.keys, self.values + + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + class QEffDynamicCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -36,6 +90,16 @@ class QEffDynamicCache(DynamicCache): """ + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # Remove layer_classes if present to avoid duplicate argument + kwargs.pop("layer_classes", None) + from transformers.cache_utils import Cache # Import here to avoid circular import + + Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -113,80 +177,6 @@ def read_only(self, layer_idx, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) - else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], position_ids, value_states - ) - - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Gather - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - - return k_out, v_out - def update3D( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index e75181424..aebb2c446 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -47,7 +47,7 @@ def _attn( attn_weights = torch.matmul(query, key.transpose(-1, -2)) - attn_weights = attn_weights / self.scale_attn + # attn_weights = attn_weights / self.scale_attn # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` @@ -57,6 +57,7 @@ def _attn( # Apply the attention mask attn_weights = torch.where(attention_mask, mask_value, attn_weights) + attn_weights = attn_weights / self.scale_attn attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -124,23 +125,14 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - # Update the cache_kwargs with position_ids for Cloud AI 100 - past_key_value = layer_past cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_dim, + "cache_position": cache_position, "position_ids": position_ids, - "batch_index": batch_index, } - pkv = QEffDynamicCache() - pkv.key_cache.append(past_key_value[0]) - pkv.value_cache.append(past_key_value[1]) - key, value = pkv.update(key, value, 0, cache_kwargs) - - if use_cache is True: - # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. - # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 - present = (pkv.key_cache[0].to(hidden_states.dtype), pkv.value_cache[0]) - else: - present = None + key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) @@ -148,12 +140,7 @@ def forward( attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffCodeGenModel(CodeGenModel): @@ -167,7 +154,7 @@ class QEffCodeGenModel(CodeGenModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -179,7 +166,14 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -200,20 +194,34 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = QEffDynamicCache() + else: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + """# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + """ + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length = inputs_embeds.shape[1] + if cache_position is None: + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) + + # causal_mask = self._update_causal_mask( + # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + # ) # Attention mask. if attention_mask is not None: @@ -237,7 +245,7 @@ def forward( elif attention_mask is None: # 4d mask is passed through the layers - attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_length) + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -245,32 +253,25 @@ def forward( # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) + output_shape = (-1, seq_length, hidden_states.size(-1)) - output_shape = input_shape + (hidden_states.size(-1),) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( - hidden_states=hidden_states, - layer_past=layer_past, + hidden_states, + layer_past=past_key_values, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -281,11 +282,9 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.ln_f(hidden_states) @@ -295,11 +294,13 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -389,7 +390,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -400,15 +401,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + hidden_states = attn_outputs + feed_forward_hidden_states + residual - return outputs # hidden_states, present, (attentions) + return hidden_states, attn_weights diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 79e3ebc01..26638bc5d 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -135,8 +135,7 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = key_layer.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 0cefbcfee..c9e78f1de 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -149,9 +149,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 173da1798..be3ba942d 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -155,9 +155,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index a2b84c139..d68a65430 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -9,13 +9,14 @@ import torch from torch import nn +from transformers import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model -from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffEncoderDecoderCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -63,18 +64,29 @@ class QEffGPT2Attention(GPT2Attention): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + if past_key_value is not None: + if isinstance(past_key_value, QEffEncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " @@ -82,31 +94,39 @@ def forward( ) query_states = self.q_attn(hidden_states) - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask + + # Try to get key/value states from cache if possible + if past_key_value is not None and is_updated: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) + else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) shape_q = (*query_states.shape[:-1], -1, self.head_dim) - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.view(shape_q).transpose(1, 2) - key_states = key_states.view(shape_kv).transpose(1, 2) - value_states = value_states.view(shape_kv).transpose(1, 2) - if layer_past is not None: - # Added for optimized GPT Attention for AI 100 KV Retention + if (past_key_value is not None and not is_cross_attention) or ( + past_key_value is not None and is_cross_attention and not is_updated + ): + # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - pkv = QEffDynamicCache() - pkv.key_cache.append(layer_past[0]) - pkv.value_cache.append(layer_past[1]) - key_states, value_states = pkv.update(key_states, value_states, 0, cache_kwargs) - - if use_cache is True: - present = (pkv.key_cache[0], pkv.value_cache[0]) - else: - present = None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -122,11 +142,7 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffGPT2Block(GPT2Block): @@ -139,7 +155,7 @@ class QEffGPT2Block(GPT2Block): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -154,9 +170,9 @@ def forward( ]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_output, self_attn_weights = self.attn( hidden_states, - layer_past=layer_past, + past_key_value=past_key_value, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -164,8 +180,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -180,18 +194,17 @@ def forward( hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( + cross_attn_outputs, cross_attn_weights = self.crossattention( hidden_states, + past_key_value=past_key_value, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) - attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -199,10 +212,11 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if encoder_hidden_states is not None: + outputs += (cross_attn_weights,) return outputs # hidden_states, present, (attentions, cross_attentions) @@ -256,14 +270,23 @@ def forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + return_legacy_cache = False + if use_cache: + if past_key_values is None: + return_legacy_cache = True + past_key_values = QEffDynamicCache() + elif not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if self.config.add_cross_attention and not isinstance(past_key_values, QEffEncoderDecoderCache): + past_key_values = QEffEncoderDecoderCache(past_key_values, QEffDynamicCache()) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ).unsqueeze(0) # Attention mask. if attention_mask is not None: @@ -271,9 +294,10 @@ def forward( attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + elif attention_mask is None: # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_length, None) + attention_mask = _create_causal_mask(position_ids, past_seen_tokens, None) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -306,19 +330,17 @@ def forward( output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i in range(len(self.h)): - block, layer_past = self.h[i], past_key_values[i] + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past=layer_past, + past_key_value=past_key_values, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -329,8 +351,6 @@ def forward( output_attentions=output_attentions, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -344,9 +364,17 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_values = past_key_values if use_cache else None + if return_legacy_cache: + past_key_values = ( + past_key_values.self_attention_cache.to_legacy_cache() + if self.config.add_cross_attention + else past_key_values.to_legacy_cache() + ) + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 13b308547..2a2d47d6d 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -140,10 +140,8 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8f840b4b4..c085f6a5e 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torch import nn from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -127,15 +128,17 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + bsz, q_len, _ = hidden_states.size() - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -149,23 +152,46 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface = eager_attention_forward - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - dropout = 0.0 if not self.training else self.attention_dropout - attn_weights = F.dropout(attn_weights, p=dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights class QEffGraniteMoeModel(GraniteMoeModel): @@ -212,9 +238,14 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama - return_legacy_cache = False + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + # if not isinstance(past_key_values, (type(None), Cache)): + # raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + # if use_cache and past_key_values is None: + # past_key_values = QEffDynamicCache() + if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True if past_key_values is None: past_key_values = QEffDynamicCache() else: @@ -230,39 +261,26 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions ) - hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = None + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - elif batch_index is not None: + if batch_index is not None: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -287,9 +305,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -298,15 +313,15 @@ def forward( # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -486,6 +501,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -537,6 +553,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -544,7 +561,8 @@ def forward( logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling loss = None diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 21516ff5f..5ff6352fa 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -88,7 +88,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) + kv_seq_len = past_key_value.get_seq_length(layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index a285f00dc..ce9537ff3 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -150,9 +150,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f5e60c5de..3cda9bac7 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -106,7 +106,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) @@ -359,7 +359,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) + kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx, cache_position) cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 60b1c929d..0288ea270 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,7 +24,6 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - logger, repeat_kv, rotate_half, ) @@ -32,6 +31,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): @@ -159,9 +159,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index ef51c3421..292a4c42d 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -156,7 +156,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 89d474e15..9bf6a4422 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -12,6 +12,7 @@ import torch import torch.utils.checkpoint from torch import nn +from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -50,20 +51,12 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - if len(past_key_value) != 0: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - pkv = QEffDynamicCache() - pkv.key_cache.append(past_key_value[0]) - pkv.value_cache.append(past_key_value[1]) - key_states, value_states = pkv.update(key_states, value_states, 0, cache_kwargs) - if use_cache: - past_key_value = (pkv.key_cache[0], pkv.value_cache[0]) - else: - past_key_value = None + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value.get_seq_length() if position_bias is not None: if len(position_bias.shape) != 3: @@ -137,15 +130,7 @@ def forward( # MLP. output = self.ffn(layernorm_output, residual) - outputs = (output,) - - if use_cache: - outputs += (past_key_value,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs # hidden_states, present, attentions + return output, attn_weights class QEFfMptModel(MptModel): @@ -190,18 +175,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + hidden_states = inputs_embeds - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) @@ -213,13 +198,13 @@ def forward( elif attention_mask is None: causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) - for block, layer_past in zip(self.blocks, past_key_values): + for block in self.blocks: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -228,24 +213,27 @@ def forward( position_bias=alibi, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.norm_f(hidden_states) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 602a73c84..79fd69394 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -155,8 +155,8 @@ def forward( query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 00a3989d8..d8d09361d 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -163,9 +163,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index e3db4b490..b51c6eaf7 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -82,8 +82,6 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index b2a35c005..f7e5f3981 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -44,7 +44,19 @@ def login_and_download_hf_lm(model_name, *args, **kwargs): model_path = hf_download( repo_id=model_name, cache_dir=cache_dir, - ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.msgpack", "*.h5", "*.pth"], + ignore_patterns=[ + "*.txt", + "*.onnx", + "*.ot", + "*.md", + "*.tflite", + "*.pdf", + "*.msgpack", + "*.h5", + "*.pth", + "*.pt", + "*.bin", + ], ) return model_path diff --git a/pyproject.toml b/pyproject.toml index 479736c22..e3c1e803d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.51.3", - "huggingface-hub==0.30.0", + "transformers==4.55.0", + "huggingface-hub==0.34.0", "hf_transfer==0.1.9", "peft==0.13.2", "datasets==2.20.0",