diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f1ec51e6..9787ad189 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -43,6 +43,7 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + ReplicateKVHeadTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -888,6 +889,11 @@ def __init__( self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + # Since both modules use the entire config for hash creation, we're updating the params for consistency. + if replicate_kv_transformed: + self.lang_model.hash_params["config"] = model.config.to_diff_dict() + self.vision_model.hash_params["config"] = model.config.to_diff_dict() self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @@ -1570,6 +1576,9 @@ def __init__( self.model.config.text_config.use_cache = True else: self.model.config.use_cache = True + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() self.hash_params["qeff_auto_class"] = self.__class__.__name__ @classmethod @@ -2182,8 +2191,10 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) return cls( model, kv_offload=kv_offload, @@ -2288,6 +2299,9 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) super().__init__(model, qaic_config=qaic_config, **kwargs) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config @@ -2389,7 +2403,10 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # InternVL causes an error if we pass the n_kv_head_repeat parameter + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..31b5e4ce6 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -9,6 +9,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -424,8 +425,12 @@ QEffWhisperPositionalEmbedding, ) from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -630,6 +635,150 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class ReplicateKVHeadTransform: + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + _module_mapping = { + QEffCodeGenForCausalLM, + QEffFalconForCausalLM, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffLlamaForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlavaForConditionalGeneration, + QEffLlavaNextForConditionalGeneration, + QEffMllamaForConditionalGeneration, + QEffGemmaForCausalLM, + QEffQwen3MoeForCausalLM, + QEffGemma2ForCausalLM, + QEffGemma3ForConditionalGeneration, + QEffPhi3ForCausalLM, + QEffPhiForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_ForConditionalGeneration, + QEffStarcoder2ForCausalLM, + QEffGPTBigCodeForCausalLM, + QEffOlmo2ForCausalLM, + } + _module_string_mapping = { + "InternVLChatModel", + } + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int + ): + new_kv_heads = repeat * orig_kv_heads + if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)): + if head_dim % 8 != 0: + raise ValueError( + f"the value head_dim={head_dim} is not divisible by 8 which is \ + according to the assumption that model is 4-bit quantized." + ) + if hidden_size % layer.group_size != 0: + raise ValueError( + f"The value of hidden_size={hidden_size} is not divisible by \ + K_proj.group_size={layer.group_size}" + ) + + # Duplication of quantized weights + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 + ).view(hidden_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantized zero points + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), + repeat, + 1, + ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantization scales + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), + repeat, + 1, + ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, FP8DeQuantLinear): + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + layer.weight_scale.data = torch.repeat_interleave( + layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim, -1) + + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + # Check for CausalLMs + if hasattr(model, "model"): + return model.model + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def apply(cls, model: nn.Module, **kwargs) -> nn.Module: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - n_kv_head_repeat: The number of times to repeat the KV heads. + """ + n_repeat = kwargs.pop("n_kv_head_repeat", 1) + transformed = False + if n_repeat > 1: + if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + text_model = cls._get_text_model(model) + + orig_kv_heads = text_model.config.num_key_value_heads + new_kv_heads = n_repeat * orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + num_attention_heads = text_model.config.num_attention_heads + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + transformed = True + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + attn.num_key_value_groups = num_attention_heads // new_kv_heads + + cls._duplicate_weights_for_linear_layer( + attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + cls._duplicate_weights_for_linear_layer( + attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + else: + raise NotImplementedError( + f"Model class {model.__class__.__name__} is not supported for KV head replication." + ) + return model, transformed + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 321a466ab..782712d50 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -282,6 +282,72 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + n_kv_head_repeat: int = 1, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, +): + """ + Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + :n_kv_head_repeat (int): Number of times to repeat KV heads. + """ + replace_transformers_quantizers() + if config is None: + n_layer = get_custom_n_layers(model_name) + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # TODO: Add support for custom repeat_kv in models to hands uneven replications. + # Generate n_kv_head_repeat from config so that divisibility error doesn't occur. + n_kv_head_repeat = config.num_attention_heads // config.num_key_value_heads + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), + pretrained_model_name_or_path=model_name, + n_kv_head_repeat=n_kv_head_repeat, + ) + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + gen_len = len(pytorch_hf_tokens) + cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] + assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for Pytorch HF output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # FIXME: there should be a CB test here @pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x) def test_causal_lm_export_with_deprecated_api(model_name): @@ -360,6 +426,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): ) +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + # Using fixed reference tokens for external models for specific test cases. + # These tokens are hardcoded, therefore will not match if the model config changes. + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"] + + check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name=model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens + ) + + @pytest.mark.on_qaic @pytest.mark.regular @pytest.mark.qnn diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index e6a145195..941c289e5 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -188,7 +188,7 @@ # "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg", # "Please describe the image in detail.", # 2, - # ), # commented becuase QNN Convertor is not supported for this model yet. + # ), ] molmo_model_config = [ @@ -249,6 +249,14 @@ def set_num_layers(config, n_layer=1): return config +def get_text_config(config): + if hasattr(config, "text_config"): + return config.text_config + elif hasattr(config, "llm_config"): + return config.llm_config + return config + + def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, img_size: int, @@ -263,6 +271,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( num_devices: int = 1, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + n_kv_head_repeat: Optional[int] = None, + test_kv_replicate: Optional[bool] = None, ): model_config = {"model_name": model_name} model_config["img_size"] = img_size @@ -304,10 +314,15 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) streamer = TextStreamer(processor.tokenizer) pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs) + if test_kv_replicate: + text_config = get_text_config(config) + n_kv_head_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_config["model_name"], kv_offload=kv_offload, config=config, + n_kv_head_repeat=n_kv_head_repeat, ) # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) @@ -428,6 +443,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( num_devices: int = 1, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + n_kv_head_repeat: Optional[int] = None, + test_kv_replicate: Optional[bool] = None, ): model_config = {"model_name": model_name} @@ -490,10 +507,15 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( ) pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config) + if test_kv_replicate: + text_config = get_text_config(config) + n_kv_head_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( model_config["model_name"], kv_offload=kv_offload, config=config, + n_kv_head_repeat=n_kv_head_repeat, ) # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( @@ -551,6 +573,34 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config +) +def test_replicate_kv_pytorch_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + test_kv_replicate=True, + ) + + @pytest.mark.on_qaic @pytest.mark.qnn @pytest.mark.multimodal @@ -608,6 +658,28 @@ def test_image_text_to_text_molmo_pytorch_vs_kv_vs_ort_vs_ai100( ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", intern_model_config +) +def test_replicate_kv_intern_pytorch_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer +): + check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + test_kv_replicate=True, + ) + + @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.parametrize(