Skip to content

Commit ea26341

Browse files
Added Continuous Batching (CB) Support for Subfunctions (#642)
Resolved compilation issues by adding CB support for subfunctions, ensuring compatibility across CB and non-CB models. Signed-off-by: abhishek-singh591 <[email protected]>
1 parent 2ef06c2 commit ea26341

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

QEfficient/base/onnx_transforms.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,26 @@
1111
import torch
1212
from onnx import ModelProto, external_data_helper, numpy_helper
1313

14-
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
14+
from QEfficient.customop.ctx_scatter_gather import (
15+
CtxGather,
16+
CtxGather3D,
17+
CtxGatherFunc,
18+
CtxGatherFunc3D,
19+
CtxScatter,
20+
CtxScatter3D,
21+
CtxScatterFunc,
22+
CtxScatterFunc3D,
23+
)
24+
from QEfficient.customop.ctx_scatter_gather_cb import (
25+
CtxGatherCB,
26+
CtxGatherCB3D,
27+
CtxGatherFuncCB,
28+
CtxGatherFuncCB3D,
29+
CtxScatterCB,
30+
CtxScatterCB3D,
31+
CtxScatterFuncCB,
32+
CtxScatterFuncCB3D,
33+
)
1534
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
1635

1736

@@ -113,7 +132,13 @@ class CustomOpTransform(OnnxTransform):
113132
_custom_ops: Dict[str, Tuple[Any, Any]] = {
114133
"CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm),
115134
"CtxScatterFunc": (CtxScatterFunc, CtxScatter),
135+
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
116136
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
137+
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
138+
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
139+
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
140+
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
141+
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
117142
}
118143

119144
@classmethod

0 commit comments

Comments
 (0)