Skip to content

Transformers version 4.55 upgrade #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 65 additions & 75 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# -----------------------------------------------------------------------------


from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache
from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache

from QEfficient.customop import (
CtxGatherFunc,
Expand All @@ -23,6 +24,59 @@
)


class QEffDynamicLayer(DynamicLayer):
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Update the cache
if self.keys is None:
self.keys = key_states
self.values = value_states
k_out, v_out = self.keys, self.values
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

# Scatter
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states)

self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states)
else:
self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states)
self.values = CtxScatterFunc.apply(self.values, position_ids, value_states)

k_out, v_out = self.keys, self.values

# Gather
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out


class QEffDynamicCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand All @@ -36,6 +90,16 @@ class QEffDynamicCache(DynamicCache):

"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
# Remove layer_classes if present to avoid duplicate argument
kwargs.pop("layer_classes", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
"""
Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Expand Down Expand Up @@ -113,80 +177,6 @@ def read_only(self, layer_idx, cache_kwargs):
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

# Scatter
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)

self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out

def update3D(
self,
key_states: torch.Tensor,
Expand Down
Loading
Loading