Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
ReplicateKVHeadTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
Expand Down Expand Up @@ -888,6 +889,11 @@ def __init__(

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
# Since both modules use the entire config for hash creation, we're updating the params for consistency.
if replicate_kv_transformed:
self.lang_model.hash_params["config"] = model.config.to_diff_dict()
self.vision_model.hash_params["config"] = model.config.to_diff_dict()
self.continuous_batching = continuous_batching
self.input_shapes, self.output_names = None, None

Expand Down Expand Up @@ -1570,6 +1576,9 @@ def __init__(
self.model.config.text_config.use_cache = True
else:
self.model.config.use_cache = True
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()
self.hash_params["qeff_auto_class"] = self.__class__.__name__

@classmethod
Expand Down Expand Up @@ -2182,8 +2191,10 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1)

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
kwargs.update({"n_kv_head_repeat": n_kv_head_repeat})
return cls(
model,
kv_offload=kv_offload,
Expand Down Expand Up @@ -2288,6 +2299,9 @@ def __init__(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
self.model.qaic_config = qaic_config
Expand Down Expand Up @@ -2389,7 +2403,10 @@ def from_pretrained(
kv_offload = kwargs.pop("kv_offload", None)

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
# InternVL causes an error if we pass the n_kv_head_repeat parameter
n_kv_head_repeat = kwargs.pop("n_kv_head_repeat", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
kwargs.update({"n_kv_head_repeat": n_kv_head_repeat})
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path

Expand Down
149 changes: 149 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import MethodType
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -424,8 +425,12 @@
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
from QEfficient.utils.logging_utils import logger

SPD_TARGET = "target"

Expand Down Expand Up @@ -630,6 +635,150 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
return model, transformed


class ReplicateKVHeadTransform:
"""
Replicates KV heads in attention modules to match the number of KV heads in the target model.
This transform is used when the source model has fewer KV heads than required in target model.
"""

_module_mapping = {
QEffCodeGenForCausalLM,
QEffFalconForCausalLM,
QEffGPT2LMHeadModel,
QEffGPTJForCausalLM,
QEffLlamaForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlavaForConditionalGeneration,
QEffLlavaNextForConditionalGeneration,
QEffMllamaForConditionalGeneration,
QEffGemmaForCausalLM,
QEffQwen3MoeForCausalLM,
QEffGemma2ForCausalLM,
QEffGemma3ForConditionalGeneration,
QEffPhi3ForCausalLM,
QEffPhiForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen_2_5_vl_ForConditionalGeneration,
QEffStarcoder2ForCausalLM,
QEffGPTBigCodeForCausalLM,
QEffOlmo2ForCausalLM,
}
_module_string_mapping = {
"InternVLChatModel",
}

def _duplicate_weights_for_linear_layer(
layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
):
new_kv_heads = repeat * orig_kv_heads
if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)):
if head_dim % 8 != 0:
raise ValueError(
f"the value head_dim={head_dim} is not divisible by 8 which is \
according to the assumption that model is 4-bit quantized."
)
if hidden_size % layer.group_size != 0:
raise ValueError(
f"The value of hidden_size={hidden_size} is not divisible by \
K_proj.group_size={layer.group_size}"
)

# Duplication of quantized weights
layer.qweight.data = torch.repeat_interleave(
layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1
).view(hidden_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantized zero points
layer.qzeros.data = torch.repeat_interleave(
layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8),
repeat,
1,
).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantization scales
layer.scales.data = torch.repeat_interleave(
layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim),
repeat,
1,
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
layer.out_features = layer.out_features * repeat

elif isinstance(layer, FP8DeQuantLinear):
layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)
layer.weight_scale.data = torch.repeat_interleave(
layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim, -1)

else:
layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)
if layer.bias is not None:
layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)

def _get_text_model(model):
"""
Determine and return the appropriate text_model from a given model object.
"""
# Check for VLMs
if hasattr(model, "language_model"):
if hasattr(model.language_model, "model"):
return model.language_model.model
else:
return model.language_model
# Check for CausalLMs
if hasattr(model, "model"):
return model.model

