From 3c9039003c5c407a34380d5097734fa665d67473 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Wed, 19 Nov 2025 06:06:20 +0000 Subject: [PATCH 1/2] Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library. The Transform enables KV-head replication for CausalLMs and VLMs as well. The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model. n_kv_head_repeat param acts as the multiplier for the number of repeats to be done to original count of KV heads. This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict. Also added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model. Two new optional paramters n_kv_head_repeat and test_kv_replicate are added for testing purpose. Setting test_kv_replicate to True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads. Signed-off-by: Dhiraj Kumar Sah --- .../transformers/models/modeling_auto.py | 17 ++ .../transformers/models/pytorch_transforms.py | 149 ++++++++++++++++++ .../models/test_causal_lm_models.py | 88 +++++++++++ .../models/test_image_text_to_text_models.py | 74 ++++++++- 4 files changed, 327 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f1ec51e6..cdff8161e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -43,6 +43,7 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + ReplicateKVHeadTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -888,6 +889,11 @@ def __init__( self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + # Since both modules use the entire config for hash creation, we're updating the params for consistency. + if replicate_kv_transformed: + self.lang_model.hash_params["config"] = model.config.to_diff_dict() + self.vision_model.hash_params["config"] = model.config.to_diff_dict() self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @@ -1570,6 +1576,9 @@ def __init__( self.model.config.text_config.use_cache = True else: self.model.config.use_cache = True + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() self.hash_params["qeff_auto_class"] = self.__class__.__name__ @classmethod @@ -2182,8 +2191,10 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) return cls( model, kv_offload=kv_offload, @@ -2288,6 +2299,9 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) super().__init__(model, qaic_config=qaic_config, **kwargs) + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + if replicate_kv_transformed: + self.hash_params["config"] = model.config.to_diff_dict() self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config @@ -2389,7 +2403,10 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # InternVL causes an error if we pass the n_kv_head_repeat parameter + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..31b5e4ce6 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -9,6 +9,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +import torch from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -424,8 +425,12 @@ QEffWhisperPositionalEmbedding, ) from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -630,6 +635,150 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class ReplicateKVHeadTransform: + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + _module_mapping = { + QEffCodeGenForCausalLM, + QEffFalconForCausalLM, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffLlamaForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlavaForConditionalGeneration, + QEffLlavaNextForConditionalGeneration, + QEffMllamaForConditionalGeneration, + QEffGemmaForCausalLM, + QEffQwen3MoeForCausalLM, + QEffGemma2ForCausalLM, + QEffGemma3ForConditionalGeneration, + QEffPhi3ForCausalLM, + QEffPhiForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_ForConditionalGeneration, + QEffStarcoder2ForCausalLM, + QEffGPTBigCodeForCausalLM, + QEffOlmo2ForCausalLM, + } + _module_string_mapping = { + "InternVLChatModel", + } + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int + ): + new_kv_heads = repeat * orig_kv_heads + if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)): + if head_dim % 8 != 0: + raise ValueError( + f"the value head_dim={head_dim} is not divisible by 8 which is \ + according to the assumption that model is 4-bit quantized." + ) + if hidden_size % layer.group_size != 0: + raise ValueError( + f"The value of hidden_size={hidden_size} is not divisible by \ + K_proj.group_size={layer.group_size}" + ) + + # Duplication of quantized weights + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 + ).view(hidden_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantized zero points + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), + repeat, + 1, + ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantization scales + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), + repeat, + 1, + ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, FP8DeQuantLinear): + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + layer.weight_scale.data = torch.repeat_interleave( + layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim, -1) + + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + # Check for CausalLMs + if hasattr(model, "model"): + return model.model + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def apply(cls, model: nn.Module, **kwargs) -> nn.Module: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - n_kv_head_repeat: The number of times to repeat the KV heads. + """ + n_repeat = kwargs.pop("n_kv_head_repeat", 1) + transformed = False + if n_repeat > 1: + if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + text_model = cls._get_text_model(model) + + orig_kv_heads = text_model.config.num_key_value_heads + new_kv_heads = n_repeat * orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + num_attention_heads = text_model.config.num_attention_heads + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + transformed = True + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + attn.num_key_value_groups = num_attention_heads // new_kv_heads + + cls._duplicate_weights_for_linear_layer( + attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + cls._duplicate_weights_for_linear_layer( + attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + else: + raise NotImplementedError( + f"Model class {model.__class__.__name__} is not supported for KV head replication." + ) + return model, transformed + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 321a466ab..782712d50 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -282,6 +282,72 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + n_kv_head_repeat: int = 1, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, +): + """ + Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + :n_kv_head_repeat (int): Number of times to repeat KV heads. + """ + replace_transformers_quantizers() + if config is None: + n_layer = get_custom_n_layers(model_name) + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # TODO: Add support for custom repeat_kv in models to hands uneven replications. + # Generate n_kv_head_repeat from config so that divisibility error doesn't occur. + n_kv_head_repeat = config.num_attention_heads // config.num_key_value_heads + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), + pretrained_model_name_or_path=model_name, + n_kv_head_repeat=n_kv_head_repeat, + ) + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + gen_len = len(pytorch_hf_tokens) + cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] + assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for Pytorch HF output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # FIXME: there should be a CB test here @pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x) def test_causal_lm_export_with_deprecated_api(model_name): @@ -360,6 +426,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): ) +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + # Using fixed reference tokens for external models for specific test cases. + # These tokens are hardcoded, therefore will not match if the model config changes. + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"] + + check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name=model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens + ) + + @pytest.mark.on_qaic @pytest.mark.regular @pytest.mark.qnn diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index e6a145195..941c289e5 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -188,7 +188,7 @@ # "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg", # "Please describe the image in detail.", # 2, - # ), # commented becuase QNN Convertor is not supported for this model yet. + # ), ] molmo_model_config = [ @@ -249,6 +249,14 @@ def set_num_layers(config, n_layer=1): return config +def get_text_config(config): + if hasattr(config, "text_config"): + return config.text_config + elif hasattr(config, "llm_config"): + return config.llm_config + return config + + def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, img_size: int, @@ -263,6 +271,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( num_devices: int = 1, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + n_kv_head_repeat: Optional[int] = None, + test_kv_replicate: Optional[bool] = None, ): model_config = {"model_name": model_name} model_config["img_size"] = img_size @@ -304,10 +314,15 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) streamer = TextStreamer(processor.tokenizer) pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs) + if test_kv_replicate: + text_config = get_text_config(config) + n_kv_head_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_config["model_name"], kv_offload=kv_offload, config=config, + n_kv_head_repeat=n_kv_head_repeat, ) # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) @@ -428,6 +443,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( num_devices: int = 1, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + n_kv_head_repeat: Optional[int] = None, + test_kv_replicate: Optional[bool] = None, ): model_config = {"model_name": model_name} @@ -490,10 +507,15 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( ) pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config) + if test_kv_replicate: + text_config = get_text_config(config) + n_kv_head_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( model_config["model_name"], kv_offload=kv_offload, config=config, + n_kv_head_repeat=n_kv_head_repeat, ) # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( @@ -551,6 +573,34 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config +) +def test_replicate_kv_pytorch_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + test_kv_replicate=True, + ) + + @pytest.mark.on_qaic @pytest.mark.qnn @pytest.mark.multimodal @@ -608,6 +658,28 @@ def test_image_text_to_text_molmo_pytorch_vs_kv_vs_ort_vs_ai100( ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", intern_model_config +) +def test_replicate_kv_intern_pytorch_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer +): + check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + test_kv_replicate=True, + ) + + @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.parametrize( From c3b72d016e3b98c45612b30d38c779eff4e0a8a4 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Tue, 25 Nov 2025 09:11:21 +0000 Subject: [PATCH 2/2] Modified modeling_auto to use 1 as default value of n_kv_head_repeat. Doing so would prevent any issues during Transforms when we don't wish to apply it. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/transformers/models/modeling_auto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cdff8161e..9787ad189 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2191,7 +2191,7 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None) + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) @@ -2404,7 +2404,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) # InternVL causes an error if we pass the n_kv_head_repeat parameter - n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None) + n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) kwargs.update({"n_kv_head_repeat": n_kv_head_repeat}) if qaic_config is not None: