diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ff341a68..bb3444a0 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -392,15 +392,18 @@ def compress_model(self, model: Module): for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): if prefix in module_to_scheme or prefix in sparse_compression_targets: - module_device = get_execution_device(module).type - is_meta = module_device == "meta" + module_device = get_execution_device(module) + is_meta = module_device.type == "meta" exec_device = "meta" if is_meta else "cpu" onloading_device = "meta" if is_meta else module_device # in the future, support compression on same device with align_module_device(module, execution_device=exec_device): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # quantization first if prefix in module_to_scheme: @@ -421,7 +424,7 @@ def compress_model(self, model: Module): # remove any existing parameters offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with compressed parameters @@ -458,7 +461,10 @@ def decompress_model(self, model: Module): if prefix in module_to_scheme or prefix in sparse_compression_targets: # in the future, support decompression on same device with align_module_device(module, execution_device="cpu"): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # sparsity first if prefix in sparse_compression_targets: @@ -483,7 +489,7 @@ def decompress_model(self, model: Module): # remove any existing parameters exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with decompressed parameters @@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module): def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ - Returns a dictionary which maps quantized module names to their quantization schemes + Returns a dictionary which maps quantized module names to their quantization + schemes. Only includes modules with weight quantization """ return { fix_fsdp_module_name(name): module.quantization_scheme for name, module in model.named_modules() - if is_module_quantized(module) + if ( + hasattr(module, "quantization_scheme") and + module.quantization_scheme.weights is not None + ) } @@ -785,4 +795,4 @@ def override_quantization_status( try: yield finally: - config.quantization_status = original_status + config.quantization_status = original_status \ No newline at end of file diff --git a/src/compressed_tensors/modeling/README.md b/src/compressed_tensors/modeling/README.md new file mode 100644 index 00000000..c5e6d3b6 --- /dev/null +++ b/src/compressed_tensors/modeling/README.md @@ -0,0 +1 @@ +This folder contains code which models existing `torch` logic as used by `transformers` \ No newline at end of file diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 00000000..85c4205c --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,159 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict, defaultdict +from typing import TYPE_CHECKING, Callable, Optional + +import torch +from compressed_tensors.utils import getattr_chain +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.llama.modeling_llama import eager_attention_forward + + +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationArgs, QuantizationStatus + + +__all__ = ["CompressedAttentionImpl", "enable_compressed_attention", "call_attn_impl"] + + +ActivationHookFn = Callable[[torch.nn.Module, torch.Tensor], None] + + +class CompressedAttentionImpl(torch.nn.Module): + """ + Callable attention implementation which applies transforms, calibration, and + quantization if applicable. Can be hooked with calibrations hooks in order to + trigger quantization observers. + + :param attn_implementation: original attention implementation to call after hooks + """ + + NAME = "compressed_attention" + ATTN_IMPL = "eager" + _ATTN_IMPLS = dict() + + @classmethod + def from_module(cls, module: torch.nn.Module): + if module not in cls._ATTN_IMPLS: + cls._ATTN_IMPLS[module] = cls() + return cls._ATTN_IMPLS[module] + + def __init__(self): + super().__init__() + self.query_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() + self.key_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() + self.value_hooks: OrderedDict[int, ActivationHookFn] = OrderedDict() + + def register_query_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.query_hooks) + self.query_hooks[handle.id] = hook + + return handle + + def register_key_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.key_hooks) + self.key_hooks[handle.id] = hook + + return handle + + def register_value_hook(self, hook: ActivationHookFn) -> RemovableHandle: + handle = RemovableHandle(self.value_hooks) + self.value_hooks[handle.id] = hook + + return handle + + def forward( + self, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, + ): + from compressed_tensors.quantization import forward_quantize + + for hook in self.query_hooks.values(): + output = hook(self, query) + if output is not None: + query = output + + for hook in self.key_hooks.values(): + output = hook(self, key) + if output is not None: + key = output + + for hook in self.value_hooks.values(): + output = hook(self, value) + if output is not None: + value = output + + # TODO: attnq + # 2. calibrate/ apply quantization + # args_path = "quantization_scheme.input_activations" + # status_path = "quantization_status" + # input_args: Optional[QuantizationArgs] = getattr_chain( + # module, args_path, None + # ) + # status: Optional[QuantizationStatus] = getattr(module, status_path, None) + # if input_args is not None and status in ( + # QuantizationStatus.CALIBRATION, + # QuantizationStatus.FROZEN, + # ): + # query = forward_quantize(module, query, "q", input_args) + # key = forward_quantize(module, key, "k", input_args) + # value = forward_quantize(module, value, "v", input_args) + + # 3. apply original attention function + # `eager_attention_forward` is duplicated across models by design + # assume that llama implementation is representative of all attention functions + # see: https://github.com/huggingface/transformers/issues/38541#issuecomment-2958567250 # noqa: 501 + + attention_fn: Callable = ( + eager_attention_forward + # if self.ATTN_IMPL == "eager" + # else ALL_ATTENTION_FUNCTIONS[self.ATTN_IMPL] + ) + # print(self.ATTN_IMPL) + return attention_fn( + module, query, key, value, attention_mask, scaling, dropout, **kwargs + ) + + +def enable_compressed_attention(model: torch.nn.Module): + """ + Enables transforms, calibration, and quantization for an attention implementation. + This function can safetly be called multiple times on the same model. + + :param model: model to enable compressed quantization for + :return: singleton instance of `CompressedAttentionImpl` + """ + if not isinstance(model, PreTrainedModel): + return + + attn_impl = getattr(model.config, "_attn_implementation", "eager") + + CompressedAttentionImpl.ATTN_IMPL = attn_impl + AttentionInterface.register(CompressedAttentionImpl.NAME, call_attn_impl) + model.config._attn_implementation = CompressedAttentionImpl.NAME + # model.set_attention_implementation(CompressedAttentionImpl.NAME) + + +def call_attn_impl(module: torch.nn.Module, *args, **kwargs): + return CompressedAttentionImpl.from_module(module)(module, *args, **kwargs) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..b8dafd25 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -22,6 +22,7 @@ import torch from compressed_tensors.config import CompressionFormat +from compressed_tensors.modeling.attention import enable_compressed_attention from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, ) @@ -144,6 +145,9 @@ def apply_quantization_config( for target in scheme.targets: target_to_scheme[target] = scheme + # enable attention calibration/ quantization + enable_compressed_attention(model) + if run_compressed: from compressed_tensors.linear.compressed_linear import CompressedLinear diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index a5d4c8c2..970043f0 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from compressed_tensors.modeling.attention import enable_compressed_attention from compressed_tensors.transform import TransformConfig, TransformFactory @@ -27,6 +28,8 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): :param model: model to apply config to :param config: transform config to apply """ + enable_compressed_attention(model) + for name, scheme in config.config_groups.items(): factory = TransformFactory.from_scheme(scheme, name=name) factory.apply_to_model(model) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 1fdfa121..0849917a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,7 +18,7 @@ import torch import torch.nn.utils.parametrize as P -from compressed_tensors import InternalModule, match_named_modules +from compressed_tensors.modeling.attention import CompressedAttentionImpl from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -26,9 +26,11 @@ TransformScheme, ) from compressed_tensors.utils import ( + InternalModule, align_module_device, delete_offload_module, has_offloaded_params, + match_named_modules, patch_attr, register_offload_module, update_offload_parameter, @@ -111,15 +113,15 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" - transform = self.create_transform(module, args) - self.transforms.append(transform) - register_offload_module(module, transform_name, transform) # register input transformation hook if args.location == TransformLocation.INPUT: + transform = self.create_transform(module, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) def input_hook(_, args): - input = args[0] + input = args[0] if isinstance(args, tuple) else args return transform(input) module.register_forward_pre_hook(input_hook, prepend=True) @@ -129,6 +131,9 @@ def input_hook(_, args): TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ): + transform = self.create_transform(module, args) + register_offload_module(module, transform_name, transform) + # fuse transform into weight assert hasattr(module, "weight") with torch.no_grad(), align_module_device(module): @@ -140,6 +145,7 @@ def input_hook(_, args): if has_offloaded_params(module): raise ValueError("Offloaded training is not supported") P.register_parametrization(module, "weight", transform) + self.transforms.append(transform) else: # transform is no longer needed (unfusing is not supported) @@ -147,15 +153,47 @@ def input_hook(_, args): # register output transformation hook elif args.location == TransformLocation.OUTPUT: + transform = self.create_transform(module, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) def output_hook(_, _input, output): return transform(output) module.register_forward_hook(output_hook) - # other locations such as q_attn and k_attn have not been implemented + # query hook registered to `CompressedAttentionImpl` + elif args.location in TransformLocation.ATTN_Q: + # TODO: makes name assumptions. Maybe we can target q_proj in the config + # then assume parent? Not sure + transform = self.create_transform(module.q_proj, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) + + attention_impl = CompressedAttentionImpl.from_module(module) + + def query_hook(_, query): + return transform(query) + + attention_impl.register_query_hook(query_hook) + + # key hook registered to `CompressedAttentionImpl` + elif args.location in TransformLocation.ATTN_K: + # TODO: makes name assumptions. Maybe we can target k_proj in the config + # then assume parent? Not sure + transform = self.create_transform(module.k_proj, args) + self.transforms.append(transform) + register_offload_module(module, transform_name, transform) + + attention_impl = CompressedAttentionImpl.from_module(module) + + def key_hook(_, key): + return transform(key) + + attention_impl.register_key_hook(key_hook) + else: - raise NotImplementedError() + raise ValueError() def _update_tied_weights(self): """ diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index e94d4d2d..582510c8 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -33,8 +33,8 @@ class TransformLocation(str, Enum): | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 - | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 - | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 + | `ATTN_Q` | online | query_states | `this.ATTN_K` | # noqa: E501 + | `ATTN_K` | online | key_states | `this.Q_ATTN` | # noqa: E501 | -------------------------------------------------------------------------------------------------------- | # noqa: E501 """ @@ -42,8 +42,8 @@ class TransformLocation(str, Enum): WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" OUTPUT = "output" - K_CACHE = "k_cache" - Q_ATTN = "q_attn" + ATTN_Q = "attn_q" + ATTN_K = "attn_k" class TransformArgs(BaseModel, use_enum_values=True): diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 18a7dc3a..4aa7b746 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -59,47 +59,13 @@ def get_transform_size( def apply_transform_weight( - weight: torch.Tensor, + transform_weight: torch.Tensor, value: torch.Tensor, location: TransformLocation, module_type: type[torch.nn.Module], ) -> torch.Tensor: """ - :param weight: transform weight to apply - :param value: value to apply weight to - :param location: determines how weight should be applied - :param model_type: result of type(module), passed in to determine application of - weight transform. This is needed because torch uses convention: - - torch.nn.Linear(in_features,out_features) has weight shape - (out_features, in_features) - - torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape - (num_embeddings, embedding_dim) - The transform has to account for Linear's transposed weights - :return: value after weight has been applied - """ - # get function used to apply transform - fn, axis = _get_transform_method(module_type, location) - - # reshape for head_dim - head_dim = weight.shape[0] - num_heads = value.shape[axis] // head_dim - value = value.unflatten(axis, (num_heads, head_dim)) - - # apply transform - value = fn(weight, value) - - # [undo] reshape for head_dim - value = value.flatten(axis - 1, axis) - - return value - - -def _get_transform_method( - module_type: type[torch.nn.Module], - location: TransformLocation, -) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]: - """ - Using the transform location, determine how to apply the transform weight to the + Using the transform location, apply the transform_weight to the given value wrt linear weights. For more info on input and output transforms, see `TransformLocation` @@ -129,51 +95,89 @@ def _get_transform_method( = y U = yh - :param weight: transform weight to apply - :param value: value to apply weight to + :param transform_weight: transform weight to apply + :param value: value to apply transform_weight to :param location: determines how weight should be applied - :return: value after transform weight has been applied + :param model_type: result of type(module), passed in to determine application of + weight transform + :return: value after transform_weight has been applied """ - fn = axis = None + + assert transform_weight.shape[0] == transform_weight.shape[1] if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - fn = lambda weight, value: value @ weight - axis = -1 + if location in ( + TransformLocation.INPUT, + TransformLocation.ATTN_Q, + TransformLocation.ATTN_K, + ): + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.WEIGHT_INPUT: - fn = lambda weight, value: value @ weight.T - axis = -1 + # equivalent to (transform_weight @ value.T).T + return _multihead_matmul(value, transform_weight.T) elif location == TransformLocation.WEIGHT_OUTPUT: - fn = lambda weight, value: weight.T @ value - axis = -2 + # equivalent to (value.T @ transform_weight).T + return _multihead_matmul(transform_weight.T, value) elif location == TransformLocation.OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) # similar derivation to torch.nn.Linear, but `y = (x W)` - if module_type == torch.nn.Embedding: + elif module_type == torch.nn.Embedding: if location == TransformLocation.INPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.WEIGHT_INPUT: - fn = lambda weight, value: weight @ value - axis = -1 + return _multihead_matmul( + transform_weight, + value, + ) elif location == TransformLocation.WEIGHT_OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) elif location == TransformLocation.OUTPUT: - fn = lambda weight, value: value @ weight - axis = -1 + return _multihead_matmul(value, transform_weight) + + raise NotImplementedError( + f"Applying transforms to {module_type} {location} is not supported" + ) - if fn is None: - raise NotImplementedError( - f"Applying transforms to {module_type} {location} is not supported" - ) - return fn, axis +def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs A @ B for last two dims of two matrices A and B that possibly + have different shapes, as is the case in multi-headed dimension. If + shapes are different, this is equivalent to converting the last two dims + of the smaller matrix into a block-diagonal matrix with the same shape as + the last two dims of the larger matrix. + + E.g. if A is half the size of B, this function will perform + [[A ] @ B + [ A]] + + If B is a third of the size of A, this function will perform + A @ [[B ] + [ B ] + [ B]] + + This function will error out if the shapes are not evenly divisble + + :param A: left-hand tensor + :param B: right-hand tensor + :return: result + """ + if A.shape[-1] > B.shape[-2]: + head_dim = B.shape[-2] + num_heads = A.shape[-1] // head_dim + A = A.unflatten(-1, (num_heads, head_dim)) + return (A @ B).flatten(-2, -1) + elif A.shape[-1] < B.shape[-2]: + head_dim = A.shape[-1] + num_heads = B.shape[-2] // head_dim + B = B.unflatten(-2, (num_heads, head_dim)) + return (A @ B).flatten(-3, -2) + else: + return A @ B diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index e08e4d49..83b2c8a6 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -14,8 +14,13 @@ import pytest import torch +from compressed_tensors.modeling.attention import call_attn_impl from compressed_tensors.transform import TransformArgs, TransformFactory from transformers import PretrainedConfig, PreTrainedModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaRotaryEmbedding, +) class TransformableModel(PreTrainedModel): @@ -34,65 +39,46 @@ def forward(self, x): return x -class MockAttention(torch.nn.Module): +class MockAttentionModel(PreTrainedModel): def __init__( - self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + skip_pos_embeddings: bool = False, + attn_implementation: str = "eager", ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - - self.num_key_value_groups = num_attention_heads // num_key_value_heads - self.head_dim = hidden_size // num_attention_heads - self.scaling = self.head_dim**-0.5 - assert hidden_size >= num_attention_heads * self.head_dim - - self.q_proj = torch.nn.Linear( - hidden_size, num_attention_heads * self.head_dim, bias=False - ) - self.k_proj = torch.nn.Linear( - hidden_size, num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = torch.nn.Linear( - hidden_size, num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = torch.nn.Linear( - num_attention_heads * self.head_dim, hidden_size, bias=False + config = PretrainedConfig( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=0.0, + attention_bias=False, + max_position_embeddings=128, + rope_theta=500000.0, + _attn_implementation_internal=attn_implementation, + _attn_implementation_autoset=False, ) + super().__init__(config) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.attn = LlamaAttention(config, layer_idx=0) + self.skip_pos_embeddings = skip_pos_embeddings def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, hidden_size = hidden_states.shape - hidden_shape = (batch_size, seq_len, -1, self.head_dim) + assert hidden_states.size(1) <= self.config.max_position_embeddings - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not self.skip_pos_embeddings: + position_ids = torch.arange(hidden_states.size(1)).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + zeros = torch.zeros(hidden_states.size(1), dtype=hidden_states.dtype) + position_embeddings = (zeros, zeros) - key_states = self.repeat_kv(key_states, self.num_key_value_groups) - value_states = self.repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + attn_output, _attn_weights = self.attn( + hidden_states, position_embeddings=position_embeddings, attention_mask=None ) - attn_weights = torch.nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape((batch_size, seq_len, -1)).contiguous() - - return self.o_proj(attn_output) - - def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return attn_output @pytest.fixture(scope="function") diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index c0225636..f5face63 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -22,7 +22,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from tests.test_transform.conftest import MockAttention +from tests.test_transform.conftest import MockAttentionModel from tests.testing_utils import requires_accelerate, requires_gpu @@ -122,33 +122,53 @@ def test_correctness_attention_heads(type, randomize, head_dim): hidden_size = 64 num_attention_heads = 8 - attention = MockAttention( + model = MockAttentionModel( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_key_value_heads=head_dim, + skip_pos_embeddings=False, + attn_implementation="eager", # TODO: fails with sdpa ) input = torch.rand(17, 5, hidden_size) - true_output = attention(input) + true_output = model(input) config = TransformConfig( config_groups={ - "": TransformScheme( + # "R3": TransformScheme( + # type=type, + # randomize=randomize, + # head_dim=head_dim, + # apply=[ + # TransformArgs(targets="attn.q_proj", location="output"), + # TransformArgs(targets="attn.k_proj", location="output"), + # ], + # ), + "R3": TransformScheme( + type=type, + randomize=randomize, + head_dim=head_dim, + apply=[ + TransformArgs(targets="attn", location="attn_q"), + TransformArgs(targets="attn", location="attn_k"), + ], + ), + "R2": TransformScheme( type=type, randomize=randomize, head_dim=head_dim, apply=[ - TransformArgs(targets="v_proj", location="weight_output"), + TransformArgs(targets="attn.v_proj", location="weight_output"), TransformArgs( - targets="o_proj", location="weight_input", inverse=True + targets="attn.o_proj", location="weight_input", inverse=True ), ], - ) + ), } ) - apply_transform_config(attention, config) + apply_transform_config(model, config) - output = attention(input) + output = model(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)