raise AttributeError("No suitable text model found in the provided model.")

@classmethod
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
"""
Replicates KV heads in attention modules based on provided multiplier.

Args:
model: The model to apply the transform to.
kwargs: Additional arguments for the transformation. Includes:
- n_kv_head_repeat: The number of times to repeat the KV heads.
"""
n_repeat = kwargs.pop("n_kv_head_repeat", 1)
transformed = False
if n_repeat > 1:
if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping):
text_model = cls._get_text_model(model)

orig_kv_heads = text_model.config.num_key_value_heads
new_kv_heads = n_repeat * orig_kv_heads
text_model.config.orig_kv_heads = orig_kv_heads
text_model.config.num_key_value_heads = new_kv_heads

num_attention_heads = text_model.config.num_attention_heads
hidden_size = text_model.config.hidden_size

logger.warning(f"Original KV heads: {orig_kv_heads}")
logger.warning(f"Modified KV heads: {new_kv_heads}")
transformed = True
for block in text_model.layers:
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
attn.num_key_value_heads = new_kv_heads
attn.num_key_value_groups = num_attention_heads // new_kv_heads

cls._duplicate_weights_for_linear_layer(
attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
)
cls._duplicate_weights_for_linear_layer(
attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
)
else:
raise NotImplementedError(
f"Model class {model.__class__.__name__} is not supported for KV head replication."
)
return model, transformed


class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
Expand Down
88 changes: 88 additions & 0 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,72 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))


def check_kv_repeat_causal_lm_pytorch_vs_ai100(
model_name: str,
prompt_len: int = Constants.PROMPT_LEN,
ctx_len: int = Constants.CTX_LEN,
n_layer: int = 1,
n_kv_head_repeat: int = 1,
config: Optional[AutoConfig] = None,
pytorch_hf_tokens: Optional[list] = None,
):
"""
Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
:prompt_len (int): Prompt length for the model to compile.
:ctx_len (int): Maximum context length to compile the model.
:n_layers (int): Number of layers for the Model.
:n_kv_head_repeat (int): Number of times to repeat KV heads.
"""
replace_transformers_quantizers()
if config is None:
n_layer = get_custom_n_layers(model_name)
model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer)
else:
model_hf, _ = load_causal_lm_model(model_name, config=config)

tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
config = model_hf.config
batch_size = len(Constants.INPUT_STR)
api_runner = ApiRunner(
batch_size,
tokenizer,
config,
Constants.INPUT_STR,
Constants.PROMPT_LEN,
Constants.CTX_LEN,
)
if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS:
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)

# TODO: Add support for custom repeat_kv in models to hands uneven replications.
# Generate n_kv_head_repeat from config so that divisibility error doesn't occur.
n_kv_head_repeat = config.num_attention_heads // config.num_key_value_heads
qeff_model = QEFFAutoModelForCausalLM(
copy.deepcopy(model_hf),
pretrained_model_name_or_path=model_name,
n_kv_head_repeat=n_kv_head_repeat,
)

if not get_available_device_id():
pytest.skip("No available devices to run model on Cloud AI 100")
qpc_path = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
num_cores=14,
mxfp6=False,
aic_enable_depth_first=False,
)
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
gen_len = len(pytorch_hf_tokens)
cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len]
assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), (
"Tokens don't match for Pytorch HF output and Cloud AI 100 output."
)
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))


# FIXME: there should be a CB test here
@pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x)
def test_causal_lm_export_with_deprecated_api(model_name):
Expand Down Expand Up @@ -360,6 +426,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
)


@pytest.mark.nightly
@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name", test_models_causal)
def test_check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name):
"""
Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
"""
n_layer = get_custom_n_layers(model_name)

# Using fixed reference tokens for external models for specific test cases.
# These tokens are hardcoded, therefore will not match if the model config changes.
pytorch_hf_tokens = None
if model_name in ModelConfig.EXTERNAL_MODELS:
pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"]

check_kv_repeat_causal_lm_pytorch_vs_ai100(
model_name=model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens
)


@pytest.mark.on_qaic
@pytest.mark.regular
@pytest.mark.qnn
Expand Down
Loading
Loading