|
11 | 11 | import torch |
12 | 12 | from onnx import ModelProto, external_data_helper, numpy_helper |
13 | 13 |
|
14 | | -from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc |
| 14 | +from QEfficient.customop.ctx_scatter_gather import ( |
| 15 | + CtxGather, |
| 16 | + CtxGather3D, |
| 17 | + CtxGatherFunc, |
| 18 | + CtxGatherFunc3D, |
| 19 | + CtxScatter, |
| 20 | + CtxScatter3D, |
| 21 | + CtxScatterFunc, |
| 22 | + CtxScatterFunc3D, |
| 23 | +) |
| 24 | +from QEfficient.customop.ctx_scatter_gather_cb import ( |
| 25 | + CtxGatherCB, |
| 26 | + CtxGatherCB3D, |
| 27 | + CtxGatherFuncCB, |
| 28 | + CtxGatherFuncCB3D, |
| 29 | + CtxScatterCB, |
| 30 | + CtxScatterCB3D, |
| 31 | + CtxScatterFuncCB, |
| 32 | + CtxScatterFuncCB3D, |
| 33 | +) |
15 | 34 | from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc |
16 | 35 |
|
17 | 36 |
|
@@ -113,7 +132,13 @@ class CustomOpTransform(OnnxTransform): |
113 | 132 | _custom_ops: Dict[str, Tuple[Any, Any]] = { |
114 | 133 | "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), |
115 | 134 | "CtxScatterFunc": (CtxScatterFunc, CtxScatter), |
| 135 | + "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), |
116 | 136 | "CtxGatherFunc": (CtxGatherFunc, CtxGather), |
| 137 | + "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), |
| 138 | + "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), |
| 139 | + "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), |
| 140 | + "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), |
| 141 | + "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), |
117 | 142 | } |
118 | 143 |
|
119 | 144 | @classmethod |
|
0 commit comments