Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxScatter
from QEfficient.customop.rms_norm import CustomRMSNorm
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.cache_utils import InvalidIndexProvider
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
Expand Down Expand Up @@ -184,6 +186,8 @@ def _export(
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
use_onnx_subfunctions: bool = False,
use_dynamo: bool = False,
dynamic_shapes: Optional[Dict[str, Dict[int, any]]] = None,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -250,6 +254,7 @@ def _export(
try:
# Initialize the registry with your custom ops
export_kwargs = {} if export_kwargs is None else export_kwargs
export_kwargs["dynamo"] = use_dynamo
if use_onnx_subfunctions:
warnings.warn(
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
Expand All @@ -261,14 +266,35 @@ def _export(
self._onnx_transforms.append(RenameFunctionOutputsTransform)
self._onnx_transforms.append(CustomOpTransform)

# fx_graph = torch.export.export(
# self.model,
# args=(),
# kwargs=example_inputs, #IMPORTANT CHANGE: passing all inputs in kwargs rather than as a rigid tuple in args
# dynamic_shapes=dynamic_shapes,
# **export_kwargs,
# strict=True,
# )
# result = fx_graph.module()(**example_inputs)
if use_dynamo:
dynamic_axes = None
export_kwargs["report"] = True
# export_kwargs["verify"] =True
export_kwargs["custom_translation_table"] = {
torch.ops.qefficient.rms_norm.default: CustomRMSNorm,
torch.ops.qefficient.ctx_gather.default: CtxGather,
torch.ops.qefficient.ctx_scatter.default: CtxScatter,
}

torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
args=(),
kwargs=example_inputs,
f=str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
dynamic_shapes=dynamic_shapes,
opset_version=18,
**export_kwargs,
)
logger.info("PyTorch export successful")
Expand Down Expand Up @@ -323,6 +349,7 @@ def _compile(
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
use_dynamo: bool = False,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -350,7 +377,7 @@ def _compile(
"""

if onnx_path is None and self.onnx_path is None:
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
self.export(use_onnx_subfunctions=use_onnx_subfunctions, use_dynamo=use_dynamo)

onnx_path = Path(onnx_path or self.onnx_path)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down Expand Up @@ -452,6 +479,7 @@ def _compile(
command.append(f"-custom-IO-list-file={custom_io_yaml}")

command.append(f"-aic-binary-dir={qpc_path}")
print(command)
logger.info(f"Running compiler: {' '.join(command)}")
try:
subprocess.run(command, capture_output=True, check=True)
Expand Down
9 changes: 6 additions & 3 deletions QEfficient/customop/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
from torch import nn

from QEfficient.utils import constants
from QEfficient.utils.custom_op_utils import select_interface

ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET))


@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1))
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float):
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float) -> onnxscript.FLOAT:
weight = ops.Cast(weight, to=1)
variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1)
epsilon = ops.Expand(epsilon, ops.Shape(variance))
hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon))
return weight * hidden_states
output = weight * hidden_states
return output


class CustomRMSNormFunc(torch.autograd.Function):
Expand Down Expand Up @@ -51,7 +53,8 @@ def __init__(self, hidden_size, eps=1e-05):
self.weight = torch.nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
return CustomRMSNormFunc.apply(
rms_interface = select_interface(CustomRMSNormFunc.apply, torch.ops.qefficient.rms_norm)
return rms_interface(
hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps
)

Expand Down
Loading
Loading