diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index e503a057f..708201756 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -32,6 +32,33 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: raise NotImplementedError("Use subclasses for Pytorch transform") +class ProxyModuleMappingTransform(PytorchTransform): + """ + Replaces the PyTorch modules based on the _module_mapping class variable. + """ + + _module_mapping: Dict[Type[nn.Module], Type[nn.Module]] + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + for name, module in model.named_modules(): + for base_type, repl_type in cls._module_mapping.items(): + if isinstance(module, base_type): + if base_type is nn.Linear: + short_name = name.split(".")[-1] if name else "" + if short_name != "lm_head": + continue + # Perform in-place class replacement (preserve parameters/state) + try: + module.__class__ = repl_type + transformed = True + except Exception as e: + logger.warning(f"Failed to replace module {name} ({base_type}) -> {repl_type}: {e}") + + return model, transformed + + class ModuleMappingTransform(PytorchTransform): """ Replaces the PyTorch modules based on the _module_mapping class variable. diff --git a/QEfficient/proxy/__init__.py b/QEfficient/proxy/__init__.py new file mode 100644 index 000000000..410b674e5 --- /dev/null +++ b/QEfficient/proxy/__init__.py @@ -0,0 +1,13 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from QEfficient.proxy.proxy_transform import QeffProxyEmbedding, QeffProxyLinear + +__all__ = [ + "QeffProxyEmbedding", + "QeffProxyLinear", +] diff --git a/QEfficient/proxy/proxy_transform.py b/QEfficient/proxy/proxy_transform.py new file mode 100644 index 000000000..959dbfb65 --- /dev/null +++ b/QEfficient/proxy/proxy_transform.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +import torch +from torch import nn + + +class QeffProxyEmbedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim): + self.embed_tokens = None + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + def forward(self, hidden_states): + inputs_embeds = torch.unsqueeze(hidden_states.float(), 2).expand(-1, -1, self.embedding_dim) + return inputs_embeds + + +class QeffProxyLinear(nn.Module): + def __init__(self, in_features, out_features, bias=False): + self.lm_head = None + + def forward(self, hidden_states): + return hidden_states diff --git a/QEfficient/proxy/pytorch_transform.py b/QEfficient/proxy/pytorch_transform.py new file mode 100644 index 000000000..ce68474cd --- /dev/null +++ b/QEfficient/proxy/pytorch_transform.py @@ -0,0 +1,22 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch.nn as nn + +from QEfficient.base.pytorch_transforms import ProxyModuleMappingTransform +from QEfficient.proxy import QeffProxyEmbedding, QeffProxyLinear + + +class QeffProxyModuleTransform(ProxyModuleMappingTransform): + """ + This transform is used to replace the original modules with QEfficient modules. + """ + + _module_mapping = { + nn.Embedding: QeffProxyEmbedding, + nn.Linear: QeffProxyLinear, + } diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cbff5be91..825240d10 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter @@ -40,6 +41,7 @@ get_compilation_dims, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration +from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, @@ -601,6 +603,9 @@ def __init__(self, model: nn.modules, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") super().__init__(model, **kwargs) self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -762,6 +767,10 @@ def __init__(self, model, **kwargs): **kwargs : Additional keyword arguments passed to the base class constructor. """ + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -2246,6 +2255,7 @@ def from_pretrained( NotImplementedError If `continuous_batching` is provided as True. """ + enable_proxy = kwargs.pop("enable_proxy", False) # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. if continuous_batching and not kv_offload: NotImplementedError("Continuous batching is not supported for kv_offload = False") @@ -2259,6 +2269,9 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + if enable_proxy and kv_offload: + logger.info("Proxy Model Enabled for QEfficient Model") + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) return cls( model, kv_offload=kv_offload, @@ -2348,6 +2361,10 @@ def __init__( if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")): raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}") + if kwargs.pop("enable_proxy", False): + self._pytorch_transforms.append(QeffProxyModuleTransform) + logger.info("Proxy Model Enabled for QEfficient Model") + # TODO: remove from version 1.20 if kwargs.pop("full_batch_size", None): continuous_batching = True @@ -2452,6 +2469,7 @@ def from_pretrained( QEFFAutoModelForCausalLM An instance initialized with the pretrained weights. """ + enable_proxy = kwargs.pop("enable_proxy", False) if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -2472,7 +2490,7 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class - + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( model, @@ -3062,6 +3080,7 @@ def generate( **kwargs : Additional keyword arguments. Currently supports: - `generation_len (int, optional)`: The maximum number of tokens to generate. + - `write_io (bool, optional)`: Whether to save the io files. Returns ------- @@ -3079,6 +3098,7 @@ def generate( if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) + write_io = kwargs.pop("write_io", False) return QEfficient.cloud_ai_100_exec_kv( tokenizer=tokenizer, qpc_path=self.qpc_path, @@ -3090,6 +3110,7 @@ def generate( automation=kwargs.pop("automation", False), iteration=kwargs.pop("iteration", 1), is_tlm=self.is_tlm, + write_io_dir=os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None, **kwargs, ) else: diff --git a/examples/proxy_model_export.py b/examples/proxy_model_export.py new file mode 100644 index 000000000..ea5607450 --- /dev/null +++ b/examples/proxy_model_export.py @@ -0,0 +1,17 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +model = QEFFAutoModelForCausalLM.from_pretrained( + "gpt2", num_hidden_layers=2, enable_proxy=True +) # enable_proxy=True to use proxy model export i.e., export model disable the embedding and LM head layers +model.compile(num_cores=16) +tokenizer = AutoTokenizer.from_pretrained("gpt2") +model.generate(prompts=["Hi there!!"], tokenizer=tokenizer, write_io=True) # write_io = True to save io files