Skip to content

Optimized ONNX Transform via Class Merging and Thread Pooling #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 60 additions & 69 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,97 +5,88 @@
#
# ----------------------------------------------------------------------------

import os
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple

import numpy as np
from onnx import ModelProto, external_data_helper, numpy_helper
from onnx import ModelProto, TensorProto, external_data_helper, numpy_helper


class OnnxTransform:
"""
OnnxTransform is the base class for graph modifications on exported onnx.
"""

def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")
raise TypeError("Transform classes are not to be instantiated. Use the `apply` method directly.")

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
Override this class to apply a transformation.
:param model: The model's ONNX graph to transform
:param kwargs: Parameters needed for specific transforms. All transforms should take **kwargs to ignore unneeded kwargs.

:returns: ONNX graph after applying the transform
:returns: Boolean indicating whether transform was applied
"""
raise NotImplementedError("Use subclasses for ONNX transform")


class FP16ClipTransform(OnnxTransform):
"""
Clips the tensor values to be in FP16 range, but preserves -inf values.
"""

@classmethod
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
"""
:param onnx_base_dir: Base directory to load tensors
"""
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False

for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)

# Restore -inf values
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)

new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True

return model, transformed


class SplitTensorsTransform(OnnxTransform):
"""
Split external tensors file
"""

class ClipAndSplitTransform(OnnxTransform):
@classmethod
def apply(
cls,
model: ModelProto,
*,
model_name: str,
onnx_base_dir: Optional[str] = None,
file_chunk_size: int = 10 * 2**30, # 10 GiB
apply_clip: bool = True,
apply_split: bool = True,
file_chunk_size: int = 10 * 2**30,
size_threshold: int = 1024,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data
:param onnx_base_dir: Base directory to load tensors (if not already loaded).
:param file_chunk_size: Chunk size to split external files into.
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
"""
file_num = 0
current_file_size = 0
transformed = False
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
tensors = external_data_helper._get_all_tensors(model)

TensorInfo = namedtuple("TensorInfo", ["tensor", "tsize"])
tensor_infos = [
TensorInfo(tensor, len(tensor.raw_data) if tensor.HasField("raw_data") else 0) for tensor in tensors
]

fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}

def process_tensor(info: TensorInfo) -> bool:
tensor, tsize = info
transformed = False

if apply_clip and cls._clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max):
transformed = True
current_file_size += tsize
if current_file_size > file_chunk_size:
file_num += 1
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed

if apply_split and tsize > size_threshold:
if file_num_tracker["size"] + tsize > file_chunk_size:
file_num_tracker["num"] += 1
file_num_tracker["size"] = tsize
else:
file_num_tracker["size"] += tsize

cls._split_tensor(tensor, model_name, file_num_tracker["num"])
transformed = True

return transformed

with ThreadPoolExecutor(max_workers=os.cpu_count() * 4) as executor:
transformed_flags = list(executor.map(process_tensor, tensor_infos))
return model, any(transformed_flags)

@staticmethod
def _clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max) -> bool:
if tensor.data_type != TensorProto.FLOAT:
return False

nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
return True
return False

@staticmethod
def _split_tensor(tensor, model_name: str, file_num: int) -> None:
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
4 changes: 2 additions & 2 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from onnx import external_data_helper

from QEfficient.base.onnx_transforms import FP16ClipTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform


def export_onnx(
Expand Down Expand Up @@ -218,7 +218,7 @@ def fix_onnx_fp16(
:str: Updated base name of exported ONNX model.
"""
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)
model, fp16_fix = ClipAndSplitTransform.apply(model, onnx_base_dir=gen_models_path, apply_split=False)

if fp16_fix:
# Save FP16 model
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.generation.streamers import BaseStreamer

from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
Expand Down Expand Up @@ -58,7 +58,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
"""

_pytorch_transforms: List[PytorchTransform] = [CustomOpsTransform, KVCacheTransform, PeftModelInputsTransform]
_onnx_transforms: List[OnnxTransform] = [FP16ClipTransform, AdapterWeightsToInputsTransform, SplitTensorsTransform]
_onnx_transforms: List[OnnxTransform] = [ClipAndSplitTransform, AdapterWeightsToInputsTransform]
_hf_auto_class = AutoPeftModelForCausalLM

def __init__(self, model: nn.Module):
Expand Down
14 changes: 7 additions & 7 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand Down Expand Up @@ -159,7 +159,7 @@ class QEFFAutoModel(QEFFTransformersBase):

_hf_auto_class = AutoModel
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.Module, pooling=None, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -426,7 +426,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
KVCacheTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.modules, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -483,7 +483,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
VlmKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -898,7 +898,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
VlmNoKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(
self,
Expand Down Expand Up @@ -1330,7 +1330,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
SplitGateUpWeightsTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(
self,
Expand Down Expand Up @@ -1896,7 +1896,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin

_hf_auto_class = AutoModelForSpeechSeq2Seq
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.Module, **kwargs):
model_class_name = model.__class__.__name__
Expand Down
Loading