Skip to content

Commit 1dfdea6

Browse files
committed
Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library.
The Transform enables KV-head replication for CausalLMs and VLMs as well. The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model. n_kv_head_repeat param acts as the multiplier for the number of repeats to be done to original count of KV heads. This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict. Also added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model. Two new optional paramters n_kv_head_repeat and test_kv_replicate are added for testing purpose. Setting test_kv_replicate to True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads. Signed-off-by: Dhiraj Kumar Sah <[email protected]>
1 parent 04f1ad7 commit 1dfdea6

File tree

4 files changed

+327
-1
lines changed

4 files changed

+327
-1
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
KVCacheExternalModuleMapperTransform,
4444
KVCacheTransform,
4545
PoolingTransform,
46+
ReplicateKVHeadTransform,
4647
SamplerTransform,
4748
SpDTransform,
4849
VlmKVOffloadTransform,
@@ -883,6 +884,11 @@ def __init__(
883884
self.config = model.config
884885
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
885886
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
887+
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
888+
# Since both modules use the entire config for hash creation, we're updating the params for consistency.
889+
if replicate_kv_transformed:
890+
self.lang_model.hash_params["config"] = model.config.to_diff_dict()
891+
self.vision_model.hash_params["config"] = model.config.to_diff_dict()
886892
self.continuous_batching = continuous_batching
887893
self.input_shapes, self.output_names = None, None
888894

@@ -1511,6 +1517,9 @@ def __init__(
15111517
self.model.config.text_config.use_cache = True
15121518
else:
15131519
self.model.config.use_cache = True
1520+
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
1521+
if replicate_kv_transformed:
1522+
self.hash_params["config"] = model.config.to_diff_dict()
15141523
self.hash_params["qeff_auto_class"] = self.__class__.__name__
15151524

15161525
@classmethod
@@ -2063,7 +2072,9 @@ def from_pretrained(
20632072
logger.warning("Updating low_cpu_mem_usage=False")
20642073

20652074
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2075+
n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None)
20662076
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
2077+
kwargs.update({"n_kv_head_repeat": n_kv_head_repeat})
20672078
return cls(
20682079
model,
20692080
kv_offload=kv_offload,
@@ -2164,6 +2175,9 @@ def __init__(
21642175
# Set use_cache=True to get KV values as output during ONNX export
21652176
model.config.use_cache = True
21662177
super().__init__(model, qaic_config=qaic_config, **kwargs)
2178+
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
2179+
if replicate_kv_transformed:
2180+
self.hash_params["config"] = model.config.to_diff_dict()
21672181
self.num_layers = model.config.num_hidden_layers
21682182
self.continuous_batching = continuous_batching
21692183
self.model.qaic_config = qaic_config
@@ -2265,7 +2279,10 @@ def from_pretrained(
22652279
kv_offload = kwargs.pop("kv_offload", None)
22662280

22672281
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2282+
# InternVL causes an error if we pass the n_kv_head_repeat parameter
2283+
n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", None)
22682284
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
2285+
kwargs.update({"n_kv_head_repeat": n_kv_head_repeat})
22692286
if qaic_config is not None:
22702287
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
22712288

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from types import MethodType
1010
from typing import Callable, Optional, Tuple, Union
1111

12+
import torch
1213
from torch import nn
1314
from transformers.models.codegen.modeling_codegen import (
1415
CodeGenAttention,
@@ -424,8 +425,12 @@
424425
QEffWhisperPositionalEmbedding,
425426
)
426427
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
428+
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
429+
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
430+
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
427431
from QEfficient.transformers.sampler.sampler import sampler_forward
428432
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
433+
from QEfficient.utils.logging_utils import logger
429434

430435
SPD_TARGET = "target"
431436

@@ -630,6 +635,150 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
630635
return model, transformed
631636

632637

638+
class ReplicateKVHeadTransform:
639+
"""
640+
Replicates KV heads in attention modules to match the number of KV heads in the target model.
641+
This transform is used when the source model has fewer KV heads than required in target model.
642+
"""
643+
644+
_module_mapping = {
645+
QEffCodeGenForCausalLM,
646+
QEffFalconForCausalLM,
647+
QEffGPT2LMHeadModel,
648+
QEffGPTJForCausalLM,
649+
QEffLlamaForCausalLM,
650+
QEffLlama4ForConditionalGeneration,
651+
QEffLlavaForConditionalGeneration,
652+
QEffLlavaNextForConditionalGeneration,
653+
QEffMllamaForConditionalGeneration,
654+
QEffGemmaForCausalLM,
655+
QEffQwen3MoeForCausalLM,
656+
QEffGemma2ForCausalLM,
657+
QEffGemma3ForConditionalGeneration,
658+
QEffPhi3ForCausalLM,
659+
QEffPhiForCausalLM,
660+
QEffQwen2ForCausalLM,
661+
QEffQwen_2_5_vl_ForConditionalGeneration,
662+
QEffStarcoder2ForCausalLM,
663+
QEffGPTBigCodeForCausalLM,
664+
QEffOlmo2ForCausalLM,
665+
}
666+
_module_string_mapping = {
667+
"InternVLChatModel",
668+
}
669+
670+
def _duplicate_weights_for_linear_layer(
671+
layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
672+
):
673+
new_kv_heads = repeat * orig_kv_heads
674+
if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)):
675+
if head_dim % 8 != 0:
676+
raise ValueError(
677+
f"the value head_dim={head_dim} is not divisible by 8 which is \
678+
according to the assumption that model is 4-bit quantized."
679+
)
680+
if hidden_size % layer.group_size != 0:
681+
raise ValueError(
682+
f"The value of hidden_size={hidden_size} is not divisible by \
683+
K_proj.group_size={layer.group_size}"
684+
)
685+
686+
# Duplication of quantized weights
687+
layer.qweight.data = torch.repeat_interleave(
688+
layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1
689+
).view(hidden_size, (new_kv_heads * head_dim) // 8)
690+
# Duplication of quantized zero points
691+
layer.qzeros.data = torch.repeat_interleave(
692+
layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8),
693+
repeat,
694+
1,
695+
).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8)
696+
# Duplication of quantization scales
697+
layer.scales.data = torch.repeat_interleave(
698+
layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim),
699+
repeat,
700+
1,
701+
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
702+
layer.out_features = layer.out_features * repeat
703+
704+
elif isinstance(layer, FP8DeQuantLinear):
705+
layer.weight.data = torch.repeat_interleave(
706+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
707+
).view(new_kv_heads * head_dim, hidden_size)
708+
layer.weight_scale.data = torch.repeat_interleave(
709+
layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0
710+
).view(new_kv_heads * head_dim, -1)
711+
712+
else:
713+
layer.weight.data = torch.repeat_interleave(
714+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
715+
).view(new_kv_heads * head_dim, hidden_size)
716+
if layer.bias is not None:
717+
layer.bias.data = torch.repeat_interleave(
718+
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
719+
).view(new_kv_heads * head_dim)
720+
721+
def _get_text_model(model):
722+
"""
723+
Determine and return the appropriate text_model from a given model object.
724+
"""
725+
# Check for VLMs
726+
if hasattr(model, "language_model"):
727+
if hasattr(model.language_model, "model"):
728+
return model.language_model.model
729+
else:
730+
return model.language_model
731+
# Check for CausalLMs
732+
if hasattr(model, "model"):
733+
return model.model
734+
735+
raise AttributeError("No suitable text model found in the provided model.")
736+
737+
@classmethod
738+
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
739+
"""
740+
Replicates KV heads in attention modules based on provided multiplier.
741+
742+
Args:
743+
model: The model to apply the transform to.
744+
kwargs: Additional arguments for the transformation. Includes:
745+
- n_kv_head_repeat: The number of times to repeat the KV heads.
746+
"""
747+
n_repeat = kwargs.pop("n_kv_head_repeat", 1)
748+
transformed = False
749+
if n_repeat > 1:
750+
if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping):
751+
text_model = cls._get_text_model(model)
752+
753+
orig_kv_heads = text_model.config.num_key_value_heads
754+
new_kv_heads = n_repeat * orig_kv_heads
755+
text_model.config.orig_kv_heads = orig_kv_heads
756+
text_model.config.num_key_value_heads = new_kv_heads
757+
758+
num_attention_heads = text_model.config.num_attention_heads
759+
hidden_size = text_model.config.hidden_size
760+
761+
logger.warning(f"Original KV heads: {orig_kv_heads}")
762+
logger.warning(f"Modified KV heads: {new_kv_heads}")
763+
transformed = True
764+
for block in text_model.layers:
765+
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
766+
attn.num_key_value_heads = new_kv_heads
767+
attn.num_key_value_groups = num_attention_heads // new_kv_heads
768+
769+
cls._duplicate_weights_for_linear_layer(
770+
attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
771+
)
772+
cls._duplicate_weights_for_linear_layer(
773+
attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
774+
)
775+
else:
776+
raise NotImplementedError(
777+
f"Model class {model.__class__.__name__} is not supported for KV head replication."
778+
)
779+
return model, transformed
780+
781+
633782
class SpDTransform:
634783
"""
635784
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.

tests/transformers/models/test_causal_lm_models.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,72 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
282282
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
283283

284284

285+
def check_kv_repeat_causal_lm_pytorch_vs_ai100(
286+
model_name: str,
287+
prompt_len: int = Constants.PROMPT_LEN,
288+
ctx_len: int = Constants.CTX_LEN,
289+
n_layer: int = 1,
290+
n_kv_head_repeat: int = 1,
291+
config: Optional[AutoConfig] = None,
292+
pytorch_hf_tokens: Optional[list] = None,
293+
):
294+
"""
295+
Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads.
296+
``Mandatory`` Args:
297+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
298+
:prompt_len (int): Prompt length for the model to compile.
299+
:ctx_len (int): Maximum context length to compile the model.
300+
:n_layers (int): Number of layers for the Model.
301+
:n_kv_head_repeat (int): Number of times to repeat KV heads.
302+
"""
303+
replace_transformers_quantizers()
304+
if config is None:
305+
n_layer = get_custom_n_layers(model_name)
306+
model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer)
307+
else:
308+
model_hf, _ = load_causal_lm_model(model_name, config=config)
309+
310+
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
311+
config = model_hf.config
312+
batch_size = len(Constants.INPUT_STR)
313+
api_runner = ApiRunner(
314+
batch_size,
315+
tokenizer,
316+
config,
317+
Constants.INPUT_STR,
318+
Constants.PROMPT_LEN,
319+
Constants.CTX_LEN,
320+
)
321+
if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS:
322+
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
323+
324+
# TODO: Add support for custom repeat_kv in models to hands uneven replications.
325+
# Generate n_kv_head_repeat from config so that divisibility error doesn't occur.
326+
n_kv_head_repeat = config.num_attention_heads // config.num_key_value_heads
327+
qeff_model = QEFFAutoModelForCausalLM(
328+
copy.deepcopy(model_hf),
329+
pretrained_model_name_or_path=model_name,
330+
n_kv_head_repeat=n_kv_head_repeat,
331+
)
332+
333+
if not get_available_device_id():
334+
pytest.skip("No available devices to run model on Cloud AI 100")
335+
qpc_path = qeff_model.compile(
336+
prefill_seq_len=prompt_len,
337+
ctx_len=ctx_len,
338+
num_cores=14,
339+
mxfp6=False,
340+
aic_enable_depth_first=False,
341+
)
342+
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
343+
gen_len = len(pytorch_hf_tokens)
344+
cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len]
345+
assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), (
346+
"Tokens don't match for Pytorch HF output and Cloud AI 100 output."
347+
)
348+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
349+
350+
285351
# FIXME: there should be a CB test here
286352
@pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x)
287353
def test_causal_lm_export_with_deprecated_api(model_name):
@@ -360,6 +426,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
360426
)
361427

362428

429+
@pytest.mark.nightly
430+
@pytest.mark.on_qaic
431+
@pytest.mark.parametrize("model_name", test_models_causal)
432+
def test_check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name):
433+
"""
434+
Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads.
435+
``Mandatory`` Args:
436+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
437+
"""
438+
n_layer = get_custom_n_layers(model_name)
439+
440+
# Using fixed reference tokens for external models for specific test cases.
441+
# These tokens are hardcoded, therefore will not match if the model config changes.
442+
pytorch_hf_tokens = None
443+
if model_name in ModelConfig.EXTERNAL_MODELS:
444+
pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"]
445+
446+
check_kv_repeat_causal_lm_pytorch_vs_ai100(
447+
model_name=model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens
448+
)
449+
450+
363451
@pytest.mark.on_qaic
364452
@pytest.mark.regular
365453
@pytest.mark.qnn

0 commit comments

Comments
 (0)