Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional

import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
Expand Down Expand Up @@ -45,9 +43,10 @@ class QEFFBaseModel(ABC):
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module) -> None:
def __init__(self, model: torch.nn.Module, onnx_slim_transfom: bool = False) -> None:
super().__init__()
self.model = model
self.onnx_slim_transform = onnx_slim_transfom
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -119,6 +118,7 @@ def _export(
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
onnx_slim_transform: bool = False,
export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
Expand Down Expand Up @@ -146,7 +146,6 @@ def _export(
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)

# Create input_names from example_inputs

input_names = []
for param in inspect.signature(self.model.forward).parameters:
if param in example_inputs:
Expand Down Expand Up @@ -183,11 +182,14 @@ def _export(
**export_kwargs,
)
logger.info("Pytorch export successful")

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
"temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
"enable_onnx_slim_transform": onnx_slim_transform,
"onnx_base_dir": str(tmp_onnx_dir),


}
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)
Expand Down Expand Up @@ -248,8 +250,7 @@ def _compile(
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
self.export()

self.export()
onnx_path = Path(onnx_path or self.onnx_path)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
Expand Down Expand Up @@ -368,5 +369,4 @@ def _compile(
)

self.qpc_path = qpc_path

return qpc_path
34 changes: 33 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Optional, Tuple

import numpy as np
import onnx
import onnxslim
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -36,7 +38,6 @@ class FP16ClipTransform(OnnxTransform):
"""
Clips the tensor values to be in FP16 range, but preserves -inf values.
"""

@classmethod
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
"""
Expand Down Expand Up @@ -99,3 +100,34 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform(OnnxTransform):
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
:param temp_onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = kwargs.get("enable_onnx_slim_transform", False)
temp_onnx_path = kwargs.get("temp_onnx_path", None)
if not temp_onnx_path:
err_str = "temp_onnx_path is required for onnx-slim transform."
raise RuntimeError(err_str)
if onnx_slim_transform:
transformed = True
slimmed_model = onnxslim.slim(model)
onnx.save(slimmed_model, temp_onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Type Checking or Validation Ensure temp_onnx_path is not None before saving

return slimmed_model, transformed
return model, transformed
Loading