From 7bda89200ce744ec2ddc128dd98c554cad868e18 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 24 Nov 2025 05:14:27 +0000 Subject: [PATCH 1/3] adding files with modifications for dynamo enabled onnx exporter Signed-off-by: Sharvari Medhe --- QEfficient/base/modeling_qeff.py | 26 +- QEfficient/customop/rms_norm.py | 9 +- QEfficient/transformers/cache_utils.py | 175 ++++++++--- .../models/gemma3/modeling_gemma3.py | 173 ++++++++++- .../models/llava/modeling_llava.py | 146 +++++++++- .../models/mistral3/modeling_mistral3.py | 144 ++++++++- .../models/mllama/modeling_mllama.py | 206 ++++++++++++- .../transformers/models/modeling_auto.py | 275 +++++++++++++++++- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 175 ++++++++++- QEfficient/utils/constants.py | 4 +- QEfficient/utils/custom_op_utils.py | 232 +++++++++++++++ 11 files changed, 1500 insertions(+), 65 deletions(-) create mode 100644 QEfficient/utils/custom_op_utils.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 72f5c050e..8196e73d1 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -22,6 +22,8 @@ from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxScatter +from QEfficient.customop.rms_norm import CustomRMSNorm from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.cache_utils import InvalidIndexProvider from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export @@ -184,6 +186,8 @@ def _export( export_dir: Optional[str] = None, offload_pt_weights: bool = True, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, + dynamic_shapes: Optional[Dict[str, Dict[int, any]]] = None, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -250,6 +254,7 @@ def _export( try: # Initialize the registry with your custom ops export_kwargs = {} if export_kwargs is None else export_kwargs + export_kwargs["dynamo"] = use_dynamo if use_onnx_subfunctions: warnings.warn( "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." @@ -261,14 +266,26 @@ def _export( self._onnx_transforms.append(RenameFunctionOutputsTransform) self._onnx_transforms.append(CustomOpTransform) + if use_dynamo: + dynamic_axes = None + export_kwargs["report"] = True + # export_kwargs["verify"] =True + export_kwargs["custom_translation_table"] = { + torch.ops.qefficient.rms_norm.default: CustomRMSNorm, + torch.ops.qefficient.ctx_gather.default: CtxGather, + torch.ops.qefficient.ctx_scatter.default: CtxScatter, + } + torch.onnx.export( self.model, - (example_inputs,), - str(tmp_onnx_path), + args=(), + kwargs=example_inputs, + f=str(tmp_onnx_path), input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=constants.ONNX_EXPORT_OPSET, + dynamic_shapes=dynamic_shapes, + opset_version=18, **export_kwargs, ) logger.info("PyTorch export successful") @@ -323,6 +340,7 @@ def _compile( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -350,7 +368,7 @@ def _compile( """ if onnx_path is None and self.onnx_path is None: - self.export(use_onnx_subfunctions=use_onnx_subfunctions) + self.export(use_onnx_subfunctions=use_onnx_subfunctions, use_dynamo=use_dynamo) onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/customop/rms_norm.py b/QEfficient/customop/rms_norm.py index 913f22a17..802f676f2 100644 --- a/QEfficient/customop/rms_norm.py +++ b/QEfficient/customop/rms_norm.py @@ -10,17 +10,19 @@ from torch import nn from QEfficient.utils import constants +from QEfficient.utils.custom_op_utils import select_interface ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) @onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) -def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): +def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float) -> onnxscript.FLOAT: weight = ops.Cast(weight, to=1) variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) epsilon = ops.Expand(epsilon, ops.Shape(variance)) hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) - return weight * hidden_states + output = weight * hidden_states + return output class CustomRMSNormFunc(torch.autograd.Function): @@ -51,7 +53,8 @@ def __init__(self, hidden_size, eps=1e-05): self.weight = torch.nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): - return CustomRMSNormFunc.apply( + rms_interface = select_interface(CustomRMSNormFunc.apply, torch.ops.qefficient.rms_norm) + return rms_interface( hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps ) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 292fe0487..e64939653 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -22,6 +22,7 @@ CtxScatterFuncCB, CtxScatterFuncCB3D, ) +from QEfficient.utils.custom_op_utils import select_interface class InvalidIndexProvider: @@ -43,7 +44,7 @@ def _get_invalid_idx_value(cls): int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise) """ if torch.onnx.is_in_onnx_export(): - if cls.SUBFUNC_ENABLED: + if cls.SUBFUNC_ENABLED or torch._dynamo.is_compiling(): return 0 else: return torch.iinfo(torch.int32).max @@ -78,11 +79,19 @@ def read_only(self, cache_kwargs): ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + ctx_gather_cb_interface = select_interface( + CtxGatherFuncCB.apply, + torch.ops.qefficient.ctx_gather_cb, + ) + k_out = ctx_gather_cb_interface(k_out, batch_index, ctx_indices, ctx_len) + v_out = ctx_gather_cb_interface(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + ctx_gather_interface = select_interface( + CtxGatherFunc.apply, + torch.ops.qefficient.ctx_gather, + ) + k_out = ctx_gather_interface(k_out, ctx_indices, ctx_len) + v_out = ctx_gather_interface(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -112,11 +121,19 @@ def write_only(self, key_states, value_states, cache_kwargs): invalid_scatter_index = torch.iinfo(torch.int32).max scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) - self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + ctx_scatter_cb_interface = select_interface( + CtxScatterFuncCB.apply, + torch.ops.qefficient.ctx_scatter_cb, + ) + self.keys = ctx_scatter_cb_interface(self.keys, batch_index, scatter_position_ids, key_states) + self.values = ctx_scatter_cb_interface(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + ctx_scatter_interface = select_interface( + CtxScatterFunc.apply, + torch.ops.qefficient.ctx_scatter, + ) + self.keys = ctx_scatter_interface(self.keys, position_ids, key_states) + self.values = ctx_scatter_interface(self.values, position_ids, value_states) def update( self, @@ -152,12 +169,19 @@ def update( invalid_scatter_index = torch.iinfo(torch.int32).max scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) - - self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + ctx_scatter_cb_interface = select_interface( + CtxScatterFuncCB.apply, + torch.ops.qefficient.ctx_scatter_cb, + ) + self.keys = ctx_scatter_cb_interface(self.keys, batch_index, scatter_position_ids, key_states) + self.values = ctx_scatter_cb_interface(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + ctx_scatter_interface = select_interface( + CtxScatterFunc.apply, + torch.ops.qefficient.ctx_scatter, + ) + self.keys = ctx_scatter_interface(self.keys, position_ids, key_states) + self.values = ctx_scatter_interface(self.values, position_ids, value_states) k_out, v_out = self.keys, self.values @@ -170,12 +194,22 @@ def update( invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + ctx_gather_cb_interface = select_interface( + CtxGatherFuncCB.apply, + torch.ops.qefficient.ctx_gather_cb, + ) + k_out = ctx_gather_cb_interface(k_out, batch_index, ctx_indices, ctx_len) + v_out = ctx_gather_cb_interface(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + ctx_gather_interface = select_interface( + CtxGatherFunc.apply, + torch.ops.qefficient.ctx_gather, + ) + k_out = ctx_gather_interface(k_out, ctx_indices, ctx_len) + v_out = ctx_gather_interface(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -215,12 +249,19 @@ def update3D( invalid_scatter_index = torch.iinfo(torch.int32).max scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - self.keys = CtxScatterFuncCB3D.apply(self.keys, batch_index, scatter_position_ids, key_states) - - self.values = CtxScatterFuncCB3D.apply(self.values, batch_index, scatter_position_ids, value_states) + ctx_scatter_cb_3d_interface = select_interface( + CtxScatterFuncCB3D.apply, + torch.ops.qefficient.ctx_scatter_cb_3d, + ) + self.keys = ctx_scatter_cb_3d_interface(self.keys, batch_index, scatter_position_ids, key_states) + self.values = ctx_scatter_cb_3d_interface(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc3D.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc3D.apply(self.values, position_ids, value_states) + ctx_scatter_3d_interface = select_interface( + CtxScatterFunc3D.apply, + torch.ops.qefficient.ctx_scatter_3d, + ) + self.keys = ctx_scatter_3d_interface(self.keys, position_ids, key_states) + self.values = ctx_scatter_3d_interface(self.values, position_ids, value_states) k_out, v_out = self.keys, self.values @@ -234,12 +275,21 @@ def update3D( else: invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: - k_out = CtxGatherFuncCB3D.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB3D.apply(v_out, batch_index, ctx_indices) + ctx_gather_cb_3d_interface = select_interface( + CtxGatherFuncCB3D.apply, + torch.ops.qefficient.ctx_gather_cb_3d, + ) + k_out = ctx_gather_cb_3d_interface(k_out, batch_index, ctx_indices) + v_out = ctx_gather_cb_3d_interface(v_out, batch_index, ctx_indices) else: - k_out = CtxGatherFunc3D.apply(k_out, ctx_indices) - v_out = CtxGatherFunc3D.apply(v_out, ctx_indices) + ctx_gather_3d_interface = select_interface( + CtxGatherFunc3D.apply, + torch.ops.qefficient.ctx_gather_3d, + ) + k_out = ctx_gather_3d_interface(k_out, ctx_indices) + v_out = ctx_gather_3d_interface(v_out, ctx_indices) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) @@ -429,8 +479,13 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( + + ctx_scatter_interface = select_interface( + CtxScatterFunc.apply, + torch.ops.qefficient.ctx_scatter, + ) + self.key_cache[layer_idx] = ctx_scatter_interface(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = ctx_scatter_interface( self.value_cache[layer_idx], kv_position_ids, value_states ) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -449,8 +504,13 @@ def update( final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) + + ctx_gather_interface = select_interface( + CtxGatherFunc.apply, + torch.ops.qefficient.ctx_gather, + ) + k_out = ctx_gather_interface(k_out, final_indices, ctx_len) + v_out = ctx_gather_interface(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -529,8 +589,13 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( + + ctx_scatter_interface = select_interface( + CtxScatterFunc.apply, + torch.ops.qefficient.ctx_scatter, + ) + self.key_cache[layer_idx] = ctx_scatter_interface(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = ctx_scatter_interface( self.value_cache[layer_idx], kv_position_ids, value_states ) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -554,8 +619,13 @@ def update( final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) + + ctx_gather_interface = select_interface( + CtxGatherFunc.apply, + torch.ops.qefficient.ctx_gather, + ) + k_out = ctx_gather_interface(k_out, final_indices, ctx_len) + v_out = ctx_gather_interface(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -644,15 +714,26 @@ def update( scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids) else: scatter_position_ids = kv_position_ids - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + + ctx_scatter_cb_interface = select_interface( + CtxScatterFuncCB.apply, + torch.ops.qefficient.ctx_scatter_cb, + ) + self.key_cache[layer_idx] = ctx_scatter_cb_interface( self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states ) - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx] = ctx_scatter_cb_interface( self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states ) else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( + ctx_scatter_interface = select_interface( + CtxScatterFunc.apply, + torch.ops.qefficient.ctx_scatter, + ) + self.key_cache[layer_idx] = ctx_scatter_interface( + self.key_cache[layer_idx], kv_position_ids, key_states + ) + self.value_cache[layer_idx] = ctx_scatter_interface( self.value_cache[layer_idx], kv_position_ids, value_states ) @@ -674,11 +755,19 @@ def update( ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + ctx_gather_cb_interface = select_interface( + CtxGatherFuncCB.apply, + torch.ops.qefficient.ctx_gather_cb, + ) + k_out = ctx_gather_cb_interface(k_out, batch_index, ctx_indices, ctx_len) + v_out = ctx_gather_cb_interface(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + ctx_gather_interface = select_interface( + CtxGatherFunc.apply, + torch.ops.qefficient.ctx_gather, + ) + k_out = ctx_gather_interface(k_out, ctx_indices, ctx_len) + v_out = ctx_gather_interface(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 398259d8b..6d9c9f5e1 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -6,10 +6,11 @@ # ----------------------------------------------------------------------------- import copy -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn +from torch.export import Dim from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -33,6 +34,7 @@ from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.custom_op_utils import select_interface class GemmaRMSNormFunc(torch.autograd.Function): @@ -59,7 +61,8 @@ class QEffGemma3CustomRMSNormAIC(nn.Module): """ def forward(self, hidden_states): - return GemmaRMSNormFunc.apply( + rms_interface = select_interface(GemmaRMSNormFunc.apply, torch.ops.qefficient.rms_norm) + return rms_interface( hidden_states, self.weight.float() + 1.0, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps, @@ -816,6 +819,172 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names + def get_onnx_dynamic_shapes( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + ) -> Dict[str, Any]: + """ + - Handles past_key_values as a list of (key, value) pairs per layer + + For kv_offload=False, dynamic_shapes corresponds to the combined forward: + forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths?) + + For kv_offload=True, it returns: + { + "vision": { ... dynamic_shapes for vision ... }, + "lang": { ... dynamic_shapes for language, including past_key_values ... }, + } + """ + + num_layers = self.language_model.config.num_hidden_layers + config = self.language_model.config + + layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 + has_sliding_window = hasattr(config, "sliding_window") + # sliding_window = getattr(config, "sliding_window", None) + + # Registry of Dim objects so that dims with the same name share the same Dim + dim_registry: Dict[str, Dim] = {} + + def get_dim(dim_name: str) -> Dim: + if dim_name in dim_registry: + return dim_registry[dim_name] + if dim_name == "batch_size": + d = Dim(dim_name, min=1, max=1024) + elif "seq_len" in dim_name: + d = Dim(dim_name, min=2, max=4095) + elif "img_size" in dim_name: + d = Dim.STATIC + elif "mm_tokens_per_image" in dim_name: + d = Dim(dim_name, min=1, max=4096) + elif "ctx_len" in dim_name: + d = Dim(dim_name, min=2, max=4095) + elif "sliding_window" in dim_name: + d = Dim(dim_name, min=2, max=4095) + elif "idx" in dim_name: + d = Dim.STATIC + elif "comp_ctx_lengths" in dim_name: + d = Dim(dim_name, min=1, max=4096) + else: + d = Dim(dim_name, min=1, max=4096) + dim_registry[dim_name] = d + return d + + def build_past_kv_shapes() -> List[Tuple[Dict[int, Any], Dict[int, Any]]]: + """ + Returns: + list of length num_layers, each element is: + (past_key_shape_dict, past_value_shape_dict) + past_* tensor shape: (batch_size, num_key_value_heads, cache_len, head_dim) + where cache_len is either ctx_len or sliding_window, depending on layer. + """ + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + for i in range(num_layers): + # Decide whether this layer uses global ctx_len or sliding_window + if has_sliding_window and ((i + 1) % layer_switch): + # sliding-window layer + cache_len_dim = get_dim("sliding_window") + else: + # global cache layer + cache_len_dim = get_dim("ctx_len") + + past_key_shape = { + 0: get_dim("batch_size"), + 2: cache_len_dim, + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: cache_len_dim, + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + return past_kv_shapes + + # kv_offload=True → separate vision/lang exports + if kv_offload: + # Vision encoder: pixel_values only + vision_dynamic_shapes: Dict[str, Dict[int, Any]] = { + "pixel_values": { + 0: get_dim("batch_size"), + 2: get_dim("img_size"), + 3: get_dim("img_size"), + } + } + + # Language decoder wrapper forward: + # forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths=None) + lang_dynamic_shapes: Dict[str, Any] = { + "input_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "vision_embeds": { + 0: get_dim("batch_size"), + 1: get_dim("mm_tokens_per_image"), + }, + "position_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "image_idx": { + 0: get_dim("idx"), + 1: get_dim("idx"), + }, + } + + lang_dynamic_shapes["past_key_values"] = build_past_kv_shapes() + + if comp_ctx_lengths is not None: + lang_dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths"), + } + + return { + "vision": vision_dynamic_shapes, + "lang": lang_dynamic_shapes, + } + + # kv_offload=False → combined forward + # Combined forward signature in QEffGemma3ForConditionalGeneration: + # forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths=None) + + dynamic_shapes: Dict[str, Any] = {} + + # pixel_values: (batch, 3, img_size, img_size) + dynamic_shapes["pixel_values"] = { + 0: get_dim("batch_size"), + 2: get_dim("img_size"), + 3: get_dim("img_size"), + } + + # input_ids: (batch, seq_len) + dynamic_shapes["input_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # position_ids: (batch, seq_len) + dynamic_shapes["position_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # image_idx: currently (1, 1); we keep dims static + dynamic_shapes["image_idx"] = { + 0: get_dim("idx"), + 1: get_dim("idx"), + } + + # past_key_values: list[num_layers] of (past_key, past_value) + dynamic_shapes["past_key_values"] = build_past_kv_shapes() + + if comp_ctx_lengths is not None: + dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths"), + } + + return dynamic_shapes + def get_dummy_pkv_cache(self, config, batch_size, seq_len): n_heads = config.num_key_value_heads d_head = config.head_dim diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index dc6653db0..5c370eeb9 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,11 +5,12 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.utils.checkpoint +from torch.export import Dim from transformers.models.llava.modeling_llava import ( LlavaForConditionalGeneration, ) @@ -17,7 +18,7 @@ from QEfficient.utils._utils import IOInfo from QEfficient.utils.logging_utils import logger -BS = 1 +BS = 4 NUM_CHANNEL = 3 SEQ_LEN = 592 CTX_LEN = 1024 @@ -298,6 +299,147 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} return dynamic_axes + def get_onnx_dynamic_shapes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + """ + - Handles past_key_values as a list of (key, value) pairs per layer + + For kv_offload=False (combined image_text_to_text forward), it returns a flat dict + whose top-level keys match the forward signature: + + forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values) + + For kv_offload=True, it returns: + { + "vision": { ... dynamic_shapes for vision ... }, + "lang": { ... dynamic_shapes for language, including past_key_values ... }, + } + """ + + num_layers = self.config.text_config.num_hidden_layers + + # Registry of Dim objects so that dims with the same name share the same Dim + dim_registry: Dict[str, Dim] = {} + + def get_dim(dim_name: str) -> Dim: + if dim_name in dim_registry: + return dim_registry[dim_name] + if dim_name == "batch_size": + d = Dim(dim_name, min=2, max=1023) + elif "seq_len" in dim_name: + d = Dim(dim_name, min=2, max=4095) + elif "img_size" in dim_name: + # Image spatial dims are fixed for Llava(336x336) + d = Dim.STATIC + elif "num_images" in dim_name: + d = Dim(dim_name, min=1, max=16) + elif "img_tiles" in dim_name: + d = Dim(dim_name, min=1, max=32) + elif "idx" in dim_name: + # image_idx dimensions kept static + d = Dim.STATIC + elif "ctx_len" in dim_name: + d = Dim(dim_name, min=2, max=4096) + else: + d = Dim(dim_name, min=1, max=4096) + dim_registry[dim_name] = d + return d + + # kv_offload=True → separate vision/lang exports + if kv_offload: + # Vision: pixel_values: (batch, 3, img_size, img_size) + vision_dynamic_shapes: Dict[str, Dict[int, Any]] = { + "pixel_values": { + 0: get_dim("batch_size"), + 2: get_dim("img_size"), + 3: get_dim("img_size"), + } + } + + # Lang: input_ids, position_ids, vision_embeds, image_idx, past_kv + lang_dynamic_shapes: Dict[str, Any] = { + "input_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "position_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "vision_embeds": { + 0: get_dim("batch_size"), + 1: get_dim("vision_size"), + }, + "image_idx": { + 0: get_dim("idx"), + 1: get_dim("idx"), + }, + } + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + for _ in range(num_layers): + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + + lang_dynamic_shapes["past_key_values"] = past_kv_shapes + + return { + "vision": vision_dynamic_shapes, + "lang": lang_dynamic_shapes, + } + + else: + # Combined image_text_to_text forward: + # forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values) + + dynamic_shapes: Dict[str, Any] = {} + + # pixel_values: (batch, 3, img_size, img_size) + dynamic_shapes["pixel_values"] = { + 0: get_dim("batch_size"), + 2: get_dim("img_size"), + 3: get_dim("img_size"), + } + + # input_ids: (batch, seq_len) + dynamic_shapes["input_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # position_ids: (batch, seq_len) + dynamic_shapes["position_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # image_idx: (1, 1) or (batch, 1); we keep idx dims static + dynamic_shapes["image_idx"] = { + 0: get_dim("idx"), + 1: get_dim("idx"), + } + + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + for _ in range(num_layers): + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + + dynamic_shapes["past_key_values"] = past_kv_shapes + + return dynamic_shapes + def get_output_names(self, kv_offload: bool = False): vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 694ed4cde..20610916b 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,11 +5,12 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint +from torch.export import Dim from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutput from transformers.models.mistral3.modeling_mistral3 import ( @@ -44,6 +45,11 @@ def qeff_generate_block_attention_mask(patch_embeds_list, tensor): block_end_idx = custom_cumsum(torch.tensor(patch_embeds_list)) block_start_idx = custom_cumsum(torch.tensor([0] + patch_embeds_list[:-1])) for start, end in zip(block_start_idx.tolist(), block_end_idx.tolist()): + torch._check(start >= 0) + torch._check(end >= 0) + torch._check(start <= 48400) + torch._check(end >= start) + torch._check(end <= 48400) causal_mask[start:end, start:end] = 0 causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) return causal_mask @@ -65,7 +71,7 @@ def forward( pixel_values: tensor of token features for all tokens of all images of shape (N_toks, D) """ - # pass images through initial convolution independently + # # pass images through initial convolution independently patch_embeds = self.patch_conv(pixel_values) patch_embeds_list = [ embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] @@ -433,6 +439,140 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} return dynamic_axes + def get_onnx_dynamic_shapes( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + ) -> Dict[str, Any]: + num_layers = self.config.text_config.num_hidden_layers + + # Registry so that dims with the same name share a single Dim instance + dim_registry: Dict[str, Dim] = {} + + def get_dim(dim_name: str) -> Dim: + if dim_name in dim_registry: + return dim_registry[dim_name] + + if dim_name == "batch_size": + d = Dim(dim_name, min=1, max=1024) + elif dim_name == "seq_len": + d = Dim(dim_name, min=2, max=4096) + elif dim_name == "ctx_len": + d = Dim(dim_name, min=2, max=4096) + elif dim_name == "image_size": + d = Dim(dim_name, min=2, max=4096) + elif dim_name == "vision_size": + d = Dim(dim_name, min=1, max=65536) + elif "idx" in dim_name: + d = Dim.STATIC + elif "comp_ctx_lengths" in dim_name: + d = Dim(dim_name, min=1, max=4096) + else: + d = Dim(dim_name, min=1, max=4096) + + dim_registry[dim_name] = d + return d + + def build_past_kv_shapes() -> list[list[Dict[int, Any]]]: + pkv_shapes: list[list[Dict[int, Any]]] = [] + for _ in range(num_layers): + key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + pkv_shapes.append([key_shape, value_shape]) + return pkv_shapes + + # kv_offload=True: separate vision/lang exports + if kv_offload: + # Vision encoder wrapper: + # QEFFMistral3EncoderWrapper.forward(self, pixel_values) + vision_dynamic_shapes: Dict[str, Dict[int, Any]] = { + "pixel_values": { + 0: get_dim("batch_size"), + 2: get_dim("image_size"), + 3: get_dim("image_size"), + } + } + + # Language decoder wrapper: + # QEFFMistral3DecoderWrapper.forward( + # self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths=None + # ) + lang_dynamic_shapes: Dict[str, Any] = { + "input_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "vision_embeds": { + 0: get_dim("vision_size"), + }, + "position_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "image_idx": { + 0: get_dim("idx"), + 1: get_dim("idx"), + }, + # Nested list[num_layers] of [key_shape, value_shape] + "past_key_values": build_past_kv_shapes(), + } + + if comp_ctx_lengths is not None: + lang_dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths"), + } + + return { + "vision": vision_dynamic_shapes, + "lang": lang_dynamic_shapes, + } + + # kv_offload=False: combined forward + # Combined forward of QEffMistral3ForConditionalGeneration: + # forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths=None) + + dynamic_shapes: Dict[str, Any] = {} + + # pixel_values: (batch_size, 3, image_size, image_size) + dynamic_shapes["pixel_values"] = { + 0: get_dim("batch_size"), + 2: get_dim("image_size"), + 3: get_dim("image_size"), + } + + # input_ids: (batch_size, seq_len) + dynamic_shapes["input_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # position_ids: (batch_size, seq_len) + dynamic_shapes["position_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # image_idx: currently (1,1); keep static dims + dynamic_shapes["image_idx"] = { + 0: get_dim("idx"), + 1: get_dim("idx"), + } + + # past_key_values: list[num_layers] of [key, value], each (batch_size, num_heads, ctx_len, head_dim) + dynamic_shapes["past_key_values"] = build_past_kv_shapes() + + if comp_ctx_lengths is not None: + dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths"), + } + return dynamic_shapes + def get_output_names(self, kv_offload: bool = False): vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index a3cb4273d..4459942fa 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -7,11 +7,12 @@ """PyTorch Mllama model.""" -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn +from torch.export import Dim from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutput, @@ -1093,6 +1094,209 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} return dynamic_axes + def get_onnx_dynamic_shapes( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + ): + txt_cfg = self.config.get_text_config() + num_hidden_layers = txt_cfg.num_hidden_layers + cross_attention_layers = txt_cfg.cross_attention_layers + + # Registry so dims with the same name share the same Dim object + dim_registry: Dict[str, Dim] = {} + + def get_dim(dim_name: str) -> Dim: + if dim_name in dim_registry: + return dim_registry[dim_name] + + if dim_name == "batch_size": + d = Dim(dim_name, min=1, max=1024) + elif "seq_len" in dim_name: + d = Dim(dim_name, min=1, max=4096) + elif "ctx_len" in dim_name: + d = Dim(dim_name, min=1, max=4096) + elif "max_num_images" in dim_name: + d = Dim(dim_name, min=1, max=16) + elif "max_num_tiles" in dim_name or "img_tiles" in dim_name: + d = Dim(dim_name, min=1, max=32) + elif "img_size" in dim_name: + d = Dim(dim_name, min=64, max=1024) + elif "idx" in dim_name: + # keep image_idx static for now + d = Dim.STATIC + elif "image_tokens_len" in dim_name: + d = Dim(dim_name, min=1, max=8192) + elif "comp_ctx_lengths" in dim_name: + d = Dim(dim_name, min=1, max=4096) + else: + d = Dim(dim_name, min=1, max=4096) + + dim_registry[dim_name] = d + return d + + # kv_offload=True: separate vision and language exports + if kv_offload: + # pixel_values: (batch_size, max_num_images, max_num_tiles, 3, img_size, img_size) + vision_dynamic_shapes: Dict[str, Dict[int, Any]] = { + "pixel_values": { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + 2: get_dim("max_num_tiles"), + 4: get_dim("img_size"), + 5: get_dim("img_size"), + }, + "aspect_ratio_ids": { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + }, + "aspect_ratio_mask": { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + 2: get_dim("max_num_tiles"), + }, + } + + # Language side: input_ids, image_idx, cross_attention_mask, position_ids, past_key_values, [comp_ctx_lengths] + lang_dynamic_shapes: Dict[str, Any] = { + "input_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "position_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "cross_attention_mask": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + 2: get_dim("max_num_images"), + 3: get_dim("max_num_tiles"), + }, + # attention mask included in dynamic axes but since it is not used in the forward signature, omiting it here + "image_idx": { + 0: get_dim("idx"), + 1: get_dim("idx"), + }, + } + + # Build past_key_values shapes as list of (key, value) dicts + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + for layer_idx in range(num_hidden_layers): + if layer_idx in cross_attention_layers: + # KV over image tokens + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("image_tokens_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("image_tokens_len"), + } + else: + # KV over ctx_len + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + + lang_dynamic_shapes["past_key_values"] = past_kv_shapes + + if comp_ctx_lengths is not None: + lang_dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths_dim"), + } + + return { + "vision": vision_dynamic_shapes, + "lang": lang_dynamic_shapes, + } + + # kv_offload=False: combined export + dynamic_shapes: Dict[str, Any] = {} + + # pixel_values: (batch_size, max_num_images, max_num_tiles, 3, img_size, img_size) + dynamic_shapes["pixel_values"] = { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + 2: get_dim("max_num_tiles"), + 4: get_dim("img_size"), + 5: get_dim("img_size"), + } + + # aspect_ratio_ids: (batch_size, max_num_images) + dynamic_shapes["aspect_ratio_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + } + + # aspect_ratio_mask: (batch_size, max_num_images, max_num_tiles) + dynamic_shapes["aspect_ratio_mask"] = { + 0: get_dim("batch_size"), + 1: get_dim("max_num_images"), + 2: get_dim("max_num_tiles"), + } + + # input_ids: (batch_size, seq_len) + dynamic_shapes["input_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # image_idx: (1, 1) or (batch, 1); kept static by design + dynamic_shapes["image_idx"] = { + 0: get_dim("idx"), + 1: get_dim("idx"), + } + + # cross_attention_mask: (batch_size, seq_len, max_num_images, max_num_tiles) + dynamic_shapes["cross_attention_mask"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + 2: get_dim("max_num_images"), + 3: get_dim("max_num_tiles"), + } + + # position_ids: (batch_size, seq_len) + dynamic_shapes["position_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # past_key_values: list[(past_key, past_value)] for each layer + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + for layer_idx in range(num_hidden_layers): + if layer_idx in cross_attention_layers: + # KV over image tokens + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("image_tokens_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("image_tokens_len"), + } + else: + # KV over ctx_len + past_key_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim("batch_size"), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + + dynamic_shapes["past_key_values"] = tuple(past_kv_shapes) + + return dynamic_shapes + def get_output_names(self, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cbff5be91..35c065e34 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -13,6 +13,8 @@ import numpy as np import torch import torch.nn as nn + +# Optional: helps type hints from transformers import ( AutoImageProcessor, AutoModel, @@ -318,7 +320,15 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def convert_dynamic_axes_to_dynamic_shapes(self, dynamic_axes): + pass + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -350,12 +360,18 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names = ["output"] + dynamic_shapes = None + if use_dynamo: + dynamic_shapes = self.convert_dynamic_axes_to_dynamic_shapes(dynamic_axes) + return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, + dynamic_shapes=dynamic_shapes, ) def compile( @@ -369,6 +385,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -441,6 +458,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) @@ -610,9 +628,11 @@ def export( inputs, output_names, dynamic_axes, + dynamic_shapes, export_dir=None, offload_pt_weights=True, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, ): """ Exports the vision encoder component to ONNX format. @@ -641,9 +661,11 @@ def export( inputs, output_names, dynamic_axes, + dynamic_shapes=dynamic_shapes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, ) def compile( @@ -657,6 +679,7 @@ def compile( aic_num_cores, custom_io, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -700,6 +723,7 @@ def compile( aic_num_cores=aic_num_cores, custom_io=custom_io, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) @@ -771,9 +795,11 @@ def export( inputs, output_names, dynamic_axes, + dynamic_shapes, export_dir=None, offload_pt_weights=True, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, ): """ Exports the language decoder component to ONNX format. @@ -802,9 +828,11 @@ def export( inputs, output_names, dynamic_axes, + dynamic_shapes=dynamic_shapes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, ) def compile( @@ -818,6 +846,7 @@ def compile( aic_num_cores, custom_io, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -861,6 +890,7 @@ def compile( aic_num_cores=aic_num_cores, custom_io=custom_io, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) @@ -1022,6 +1052,7 @@ def export( self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **kwargs, ) -> str: """ @@ -1045,6 +1076,7 @@ def export( A list containing the paths to the generated ONNX graph files for both components. """ # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: inputs = self.model.get_dummy_inputs( kv_offload=True, @@ -1061,6 +1093,12 @@ def export( dynamic_axes = self.model.get_onnx_dynamic_axes( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) + + dynamic_shapes = None + if use_dynamo: + dynamic_shapes = self.model.get_onnx_dynamic_shapes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1070,6 +1108,8 @@ def export( export_dir=export_dir, offload_pt_weights=False, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, + dynamic_shapes=dynamic_shapes["vision"], ) self.lang_model.export( inputs["lang"], @@ -1078,6 +1118,8 @@ def export( export_dir=export_dir, offload_pt_weights=True, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, + dynamic_shapes=dynamic_shapes["lang"], ) return self.onnx_path @@ -1101,6 +1143,7 @@ def compile( skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -1216,6 +1259,7 @@ def compile( ): self.export( use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -1235,6 +1279,7 @@ def compile( custom_io=custom_io_vision, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) @@ -1264,6 +1309,7 @@ def compile( custom_io=custom_io_lang, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) return self.qpc_path @@ -1689,6 +1735,7 @@ def export( self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **kwargs, ) -> str: """ @@ -1706,15 +1753,21 @@ def export( str Path to the generated ONNX graph file. """ + inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() + dynamic_shapes = None + if use_dynamo: + dynamic_shapes = self.model.get_onnx_dynamic_shapes(comp_ctx_lengths=self.comp_ctx_lengths_decode) return self._export( - inputs, - output_names, - dynamic_axes, + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, export_dir=export_dir, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, + dynamic_shapes=dynamic_shapes, ) def compile( @@ -1734,6 +1787,7 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -1844,6 +1898,7 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) return self.qpc_path @@ -2501,7 +2556,205 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: + # def convert_dynamic_axes_to_dynamic_shapes(self, dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: + # pass + + # def convert_dynamic_axes_to_dynamic_shapes( + # self, + # example_inputs: Dict[str, Any], + # dynamic_axes: Dict[str, Dict[int, str]], + # ) -> dynamic_shapes_api.ShapesCollection: + # """ + # Convert torch.onnx-style dynamic_axes to a torch.export ShapesCollection. + + # Parameters + # ---------- + # example_inputs : Dict[str, Any] + # The same example_inputs dict used for export. Keys here should + # correspond to the names in dynamic_axes (e.g. "input_ids", + # "position_ids", "past_key.0", etc.), so we can map from names + # to actual tensors. + + # dynamic_axes : Dict[str, Dict[int, str]] + # Mapping from tensor name to a dict {dim_index: dim_name}, as + # used in torch.onnx.export. + + # Example: + # { + # "input_ids": {0: "batch_size", 1: "seq_len"}, + # "position_ids": {0: "batch_size", 1: "seq_len"}, + # "past_key.0": {0: "full_batch_size", 2: "ctx_len"}, + # } + + # Returns + # ------- + # ShapesCollection + # A ShapesCollection suitable for the `dynamic_shapes` argument + # to torch.export, keyed by the actual tensors in example_inputs. + # """ + + # # Cache Dim objects by name so same name => same symbolic dimension + # dim_cache: Dict[str, Dim] = {} + + # # def get_dim(name: str) -> Dim: + # # if name not in dim_cache: + # # # You can add min/max here if needed: Dim(name, min=1, max=...) + # # dim_cache[name] = Dim(name) + # # return dim_cache[name] + + # def get_dim(name: str) -> Dim: + # if name not in dim_cache: + # # Patch: special min/max for specific dimension names + # if name == "batch_size": + # # Example: allow batch_size from 1 to 1024 + # dim_cache[name] = Dim("batch_size", min=1, max=1024) + # elif name == "seq_len": + # # Example: disallow tiny seq_len if needed; adjust as per model + # # Here: seq_len between 2 and 4096 (tweak as appropriate) + # dim_cache[name] = Dim("seq_len", min=2, max=4096) + # elif name == "ctx_len": + # # Example: context length from 1 to 8192 + # dim_cache[name] = Dim("ctx_len", min=1, max=8192) + # else: + # # Default: no explicit bounds + # dim_cache[name] = Dim(name) + # return dim_cache[name] + + # # Create a ShapesCollection + # shapes = dynamic_shapes_api.ShapesCollection() + + # # Iterate over each name in dynamic_axes + # for name, axes_spec in dynamic_axes.items(): + # if name not in example_inputs: + # # If your naming doesn’t match 1:1, you’ll need custom mapping here. + # # For now, we just skip and optionally log. + # # print(f"[WARN] dynamic_axes key '{name}' not found in example_inputs, skipping") + # continue + + # tensor = example_inputs[name] + + # # Only tensors (or wrapped ints) should be used as keys in ShapesCollection + # if not isinstance(tensor, torch.Tensor): + # # If it's an int or something else, you’d need _IntWrapper logic + # # from the docs. For now, we skip non-tensors. + # # print(f"[WARN] example_inputs['{name}'] is not a Tensor, skipping") + # continue + + # # axes_spec: Dict[int, str] → we build a dict {dim_idx: Dim} + # dim_mapping: Dict[int, Any] = {} + # for dim_idx, dim_name in axes_spec.items(): + # if isinstance(dim_name, str): + # dim_mapping[dim_idx] = get_dim(dim_name) + # else: + # # If already a Dim/int, just propagate + # dim_mapping[dim_idx] = dim_name + + # # Assign this mapping to the ShapesCollection for this tensor + # # Using dict {dim_index: Dim} form, as allowed by ShapesCollection + # shapes[tensor] = dim_mapping + + # return shapes + def convert_dynamic_axes_to_dynamic_shapes(self, dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: + """ + Convert ONNX dynamic_axes format to torch.export dynamic_shapes format + + Args: + dynamic_axes: ONNX format like {"input_ids": {0: "batch_size", 1: "seq_len"}} + + Returns: + dynamic_shapes: torch.export format with Dim objects matching model forward args + """ + from torch.export import Dim + + # Create dimension registry to reuse Dim objects with same names + dim_registry = {} + dynamic_shapes = {} + + # Handle regular model inputs (not past_key_values) + # These match the QEffLlamaForCausalLM forward signature: + # input_ids, attention_mask, position_ids, past_key_values, batch_index, etc. + for input_name, axes_map in dynamic_axes.items(): + if not input_name.startswith("past_"): + input_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + # Create or reuse Dim object for this dimension name + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim("batch_size") + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim("seq_len", min=2, max=131071) + elif "ctx_len" in dim_name: + dim_registry[dim_name] = Dim("ctx_len", min=2, max=131071) + else: + dim_registry[dim_name] = Dim.DYNAMIC + + input_dynamic_shapes[axis_idx] = dim_registry[dim_name] + + dynamic_shapes[input_name] = input_dynamic_shapes + + # Handle past_key_values specially - collect all past_key.X and past_value.X + past_keys = {} + past_values = {} + + for input_name, axes_map in dynamic_axes.items(): + if input_name.startswith("past_key."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim("batch_size") + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim("seq_len", min=2, max=131071) + elif "ctx_len" in dim_name: + dim_registry[dim_name] = Dim("ctx_len", min=2, max=131071) + else: + dim_registry[dim_name] = Dim.DYNAMIC + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_keys[layer_idx] = layer_dynamic_shapes + + elif input_name.startswith("past_value."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim("batch_size") + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim("seq_len", min=2, max=131071) + elif "ctx_len" in dim_name: + dim_registry[dim_name] = Dim("ctx_len", min=2, max=131071) + else: + dim_registry[dim_name] = Dim.DYNAMIC + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_values[layer_idx] = layer_dynamic_shapes + + # Reconstruct past_key_values as nested structure if we have past keys/values + if past_keys or past_values: + max_layer = max(list(past_keys.keys()) + list(past_values.keys())) + past_kv_shapes = [] + + for layer_idx in range(max_layer + 1): + layer_shapes = [] + if layer_idx in past_keys: + layer_shapes.append(past_keys[layer_idx]) + else: + layer_shapes.append({}) + + if layer_idx in past_values: + layer_shapes.append(past_values[layer_idx]) + else: + layer_shapes.append({}) + + past_kv_shapes.append(layer_shapes) + + dynamic_shapes["past_key_values"] = past_kv_shapes + + return dynamic_shapes + + def export( + self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, use_dynamo: bool = False, **kwargs + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2516,6 +2769,8 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = If not provided, the default export directory is used. use_onnx_subfunctions: bool, optional whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + use_dynamo: bool, optional + whether to enable dynamo during export. Returns ------- str @@ -2606,12 +2861,18 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = dynamic_axes=dynamic_axes, ) + dynamic_shapes = None + if use_dynamo: + dynamic_shapes = self.convert_dynamic_axes_to_dynamic_shapes(dynamic_axes) + return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, + dynamic_shapes=dynamic_shapes, offload_pt_weights=kwargs.get("offload_pt_weights", True), ) @@ -2824,6 +3085,7 @@ def compile( num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, + use_dynamo: bool = False, **compiler_options, ) -> str: """ @@ -2867,6 +3129,8 @@ def compile( the decode stage. If None, compiles for both stages. Default is None. use_onnx_subfunctions: bool, optional whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + use_dynamo: bool,optional + whether to enable dynamo during export **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -3029,6 +3293,7 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + use_dynamo=use_dynamo, **compiler_options, ) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index baffb44c5..cf852c65a 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -7,11 +7,12 @@ import math import os -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from torch.export import Dim from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel from transformers.cache_utils import Cache from transformers.modeling_outputs import ( @@ -1121,6 +1122,178 @@ def get_onnx_dynamic_axes( dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} return dynamic_axes + def get_onnx_dynamic_shapes( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + ): + num_layers = self.config.text_config.num_hidden_layers + + # Registry of Dim objects so that dims with the same name share the same Dim + dim_registry: Dict[str, Dim] = {} + + def get_dim(dim_name: str) -> Dim: + if dim_name in dim_registry: + return dim_registry[dim_name] + + if dim_name == "batch_size": + d = Dim(dim_name, min=1, max=1024) + elif dim_name == "vision_batch_size": + d = Dim(dim_name, min=1, max=1024) + elif dim_name == "full_batch_size": + d = Dim(dim_name, min=1, max=2048) + elif "seq_len" in dim_name: + d = Dim(dim_name, min=1, max=4096) + elif dim_name == "ctx_len": + d = Dim(dim_name, min=1, max=4096) + elif dim_name == "grid_height": + d = Dim(dim_name, min=1, max=65536) + elif dim_name == "grid_width": + d = Dim(dim_name, min=1, max=8192) + elif dim_name == "grid_h": + d = Dim(dim_name, min=1, max=1024) + elif dim_name == "grid_w": + d = Dim(dim_name, min=1, max=1024) + elif "vision_size" in dim_name: + d = Dim(dim_name, min=1, max=65536) + elif "comp_ctx_lengths" in dim_name: + d = Dim(dim_name, min=1, max=4096) + else: + # fallback + d = Dim(dim_name, min=1, max=4096) + + dim_registry[dim_name] = d + return d + + # kv_offload == True: separate vision/lang exports + if kv_offload: + # Vision encoder path: + # pixel_values: (grid_height, grid_width) + # image_grid_thw: (batch_size, 1, grid_h, grid_w) + vision_dynamic_shapes: Dict[str, Dict[int, Any]] = { + "pixel_values": { + 0: get_dim("grid_height"), + 1: get_dim("grid_width"), + }, + "image_grid_thw": { + 0: get_dim("batch_size"), + 2: get_dim("grid_h"), + 3: get_dim("grid_w"), + }, + } + + # Language decoder path: + # input_ids: (batch_size, seq_len) + # position_ids: (3, batch_size, seq_len) + # vision_embeds: (vision_batch_size, vision_size, hidden_size) + # image_idx: (1, 1) --here treated as input to lang subgraph + lang_dynamic_shapes: Dict[str, Any] = { + "input_ids": { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + }, + "position_ids": { + 1: get_dim("batch_size"), + 2: get_dim("seq_len"), + }, + "vision_embeds": { + 0: get_dim("vision_batch_size"), + 1: get_dim("vision_size"), + }, + "image_idx": { + 0: Dim.STATIC, + 1: Dim.STATIC, + }, + } + + # KV cache for decoder: list of (key, value) per layer: + # key/value: (batch_size or full_batch_size, num_heads, ctx_len, head_dim) + past_kv_shapes: List[Tuple[Dict[int, Any], Dict[int, Any]]] = [] + batch_dim_name = "full_batch_size" if continuous_batching else "batch_size" + for _ in range(num_layers): + past_key_shape = { + 0: get_dim(batch_dim_name), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim(batch_dim_name), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append((past_key_shape, past_value_shape)) + + lang_dynamic_shapes["past_key_values"] = past_kv_shapes + + if continuous_batching: + # batch_index: often (batch_size, 1) or (batch_size,) + lang_dynamic_shapes["batch_index"] = { + 0: get_dim("batch_size"), + } + + if comp_ctx_lengths is not None: + lang_dynamic_shapes["comp_ctx_lengths"] = { + 0: get_dim("comp_ctx_lengths"), + } + + return { + "vision": vision_dynamic_shapes, + "lang": lang_dynamic_shapes, + } + + # kv_offload == False: single combined export + dynamic_shapes: Dict[str, Any] = {} + + # pixel_values: (grid_height, grid_width) + dynamic_shapes["pixel_values"] = { + 0: get_dim("grid_height"), + 1: get_dim("grid_width"), + } + + # image_grid_thw: (batch_size, 1, grid_h, grid_w) + dynamic_shapes["image_grid_thw"] = { + 0: get_dim("batch_size"), + 2: get_dim("grid_h"), + 3: get_dim("grid_w"), + } + + # input_ids: (batch_size, seq_len) + dynamic_shapes["input_ids"] = { + 0: get_dim("batch_size"), + 1: get_dim("seq_len"), + } + + # position_ids: (3, batch_size, seq_len) + # (dim 0 == 3 is static and omitted from dynamic spec) + dynamic_shapes["position_ids"] = { + 1: get_dim("batch_size"), + 2: get_dim("seq_len"), + } + + # past_key_values: list[(key, value) ...] per layer, + # key/value: (batch_size or full_batch_size, num_heads, ctx_len, head_dim) + past_kv_shapes: List[List[Dict[int, Any], Dict[int, Any]]] = [] + batch_dim_name = "full_batch_size" if continuous_batching else "batch_size" + for _ in range(num_layers): + past_key_shape = { + 0: get_dim(batch_dim_name), + 2: get_dim("ctx_len"), + } + past_value_shape = { + 0: get_dim(batch_dim_name), + 2: get_dim("ctx_len"), + } + past_kv_shapes.append([past_key_shape, past_value_shape]) + + dynamic_shapes["past_key_values"] = past_kv_shapes + dynamic_shapes["kwargs"] = { + "image_idx": { + 0: Dim.STATIC, + 1: Dim.STATIC, + }, + } + + return dynamic_shapes + def get_output_names(self, kv_offload: bool = False): vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 1504bdae5..4d6cb219d 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -13,7 +13,7 @@ ROOT_DIR = os.path.dirname(QEFF_DIR) QEFF_CACHE_DIR_NAME = "qeff_cache" -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 +ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 4 ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep @@ -83,7 +83,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 17 +ONNX_EXPORT_OPSET = 18 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" diff --git a/QEfficient/utils/custom_op_utils.py b/QEfficient/utils/custom_op_utils.py new file mode 100644 index 000000000..e9262ddf0 --- /dev/null +++ b/QEfficient/utils/custom_op_utils.py @@ -0,0 +1,232 @@ +import torch + + +def select_interface(eager_impl, custom_op_impl): + use_custom_op = torch._dynamo.is_compiling() + return custom_op_impl if use_custom_op else eager_impl + + +@torch.library.custom_op("qefficient::rms_norm", mutates_args=()) +def rms_norm_op(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: + """Custom RMS Norm operation for QEfficient""" + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + return weight * hidden_states + + +@rms_norm_op.register_fake +def _(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: + """Fake implementation for torch.export - just returns tensor with same shape/dtype""" + return torch.empty_like(hidden_states) + + +@torch.library.custom_op("qefficient::ctx_scatter", mutates_args=()) +def ctx_scatter_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Custom context scatter operation""" + result = data.clone() + batch_idx = torch.arange(result.shape[0]).view(-1, 1, 1) + head_idx = torch.arange(result.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + result[batch_idx, head_idx, ctx_idx] = updates + return result + + +@ctx_scatter_op.register_fake +def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Fake implementation for torch.export - just returns data tensor with same shape/dtype""" + return torch.empty_like(data) + + +@torch.library.custom_op("qefficient::ctx_scatter_3d", mutates_args=()) +def ctx_scatter_3d_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Custom 3D context scatter operation""" + # Clone the data to avoid aliasing issues with torch.library.custom_op + result = data.clone() + batch_idx = torch.arange(result.shape[0]).view(-1, 1) + ctx_idx = position_ids + result[batch_idx, ctx_idx] = updates + return result + + +@ctx_scatter_3d_op.register_fake +def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Fake implementation for torch.export - just returns data tensor with same shape/dtype""" + return torch.empty_like(data) + + +@torch.library.custom_op("qefficient::ctx_gather_3d", mutates_args=()) +def ctx_gather_3d_op(data: torch.Tensor, ctx_indices: torch.Tensor) -> torch.Tensor: + """Custom 3D context gather operation""" + batch_indices = torch.arange(data.shape[0]).view(-1, 1) + return data[batch_indices, ctx_indices] + + +@ctx_gather_3d_op.register_fake +def _(data: torch.Tensor, ctx_indices: torch.Tensor) -> torch.Tensor: + """Fake implementation for torch.export""" + # Return tensor with shape [batch_size, seq_len] + batch_size = data.shape[0] + seq_len = ctx_indices.shape[1] + return torch.empty(batch_size, seq_len, dtype=data.dtype, device=data.device) + + +@torch.library.custom_op("qefficient::ctx_gather", mutates_args=()) +def ctx_gather_op(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int) -> torch.Tensor: + batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + +@ctx_gather_op.register_fake +def _( + data: torch.Tensor, + ctx_indices: torch.Tensor, + comp_ctx_len: int, +) -> torch.Tensor: + return torch.empty_like(data) + + +# SCATTER CB (4D with heads, context, etc.) +@torch.library.custom_op("qefficient::ctx_scatter_cb", mutates_args=()) +def ctx_scatter_cb_op( + data: torch.Tensor, + batch_index: torch.Tensor, + position_ids: torch.Tensor, + updates: torch.Tensor, +) -> torch.Tensor: + """ + Custom 4D context scatter op with batch_index (CB version). + Semantics: same as CtxScatterFuncCB.forward, but returns a new tensor. + """ + # Clone to avoid aliasing issues with custom_op + result = data.clone() + batch_idx = batch_index.view(-1, 1, 1) + head_idx = torch.arange(result.shape[1], device=result.device).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + result[batch_idx, head_idx, ctx_idx] = updates + return result + + +@ctx_scatter_cb_op.register_fake +def _( + data: torch.Tensor, + batch_index: torch.Tensor, + position_ids: torch.Tensor, + updates: torch.Tensor, +) -> torch.Tensor: + """ + Fake implementation for torch.export: correct shape/dtype/device, values don't matter. + Output shape matches data. + """ + return torch.empty_like(data) + + +# SCATTER CB 3D +@torch.library.custom_op("qefficient::ctx_scatter_cb_3d", mutates_args=()) +def ctx_scatter_cb_3d_op( + data: torch.Tensor, + batch_index: torch.Tensor, + position_ids: torch.Tensor, + updates: torch.Tensor, +) -> torch.Tensor: + """ + Custom 3D context scatter op with batch_index (CB3D version). + Semantics: same as CtxScatterFuncCB3D.forward but returns new tensor. + """ + result = data.clone() + batch_idx = batch_index.view(-1, 1) + ctx_idx = position_ids + result[batch_idx, ctx_idx] = updates + return result + + +@ctx_scatter_cb_3d_op.register_fake +def _( + data: torch.Tensor, + batch_index: torch.Tensor, + position_ids: torch.Tensor, + updates: torch.Tensor, +) -> torch.Tensor: + """ + Fake implementation for torch.export: same shape/dtype/device as data. + """ + return torch.empty_like(data) + + +# GATHER CB +@torch.library.custom_op("qefficient::ctx_gather_cb", mutates_args=()) +def ctx_gather_cb_op( + data: torch.Tensor, + batch_index: torch.Tensor, + ctx_indices: torch.Tensor, +) -> torch.Tensor: + """ + Custom 4D context gather op with batch_index (CB version). + Semantics: similar to CtxGatherFuncCB.forward. + """ + batch_indices = batch_index.view(-1, 1, 1) + head_indices = torch.arange(data.shape[1], device=data.device).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + +@ctx_gather_cb_op.register_fake +def _( + data: torch.Tensor, + batch_index: torch.Tensor, + ctx_indices: torch.Tensor, +) -> torch.Tensor: + """ + Fake implementation for torch.export. + + We derive the output shape from input shapes: + - batch_size: from batch_index + - num_heads: from data + - seq_len: from ctx_indices (dimension 1, typically) + - hidden dims: from data (starting from dim 3) + """ + batch_size = batch_index.shape[0] + num_heads = data.shape[1] + seq_len = ctx_indices.shape[1] + + # Remaining feature dimensions (e.g., head_dim or more) + feature_shape = data.shape[3:] # could be () if 3D + + out_shape = (batch_size, num_heads, seq_len, *feature_shape) + return torch.empty(out_shape, dtype=data.dtype, device=data.device) + + +# GATHER CB 3D +@torch.library.custom_op("qefficient::ctx_gather_cb_3d", mutates_args=()) +def ctx_gather_cb_3d_op( + data: torch.Tensor, + batch_index: torch.Tensor, + ctx_indices: torch.Tensor, +) -> torch.Tensor: + """ + Custom 3D context gather op with batch_index (CB3D version). + Semantics: similar to CtxGatherFuncCB3D.forward. + """ + batch_indices = batch_index.view(-1, 1) + return data[batch_indices, ctx_indices] + + +@ctx_gather_cb_3d_op.register_fake +def _( + data: torch.Tensor, + batch_index: torch.Tensor, + ctx_indices: torch.Tensor, +) -> torch.Tensor: + """ + Fake implementation for torch.export. + + Output shape: + - batch_size: from batch_index + - seq_len: from ctx_indices (dim 1) + - any trailing dims from data + """ + batch_size = batch_index.shape[0] + seq_len = ctx_indices.shape[1] + feature_shape = data.shape[2:] # if data is [B, C], this is () + + out_shape = (batch_size, seq_len, *feature_shape) + return torch.empty(out_shape, dtype=data.dtype, device=data.device) From b4fe1dfdeb0a86203c90e732d94510c5b55c01b5 Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 24 Nov 2025 07:50:16 +0000 Subject: [PATCH 2/3] updating the pyproject.toml file for torch 2.9.1 Signed-off-by: Sharvari Medhe --- QEfficient/base/modeling_qeff.py | 10 ++++++++++ pyproject.toml | 9 +++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 8196e73d1..bcd024bd7 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -266,6 +266,15 @@ def _export( self._onnx_transforms.append(RenameFunctionOutputsTransform) self._onnx_transforms.append(CustomOpTransform) + # fx_graph = torch.export.export( + # self.model, + # args=(), + # kwargs=example_inputs, #IMPORTANT CHANGE: passing all inputs in kwargs rather than as a rigid tuple in args + # dynamic_shapes=dynamic_shapes, + # **export_kwargs, + # strict=True, + # ) + # result = fx_graph.module()(**example_inputs) if use_dynamo: dynamic_axes = None export_kwargs["report"] = True @@ -470,6 +479,7 @@ def _compile( command.append(f"-custom-IO-list-file={custom_io_yaml}") command.append(f"-aic-binary-dir={qpc_path}") + print(command) logger.info(f"Running compiler: {' '.join(command)}") try: subprocess.run(command, capture_output=True, check=True) diff --git a/pyproject.toml b/pyproject.toml index ea3c3405d..a5d3651d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,22 +28,23 @@ dependencies = [ "multidict==6.0.4", "urllib3<2", "sentencepiece==0.2.0", - "onnx==1.18.0", - "onnxruntime==1.22", + "onnx==1.19.1", + "onnxruntime==1.23.2", "numpy==1.26.4", "protobuf==6.31.0", - "onnxscript==0.2.5", + "onnxscript==0.5.6", "pillow===10.4.0", "sympy", "tensorboard", "fire", "py7zr", "torchmetrics==1.7.0", - "torch==2.7.0; platform_machine=='aarch64'", + "torch==2.9.1; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", ] [project.optional-dependencies] From 2ff1ae131634ee9b86c6ed2a05e768c22542b614 Mon Sep 17 00:00:00 2001 From: smedhe Date: Wed, 26 Nov 2025 13:58:11 +0530 Subject: [PATCH 3/3] Update modeling_auto.py Signed-off-by: smedhe --- .../transformers/models/modeling_auto.py | 99 +------------------ 1 file changed, 1 insertion(+), 98 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 35c065e34..d0c312661 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2556,104 +2556,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - # def convert_dynamic_axes_to_dynamic_shapes(self, dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: - # pass - - # def convert_dynamic_axes_to_dynamic_shapes( - # self, - # example_inputs: Dict[str, Any], - # dynamic_axes: Dict[str, Dict[int, str]], - # ) -> dynamic_shapes_api.ShapesCollection: - # """ - # Convert torch.onnx-style dynamic_axes to a torch.export ShapesCollection. - - # Parameters - # ---------- - # example_inputs : Dict[str, Any] - # The same example_inputs dict used for export. Keys here should - # correspond to the names in dynamic_axes (e.g. "input_ids", - # "position_ids", "past_key.0", etc.), so we can map from names - # to actual tensors. - - # dynamic_axes : Dict[str, Dict[int, str]] - # Mapping from tensor name to a dict {dim_index: dim_name}, as - # used in torch.onnx.export. - - # Example: - # { - # "input_ids": {0: "batch_size", 1: "seq_len"}, - # "position_ids": {0: "batch_size", 1: "seq_len"}, - # "past_key.0": {0: "full_batch_size", 2: "ctx_len"}, - # } - - # Returns - # ------- - # ShapesCollection - # A ShapesCollection suitable for the `dynamic_shapes` argument - # to torch.export, keyed by the actual tensors in example_inputs. - # """ - - # # Cache Dim objects by name so same name => same symbolic dimension - # dim_cache: Dict[str, Dim] = {} - - # # def get_dim(name: str) -> Dim: - # # if name not in dim_cache: - # # # You can add min/max here if needed: Dim(name, min=1, max=...) - # # dim_cache[name] = Dim(name) - # # return dim_cache[name] - - # def get_dim(name: str) -> Dim: - # if name not in dim_cache: - # # Patch: special min/max for specific dimension names - # if name == "batch_size": - # # Example: allow batch_size from 1 to 1024 - # dim_cache[name] = Dim("batch_size", min=1, max=1024) - # elif name == "seq_len": - # # Example: disallow tiny seq_len if needed; adjust as per model - # # Here: seq_len between 2 and 4096 (tweak as appropriate) - # dim_cache[name] = Dim("seq_len", min=2, max=4096) - # elif name == "ctx_len": - # # Example: context length from 1 to 8192 - # dim_cache[name] = Dim("ctx_len", min=1, max=8192) - # else: - # # Default: no explicit bounds - # dim_cache[name] = Dim(name) - # return dim_cache[name] - - # # Create a ShapesCollection - # shapes = dynamic_shapes_api.ShapesCollection() - - # # Iterate over each name in dynamic_axes - # for name, axes_spec in dynamic_axes.items(): - # if name not in example_inputs: - # # If your naming doesn’t match 1:1, you’ll need custom mapping here. - # # For now, we just skip and optionally log. - # # print(f"[WARN] dynamic_axes key '{name}' not found in example_inputs, skipping") - # continue - - # tensor = example_inputs[name] - - # # Only tensors (or wrapped ints) should be used as keys in ShapesCollection - # if not isinstance(tensor, torch.Tensor): - # # If it's an int or something else, you’d need _IntWrapper logic - # # from the docs. For now, we skip non-tensors. - # # print(f"[WARN] example_inputs['{name}'] is not a Tensor, skipping") - # continue - - # # axes_spec: Dict[int, str] → we build a dict {dim_idx: Dim} - # dim_mapping: Dict[int, Any] = {} - # for dim_idx, dim_name in axes_spec.items(): - # if isinstance(dim_name, str): - # dim_mapping[dim_idx] = get_dim(dim_name) - # else: - # # If already a Dim/int, just propagate - # dim_mapping[dim_idx] = dim_name - - # # Assign this mapping to the ShapesCollection for this tensor - # # Using dict {dim_index: Dim} form, as allowed by ShapesCollection - # shapes[tensor] = dim_mapping - - # return shapes + def convert_dynamic_axes_to_dynamic_shapes(self, dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: """ Convert ONNX dynamic_axes format to torch.export dynamic_shapes format