diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 6b4d08e1ab7b5f..fd5c2bdbaaac7e 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -39,6 +39,10 @@ #include "nodes/common/blocked_desc_creator.h" #include "nodes/common/dnnl_executor.h" #include "nodes/executors/deconv_list.hpp" +#include "utils/arch_macros.h" +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif #include "nodes/executors/executor.hpp" #include "nodes/node_config.h" #include "onednn/dnnl.h" @@ -634,8 +638,8 @@ void Deconvolution::getSupportedDescriptors() { return AclDeconvExecutorBuilder::customIsSupported(deconvAttrs, srcMemoryDescs, dstMemoryDescs); }; - useACL = checkDesc(LayoutType::nspc) || checkDesc(LayoutType::ncsp); - if (useACL) { + + if (checkDesc(LayoutType::nspc) || checkDesc(LayoutType::ncsp)) { return; } #endif @@ -788,22 +792,18 @@ VectorDims Deconvolution::shapeInferInternal(const VectorDims& inDims, std::vect } void Deconvolution::execute(const dnnl::stream& strm) { - if (useACL) { + if (execPtrFactory) { std::vector srcMemory; - for (size_t i = 0; i < getOriginalInputsNumber(); i++) { + for (size_t i = 0; i < getOriginalInputsNumber(); i++) srcMemory.push_back(getSrcMemoryAtPort(i)); - } std::vector dstMemory; - for (size_t i = 0; i < getOriginalOutputsNumber(); i++) { + for (size_t i = 0; i < getOriginalOutputsNumber(); i++) dstMemory.push_back(getDstMemoryAtPort(i)); - } - // TODO: need to pass post ops data - execPtrDeconvACL->exec(srcMemory, dstMemory, nullptr); + execPtrFactory->exec(srcMemory, dstMemory, nullptr); return; } CPU_NODE_ASSERT(execPtr, "executor is not compiled"); - execPtr->exec(primArgs, strm); if (externOutShape) { @@ -965,7 +965,9 @@ void Deconvolution::prepareParams() { auto* selected_pd = getSelectedPrimitiveDescriptor(); CPU_NODE_ASSERT(selected_pd, "Preferable primitive descriptor is not set."); - if (useACL) { + // Minimal integration: always try factory path (ACL/JIT) with early-packing ctor; + // fall back to oneDNN path if factory does not provide an executor. + { if (isDynamicNode()) { initPaddingR(getParentEdgeAt(0)->getMemory().getDescPtr()->getShape(), getChildEdgeAt(0)->getMemory().getDescPtr()->getShape()); @@ -979,12 +981,24 @@ void Deconvolution::prepareParams() { dstMemoryDescs.push_back(getChildEdgeAt(i)->getMemory().getDescWithType()); } - execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutor(deconvAttrs, - srcMemoryDescs, - dstMemoryDescs, - *attr); - selected_pd->setImplementationType(execPtrDeconvACL->getImplType()); - return; + std::vector srcMemoriesEarly; + for (size_t i = 0; i < getOriginalInputsNumber(); i++) { + srcMemoriesEarly.push_back(getSrcMemoryAtPort(i)); + } + + try { + auto factory = selected_pd->getExecutorFactoryAs(); + if (factory) { + auto exec = factory->makeExecutorWithMem(deconvAttrs, srcMemoryDescs, dstMemoryDescs, *attr, srcMemoriesEarly); + if (exec) { + execPtrFactory = exec; + selected_pd->setImplementationType(execPtrFactory->getImplType()); + return; + } + } + } catch (...) { + // Fallback to oneDNN path when factory isn't applicable + } } auto inMemoryDesc = getParentEdgeAt(0)->getMemory().getDescWithType(); auto outMemoryDesc = getChildEdgeAt(0)->getMemory().getDescWithType(); @@ -1296,10 +1310,66 @@ bool Deconvolution::canFuseBias() const { } void Deconvolution::initSupportedPrimitiveDescriptors() { - if (!useACL) { - Node::initSupportedPrimitiveDescriptors(); - return; + // Prefer AArch64 JIT deconv for 5D FP16/FP32 on ARM64 regardless of ACL +#if defined(OPENVINO_ARCH_ARM64) + { + const auto rank = getInputShapeAtPort(0).getRank(); + const bool is5D = (rank == 5); + const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && + getOriginalInputPrecisionAtPort(1) == ov::element::f16 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f16; + const bool fp32_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f32 && + getOriginalInputPrecisionAtPort(1) == ov::element::f32 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f32; + if (is5D && (fp16_ok || fp32_ok)) { + auto [inDims, outDims] = makeDummyInOutShape(); + auto tmpInShape = Shape(inDims); + auto tmpOutShape = Shape(outDims); + initPaddingR(tmpInShape, tmpOutShape); + + const auto& creatorsMap = BlockedDescCreator::getCommonCreators(); + NodeConfig config; + config.inConfs.resize(getParentEdges().size()); + config.outConfs.resize(getOriginalOutputsNumber()); + + auto setDesc = [&](size_t port, bool isInput) { + const auto prec = + isInput ? getOriginalInputPrecisionAtPort(port) : getOriginalOutputPrecisionAtPort(port); + const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); + auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); + if (isInput) + config.inConfs[port].setMemDesc(d); + else + config.outConfs[port].setMemDesc(d); + }; + setDesc(0, true); + setDesc(1, true); + for (size_t i = 2; i < getParentEdges().size(); ++i) + setDesc(i, true); + setDesc(0, false); + + std::vector srcMemoryDescs; + srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); + for (size_t i = 1; i < config.inConfs.size(); i++) + srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); + std::vector dstMemoryDescs; + dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); + for (size_t i = 1; i < config.outConfs.size(); i++) + dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); + + auto factory = + std::make_shared(deconvAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); + return; + } } +#endif + + Node::initSupportedPrimitiveDescriptors(); + return; auto [inDims, outDims] = makeDummyInOutShape(); auto tmpInShape = Shape(inDims); diff --git a/src/plugins/intel_cpu/src/nodes/deconv.h b/src/plugins/intel_cpu/src/nodes/deconv.h index 2dda217f287845..1612a1c4e359dd 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.h +++ b/src/plugins/intel_cpu/src/nodes/deconv.h @@ -73,7 +73,8 @@ class Deconvolution : public Node { AttrPtr initPrimitiveAttr() override; AttrPtr makePrimitiveAttr(const VectorDims& dims); std::vector getAvailableFormatsForDims(const Shape& dims) const override; - std::shared_ptr execPtrDeconvACL = nullptr; + // Factory-based executor (JIT/ACL), created via DeconvExecutorFactory + std::shared_ptr execPtrFactory = nullptr; private: using executorPtr = std::shared_ptr; @@ -101,7 +102,6 @@ class Deconvolution : public Node { VectorDims dnnlCompatibleWeiDims; VectorDims expectedBiasDims; - bool useACL = false; DeconvAttrs deconvAttrs; Shape inShape, outShape; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp new file mode 100644 index 00000000000000..785058714af73e --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -0,0 +1,2344 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/jit_conv3d.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cpu_memory.h" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/core/type/float16.hpp" +#include "utils/cpu_utils.hpp" + +using namespace dnnl::impl::cpu::aarch64; + +namespace ov::intel_cpu { + +JitConv3DKernelF16::JitConv3DKernelF16() = default; + +void JitConv3DKernelF16::create_ker() { + jit_generator::create_kernel(); + ker_ = jit_kernel_cast(jit_ker()); +} + +void JitConv3DKernelF16::gen_minimal_kernel() { + using namespace Xbyak_aarch64; + const XReg reg_args = abi_param1; // x0 + const XReg reg_src = x1; // const uint16_t* src + const XReg reg_wei = x2; // const uint16_t* wei + const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats + const XReg reg_tail = x5; // size_t tail + const XReg reg_src_stride = x6; // size_t src_stride (bytes) + const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) + const XReg reg_acc = x8; // float* acc + const XReg reg_acc2 = x9; // float* acc2 (optional) + + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 40)); + ldr(reg_tail, ptr(reg_args, 48)); + ldr(reg_src_stride, ptr(reg_args, 56)); + ldr(reg_wei_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 88)); + ldr(reg_acc2, ptr(reg_args, 96)); + + Label Lsingle, Ldone; + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + cbz(reg_acc2, Lsingle); + b(Ldual_kx); + + L(Ldual_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_kw_cnt, ptr(reg_args, 104)); + ldr(reg_src_dx, ptr(reg_args, 112)); + ldr(reg_wei_dx, ptr(reg_args, 120)); + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; + const XReg reg_wei_blk_stride2 = x10; + ldr(reg_wei_blk_stride2, ptr(reg_args, 80)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + + // helpers to emit identical load patterns without runtime cost + auto emit_src8 = [&](const XReg& src, const XReg& src_stride) { + ld1(VReg(0).h[0], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[1], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[2], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[3], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[4], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[5], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[6], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[7], ptr(src)); + }; + auto emit_wei8_pair = [&](const XReg& wei, const XReg& wei2, const XReg& wei_stride, const XReg& wei_blk_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + ld1(VReg8H(2), ptr(wei2)); + add(wei, wei, wei_blk_stride); + add(wei2, wei2, wei_blk_stride); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[0], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[1], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[2], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[3], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[4], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[5], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[6], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + ld1(VReg(2).h[7], ptr(wei2)); + L(Lw_done); + }; + auto emit_wei8_single_blk = [&](const XReg& wei, const XReg& wei_stride, const XReg& wei_blk_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + add(wei, wei, wei_blk_stride); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + L(Lw_done); + }; + auto emit_wei8_single16 = [&](const XReg& wei, const XReg& wei_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + add(wei, wei, 16); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + L(Lw_done); + }; + L(Lkx_d); + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + Label Lrep_d_kx; + L(Lrep_d_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + emit_src8(reg_src, reg_src_stride); + emit_wei8_pair(reg_wei, reg_wei2, reg_wei_stride, reg_wei_blk_stride2); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d_kx); + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + // reduce and store accumulators + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + // Dual-OC path: v20 (oc0), v21 (oc1) + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + + Label Lrep_d, Ltail_prep_d, Ltail_done_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d); + emit_src8(reg_src, reg_src_stride); + // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 + Label Ldw_np_d, Ldw_done_d; + cmp(reg_wei_stride, 2); + b(NE, Ldw_np_d); + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, 16); + add(reg_wei2, reg_wei2, 16); + b(Ldw_done_d); + L(Ldw_np_d); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ldw_done_d); + // MAC into v20/v21 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); + + L(Ltail_prep_d); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + // lanes 0..7 guarded by tail + { + Label Ltail_done_d; + cmp(reg_tail, 0); b(LE, Ltail_done_d); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_d); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_d); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_d); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_d); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_d); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_d); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_d); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ltail_done_d); + } + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // horizontal add and store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + // Single-OC path + L(Lsingle); + b(Lsingle_kx); + // Single-OC with in-kernel kx loop + L(Lsingle_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + const XReg s_kw_cnt = x12; + const XReg s_src_dx = x13; + const XReg s_wei_dx = x14; + ldr(s_kw_cnt, ptr(reg_args, 104)); + ldr(s_src_dx, ptr(reg_args, 112)); + ldr(s_wei_dx, ptr(reg_args, 120)); + const XReg s_src_base = x15; + const XReg s_wei_base = x16; + const XReg s_wei_blk_stride2 = x10; + ldr(s_wei_blk_stride2, ptr(reg_args, 80)); + mov(s_src_base, reg_src); + mov(s_wei_base, reg_wei); + cbnz(s_kw_cnt, Lkx_s); + mov(s_kw_cnt, 1); + Label Lrep_s_kx; + L(Lkx_s); + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, s_src_base); + mov(reg_wei, s_wei_base); + L(Lrep_s_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // weights (vector fast path if stride==2) + emit_wei8_single_blk(reg_wei, reg_wei_stride, s_wei_blk_stride2); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s_kx); + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + { + Label Ltail_done_s_kx; + cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s_kx); + } + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(s_kw_cnt, s_kw_cnt, 1); + add(s_src_base, s_src_base, s_src_dx); + add(s_wei_base, s_wei_base, s_wei_dx); + cbnz(s_kw_cnt, Lkx_s); + // reduce/store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + Label Lrep_s, Ltail_prep_s, Ltail_done_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s); + emit_src8(reg_src, reg_src_stride); + emit_wei8_single16(reg_wei, reg_wei_stride); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + // Tail (single) + L(Ltail_prep_s); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + { + Label Ltail_done_s; + cmp(reg_tail, 0); b(LE, Ltail_done_s); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s); + } + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + + L(Ldone); + ret(); +} + +void JitConv3DKernelF16::gen_optimized_kernel() { + using namespace Xbyak_aarch64; + // abi_param1 -> args pointer + const XReg reg_args = abi_param1; + // Use call-clobbered registers to avoid saving/restoring callee-saved regs + const XReg reg_src = x1; + const XReg reg_wei = x2; + const XReg reg_wei2 = x10; + const XReg reg_reps = x3; // number of full 8-lane blocks + const XReg reg_tail = x9; // remaining channels (< 8) + const XReg reg_src_stride = x4; + const XReg reg_wei_stride = x5; + const XReg reg_acc = x6; + const XReg reg_acc2 = x11; + const XReg reg_wei3 = x19; + const XReg reg_wei4 = x20; + const XReg reg_acc3 = x21; + const XReg reg_acc4 = x22; + + // Prolog: save callee-saved we will use (x19-x29) and LR (x30) + stp(XReg(19), XReg(20), pre_ptr(sp, -16)); + stp(XReg(21), XReg(22), pre_ptr(sp, -16)); + stp(XReg(23), XReg(24), pre_ptr(sp, -16)); + stp(XReg(25), XReg(26), pre_ptr(sp, -16)); + stp(XReg(27), XReg(28), pre_ptr(sp, -16)); + stp(XReg(29), XReg(30), pre_ptr(sp, -16)); + + // Load args + ldr(reg_src, ptr(reg_args)); // src + ldr(reg_wei, ptr(reg_args, 8)); // wei + ldr(reg_wei2, ptr(reg_args, 16)); // wei2 (optional) + ldr(reg_wei3, ptr(reg_args, 24)); // wei3 (optional) + ldr(reg_wei4, ptr(reg_args, 32)); // wei4 (optional) + ldr(reg_reps, ptr(reg_args, 40)); // repeats + ldr(reg_tail, ptr(reg_args, 48)); // tail (<= 8) + ldr(reg_src_stride, ptr(reg_args, 56)); // src_stride bytes + ldr(reg_wei_stride, ptr(reg_args, 64)); // wei_stride bytes + const XReg reg_src_blk_stride = x7; + const XReg reg_wei_blk_stride = x8; + ldr(reg_src_blk_stride, ptr(reg_args, 72)); // src_blk_stride bytes + ldr(reg_wei_blk_stride, ptr(reg_args, 80)); // wei_blk_stride bytes + ldr(reg_acc, ptr(reg_args, 88)); // acc (float*) + ldr(reg_acc2, ptr(reg_args, 96)); // acc2 (float* or 0) + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_kw_cnt, ptr(reg_args, 104)); // kw count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dx, ptr(reg_args, 112)); // src dx step in bytes (x dimension) + ldr(reg_wei_dx, ptr(reg_args, 120)); // wei dx step in bytes (x dimension) + ldr(reg_acc3, ptr(reg_args, 128)); // acc3 (float* or 0) + ldr(reg_acc4, ptr(reg_args, 136)); // acc4 (float* or 0) + const XReg reg_kh_cnt = x26; + const XReg reg_src_dy = x27; + const XReg reg_wei_dy = x28; + ldr(reg_kh_cnt, ptr(reg_args, 144)); // kh count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) + ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) + + // Optionally force single-ky iteration for stability on certain platforms + if (m_force_single_kh_) { + mov(reg_kh_cnt, 1); + } + eor(reg_acc4, reg_acc4, reg_acc4); + + Label Lsingle, Lend_all; + // If acc4 != 0, run quad-OC; else if acc2 != 0, run dual-OC; else single-OC. + Label Lq_entry; + cbnz(reg_acc4, Lq_entry); + cbz(reg_acc2, Lsingle); + + // ---------------- Quad-OC path ---------------- + L(Lq_entry); + { + // Zero v20..v23 + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + eor(VReg16B(22), VReg16B(22), VReg16B(22)); + eor(VReg16B(23), VReg16B(23), VReg16B(23)); + + Label Lq_ky_loop, Lq_kx_loop, Lq_loop, Lq_after_loop, Lq_after_fill, Lq_after_kx, Lq_np; + // Packed fast path if wei_stride == 2 + cmp(reg_wei_stride, 2); + b(NE, Lq_np); + + // Save repeats and bases for kw loop + const XReg q_reps_init = x15; + const XReg q_src_base = x16; + const XReg q_wei_base = x17; + // avoid x18 on Apple; use callee-saved and restore at epilog + const XReg q_wei2_base = x23; + const XReg q_wei3_base = x24; + const XReg q_wei4_base = x25; + const XReg q_kw_init = x28; + const XReg q_kh_work = x29; + mov(q_reps_init, reg_reps); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + mov(q_wei3_base, reg_wei3); + mov(q_wei4_base, reg_wei4); + mov(q_kw_init, reg_kw_cnt); + mov(q_kh_work, reg_kh_cnt); + + // ky loop + L(Lq_ky_loop); + // kx loop entry + L(Lq_kx_loop); + mov(reg_reps, q_reps_init); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + mov(reg_wei3, q_wei3_base); + mov(reg_wei4, q_wei4_base); + mov(reg_kw_cnt, q_kw_init); + + // Main repeats loop + L(Lq_loop); + cmp(reg_reps, 0); + b(LE, Lq_after_loop); + // Load 8 src lanes + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load 4 weight vectors + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + ld1(VReg8H(3), ptr(reg_wei3)); + ld1(VReg8H(4), ptr(reg_wei4)); + // Accumulate into v20..v23 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + // Advance block pointers (src already advanced by 8*src_stride during lane loads) + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + add(reg_wei3, reg_wei3, reg_wei_blk_stride); + add(reg_wei4, reg_wei4, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Lq_loop); + + // Tail <=8 + L(Lq_after_loop); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + eor(VReg16B(3), VReg16B(3), VReg16B(3)); + eor(VReg16B(4), VReg16B(4), VReg16B(4)); + cmp(reg_tail, 0); + b(LE, Lq_after_fill); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + ld1(VReg(3).h[0], ptr(reg_wei3)); + ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Lq_after_fill); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + ld1(VReg(3).h[1], ptr(reg_wei3)); + ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Lq_after_fill); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + ld1(VReg(3).h[2], ptr(reg_wei3)); + ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Lq_after_fill); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + ld1(VReg(3).h[3], ptr(reg_wei3)); + ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Lq_after_fill); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + ld1(VReg(3).h[4], ptr(reg_wei3)); + ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Lq_after_fill); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + ld1(VReg(3).h[5], ptr(reg_wei3)); + ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Lq_after_fill); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + ld1(VReg(3).h[6], ptr(reg_wei3)); + ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Lq_after_fill); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(3).h[7], ptr(reg_wei3)); + ld1(VReg(4).h[7], ptr(reg_wei4)); + L(Lq_after_fill); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + add(q_wei3_base, q_wei3_base, reg_wei_dx); + add(q_wei4_base, q_wei4_base, reg_wei_dx); + subs(reg_kw_cnt, reg_kw_cnt, 1); + b(GT, Lq_kx_loop); + + // reduce/store 4 accumulators + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); + faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); + faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + add(q_src_base, q_src_base, reg_src_dy); + add(q_wei_base, q_wei_base, reg_wei_dy); + add(q_wei2_base, q_wei2_base, reg_wei_dy); + add(q_wei3_base, q_wei3_base, reg_wei_dy); + add(q_wei4_base, q_wei4_base, reg_wei_dy); + subs(q_kh_work, q_kh_work, 1); + b(GT, Lq_ky_loop); + + // After ky loop: reduce/store + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); + fadd(SReg(2), SReg(2), SReg(22)); + str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); + fadd(SReg(3), SReg(3), SReg(23)); + str(SReg(3), ptr(reg_acc4)); + b(Lend_all); + + // Not-packed fallback for quad: do lane-wise loads for 4 outputs + L(Lq_np); + // Zero v20..v23 + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + eor(VReg16B(22), VReg16B(22), VReg16B(22)); + eor(VReg16B(23), VReg16B(23), VReg16B(23)); + // single block (not looped here) — rely on host collapsing work + // lanes 0..7 + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + ld1(VReg(3).h[0], ptr(reg_wei3)); + ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + ld1(VReg(3).h[1], ptr(reg_wei3)); + ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + ld1(VReg(3).h[2], ptr(reg_wei3)); + ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + ld1(VReg(3).h[3], ptr(reg_wei3)); + ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + ld1(VReg(3).h[4], ptr(reg_wei3)); + ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + ld1(VReg(3).h[5], ptr(reg_wei3)); + ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + ld1(VReg(3).h[6], ptr(reg_wei3)); + ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(3).h[7], ptr(reg_wei3)); + ld1(VReg(4).h[7], ptr(reg_wei4)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + // reduce/store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); + faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); + faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); + fadd(SReg(2), SReg(2), SReg(22)); + str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); + fadd(SReg(3), SReg(3), SReg(23)); + str(SReg(3), ptr(reg_acc4)); + b(Lend_all); + } + // ---------------- Dual-OC path ---------------- + { + // Zero FP32 accumulators v20.4s and v21.4s + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + Label Lnp_d, Ld_after_fill, Ld_after_loop, Ld_kx_loop, Ld_after_kx; + // If not packed, fallback to non-packed path below (no kw-loop optimization) + cmp(reg_wei_stride, 2); + b(NE, Lnp_d); + // Save initial repeats and bases + const XReg reg_reps_init = x15; + const XReg reg_src_base = x16; + const XReg reg_wei_base = x17; + const XReg reg_wei2_base = x18; + mov(reg_reps_init, reg_reps); + mov(reg_src_base, reg_src); + mov(reg_wei_base, reg_wei); + mov(reg_wei2_base, reg_wei2); + const XReg reg_kw_init = x28; + const XReg reg_kh_work = x29; + mov(reg_kw_init, reg_kw_cnt); + mov(reg_kh_work, reg_kh_cnt); + // ky-loop wrapper around kx-loop (packed fast path) + Label Ld_ky_loop; + L(Ld_ky_loop); + // Reset per-ky state and restore kw counter + mov(reg_reps, reg_reps_init); + mov(reg_src, reg_src_base); + mov(reg_wei, reg_wei_base); + mov(reg_wei2, reg_wei2_base); + mov(reg_kw_cnt, reg_kw_init); + // kw-loop (only if kw_cnt > 0); if kw_cnt==0 -> process a single position + L(Ld_kx_loop); + // Main loop over full 8-lane channel blocks + Label Ld_loop; + L(Ld_loop); + cmp(reg_reps, 0); + b(LE, Ld_after_loop); + // Load 8 src half lanes with strides between channels + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load 8 half weights for oc0 and oc1 as vectors + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + // MAC for oc0 → v20, oc1 → v21 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance pointers to next block (src already advanced by 8*src_stride) + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Ld_loop); + // Tail processing (<=8) + L(Ld_after_loop); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ld_after_fill); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ld_after_fill); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ld_after_fill); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ld_after_fill); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ld_after_fill); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ld_after_fill); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ld_after_fill); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ld_after_fill); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ld_after_fill); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance base pointers for next kx + add(reg_src_base, reg_src_base, reg_src_dx); + add(reg_wei_base, reg_wei_base, reg_wei_dx); + add(reg_wei2_base, reg_wei2_base, reg_wei_dx); + // Decrement kw_cnt and loop + subs(reg_kw_cnt, reg_kw_cnt, 1); + b(GT, Ld_kx_loop); + // After kx loop for one ky, advance to next ky if any + add(reg_src_base, reg_src_base, reg_src_dy); + add(reg_wei_base, reg_wei_base, reg_wei_dy); + add(reg_wei2_base, reg_wei2_base, reg_wei_dy); + subs(reg_kh_work, reg_kh_work, 1); + b(GT, Ld_ky_loop); + // Fallthrough to store + b(Ld_after_kx); + + // Not-packed path: pairwise lane loads for both wei/wei2 (no kw-loop optimization) + L(Lnp_d); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + // Accumulate + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance to next block + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + sub(reg_reps, reg_reps, 1); + b(Ld_loop); + + // Reduce and store both accumulators + L(Ld_after_kx); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Lend_all); + } + + // ---------------- Single-OC path (default) ---------------- + L(Lsingle); + // Zero FP32 accumulator v20.4s + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + + Label Lloop2, Lloop, Lafter_loop, Lafter_fill, Lnp1, Lend1, Lnp2, Lend2, Lnot_packed, Lkx_loop_s, Lafter_kx_s, + Ls_loop, Ls_after_loop, Ls_after_fill; + + // Unrolled-by-2 loop over full 8-lane channel blocks (default single position) + b(Lloop); + L(Lloop2); + cmp(reg_reps, 2); + b(LT, Lloop); + + // First block (packed fast path if available) + cmp(reg_wei_stride, 2); + b(NE, Lnp1); + // packed: src lanes (8), then vector-load weights + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg8H(1), ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lend1); + L(Lnp1); + // not packed: pairwise lanes + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lend1); + + // Second block + cmp(reg_wei_stride, 2); + b(NE, Lnp2); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg8H(1), ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lend2); + L(Lnp2); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lend2); + + sub(reg_reps, reg_reps, 2); + b(Lloop2); + + // Single-block loop for remaining one block + L(Lloop); + cmp(reg_reps, 0); + b(EQ, Lafter_loop); + + // Prepare containers for src/wei half vectors in v0/v1 (fully overwritten by loads) + + // Choose packed-weight fast path if wei_stride == 2 bytes + cmp(reg_wei_stride, 2); + b(NE, Lnot_packed); + + // Packed path: fill src lanes, vector-load wei + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // load 8 half wei as one vector + ld1(VReg8H(1), ptr(reg_wei)); + // Multiply-accumulate (fp16 widen to fp32): lower and upper halves + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // Advance pointers to next block + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Lloop); + + // Not-packed path: fill src/wei lanes pairwise and accumulate + L(Lnot_packed); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + // Widening MAC for both halves + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // Advance pointers to next block (src already advanced by 8*src_stride) + add(reg_wei, reg_wei, reg_wei_stride); + sub(reg_reps, reg_reps, 1); + b(Lloop); + + L(Lafter_loop); + + // Tail processing (<= 8) + // Prepare containers for src/wei half vectors in v0/v1 + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + // lane 0 + cmp(reg_tail, 0); + b(LE, Lafter_fill); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 1 + cmp(reg_tail, 1); + b(LE, Lafter_fill); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 2 + cmp(reg_tail, 2); + b(LE, Lafter_fill); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 3 + cmp(reg_tail, 3); + b(LE, Lafter_fill); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 4 + cmp(reg_tail, 4); + b(LE, Lafter_fill); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 5 + cmp(reg_tail, 5); + b(LE, Lafter_fill); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 6 + cmp(reg_tail, 6); + b(LE, Lafter_fill); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 7 + cmp(reg_tail, 7); + b(LE, Lafter_fill); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + + L(Lafter_fill); + // Accumulate tail using widening fp16 MAC + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + + // Horizontal add v20 + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + + // Load *acc, add, store back + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + + L(Lend_all); + // Epilog: restore callee-saved + ldp(XReg(29), XReg(30), post_ptr(sp, 16)); + ldp(XReg(27), XReg(28), post_ptr(sp, 16)); + ldp(XReg(25), XReg(26), post_ptr(sp, 16)); + ldp(XReg(23), XReg(24), post_ptr(sp, 16)); + ldp(XReg(21), XReg(22), post_ptr(sp, 16)); + ldp(XReg(19), XReg(20), post_ptr(sp, 16)); + ret(); +} + +void JitConv3DKernelF16::generate() { + // Keep body small for clang-tidy readability-function-size + gen_minimal_kernel(); + gen_optimized_kernel(); +} + +void JitConv3DKernelF32::create_ker() { + jit_generator::create_kernel(); + ker_ = reinterpret_cast(const_cast(jit_ker())); +} + +void JitConv3DKernelF32::generate() { + using namespace Xbyak_aarch64; + + const XReg reg_args = abi_param1; + + const XReg reg_src = x1; + const XReg reg_wei = x2; + const XReg reg_wei2 = x3; + const XReg reg_reps = x4; + const XReg reg_tail = x5; + const XReg reg_src_stride = x6; + const XReg reg_wei_stride = x7; + const XReg reg_src_blk_stride = x8; + const XReg reg_wei_blk_stride = x9; + const XReg reg_acc = x10; + const XReg reg_acc2 = x11; + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 24)); + ldr(reg_tail, ptr(reg_args, 32)); + ldr(reg_src_stride, ptr(reg_args, 40)); + ldr(reg_wei_stride, ptr(reg_args, 48)); + ldr(reg_src_blk_stride, ptr(reg_args, 56)); + ldr(reg_wei_blk_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 72)); + ldr(reg_acc2, ptr(reg_args, 80)); + ldr(reg_kw_cnt, ptr(reg_args, 88)); + ldr(reg_src_dx, ptr(reg_args, 96)); + ldr(reg_wei_dx, ptr(reg_args, 104)); + + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; + + Label Lsingle, Ldone; + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + + cbz(reg_acc2, Lsingle); + b(Ldual_kx); + + L(Ldual_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + + L(Lkx_d); + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + + Label Lrep_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + Label Lw_np_d, Lw_done_d; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_d); + ld1(VReg4S(1), ptr(reg_wei)); + ld1(VReg4S(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + b(Lw_done_d); + L(Lw_np_d); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + L(Lw_done_d); + add(reg_src, reg_src, reg_src_stride); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); + + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + L(Lsingle); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + cbnz(reg_kw_cnt, Lsingle_kx); + mov(reg_kw_cnt, 1); + + L(Lsingle_kx); + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + + Label Lrep_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + Label Lw_np_s, Lw_done_s; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_s); + ld1(VReg4S(1), ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lw_done_s); + L(Lw_np_s); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lw_done_s); + add(reg_src, reg_src, reg_src_stride); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); + L(Ltail_done_s_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lsingle_kx); + + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + b(Ldone); + + L(Ldone); + ret(); +} + + +[[maybe_unused]] static inline auto ptr_f16(const MemoryPtr& mem) -> const uint16_t* { + return reinterpret_cast(mem->getData()); +} +[[maybe_unused]] static inline auto ptr_f16(MemoryPtr& mem) -> uint16_t* { + return reinterpret_cast(mem->getData()); +} + +JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, + const MemoryArgs& memory, + const ExecutorContext::CPtr& context) + : m_attrs(attrs), + m_memory(memory) { + (void)context; + m_threadsNum = static_cast(parallel_get_max_threads()); + // Decide precision from src tensor + ov::element::Type sp = ov::element::dynamic; + auto it = memory.find(ARG_SRC); + if (it != memory.end() && it->second && it->second->getDescPtr()) { + sp = it->second->getDescPtr()->getPrecision(); + } + m_is_fp32 = (sp == ov::element::f32); + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + // default to fp16 + m_ip_kernel = std::make_unique(); + m_ip_kernel->create_ker(); + } + + // Extract optional PReLU post-op (per-tensor or per-channel), but keep disabled by default. + for (const auto& po : m_attrs.postOps) { + if (const auto* const ss = std::any_cast(&po)) { + if (ss->type() == ScaleShiftPostOp::Type::prelu) { + m_has_prelu = true; + m_prelu_slopes = ss->scales(); + break; + } + } + } + + // Early weight packing (only if shapes are static). Kept inside executor per policy. + prepare_weights_early(m_memory); +} + +bool JitConv3DExecutor::supports(const ConvConfig& cfg) { + // Require 5D NCDHW, FP16 or FP32 src/wei/dst, group=1, no dilation, stride 1 or 2 + if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) + return false; + if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) + return false; + + const auto& s = cfg.descs.at(ARG_SRC)->getShape(); + const auto& w = cfg.descs.at(ARG_WEI)->getShape(); + const auto& d = cfg.descs.at(ARG_DST)->getShape(); + if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) + return false; + + const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); + const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); + const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); + const bool f16_ok = (sp == ov::element::f16 && wp == ov::element::f16 && dp == ov::element::f16); + const bool f32_ok = (sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32); + if (!(f16_ok || f32_ok)) + return false; + + // group == 1: weights rank==5 (no groups) + if (w.getRank() != 5) + return false; + + // dilation == 0 + for (auto v : cfg.attrs.dilation) { + if (v != 0) + return false; + } + // stride in [1,2] if set + for (auto v : cfg.attrs.stride) { + if (!(v == 1 || v == 2)) + return false; + } + return true; +} + +void JitConv3DExecutor::prepare_weights_early(const MemoryArgs& memory) { + // Guard: only when shapes are static + auto src_it = memory.find(ARG_SRC); + auto wei_it = memory.find(ARG_WEI); + if (src_it == memory.end() || wei_it == memory.end() || !src_it->second || !wei_it->second) + return; + const auto& s = src_it->second->getDescPtr()->getShape(); + const auto& w = wei_it->second->getDescPtr()->getShape(); + if (!s.isStatic() || !w.isStatic()) + return; + if (m_is_fp32) { + ensure_weights_packed_f32(memory); + } else { + ensure_weights_packed(memory); + } +} + +void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { + // NCDHW + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + auto dst = memory.at(ARG_DST); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); + + const size_t N = srcDims[0]; + const size_t C = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; + const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; + const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; + + const size_t PD0 = m_attrs.paddingL.size() > 0 ? static_cast(m_attrs.paddingL[0]) : 0; + const size_t PH0 = m_attrs.paddingL.size() > 1 ? static_cast(m_attrs.paddingL[1]) : 0; + const size_t PW0 = m_attrs.paddingL.size() > 2 ? static_cast(m_attrs.paddingL[2]) : 0; + + const uint16_t* src_p = ptr_f16(src); + uint16_t* dst_p = ptr_f16(dst); + + auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) -> size_t { + return (((n * C + c) * ID + z) * IH + y) * IW + x; + }; + auto index_dst = [&](size_t n, size_t oc, size_t z, size_t y, size_t x) -> size_t { + return (((n * OC + oc) * OD + z) * OH + y) * OW + x; + }; + + // Prepare packed weights once + ensure_weights_packed(memory); + + auto worker = [&](size_t n, size_t oc_quad, size_t od) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = oc0 + 1; + const size_t oc2 = oc0 + 2; + const size_t oc3 = oc0 + 3; + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + const int64_t iz0 = static_cast(od * SD) - static_cast(PD0); + for (size_t oh = 0; oh < OH; ++oh) { + const int64_t iy0 = static_cast(oh * SH) - static_cast(PH0); + for (size_t ow = 0; ow < OW; ++ow) { + const int64_t ix0 = static_cast(ow * SW) - static_cast(PW0); + + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + const size_t src_c_stride_elems = ID * IH * IW; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = + std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = + std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = + std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + // iy/ix for ky_lo/kx_lo not needed; use iy2/ix2 per ky below + + // Loop over ky in host; kernel handles kx via kw_cnt (always packed) + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + auto run_pair = [&](float* acc, float* acc2, size_t base0, size_t base1) { + jit_conv3d_call_args aa{}; + aa.src = src_p + s_base2; + aa.src_stride = src_c_stride_elems * sizeof(uint16_t); + aa.src_blk_stride = aa.src_stride * 8; + aa.acc = acc; + aa.acc2 = acc2; + aa.repeats = C / 8; + aa.tail = C % 8; + aa.kw_cnt = kw_count; + aa.src_dx = sizeof(uint16_t); + aa.wei = m_wei_packed.data() + base0; + if (acc2) aa.wei2 = m_wei_packed.data() + base1; + aa.wei_stride = sizeof(uint16_t); + aa.wei_blk_stride = aa.wei_stride * 8; + aa.wei_dx = m_padded_C * sizeof(uint16_t); + (*m_ip_kernel)(&aa); + }; + const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base1 = has_oc1 ? (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C : 0; + run_pair(&acc0, has_oc1 ? &acc1 : nullptr, base0, base1); + if (has_oc2) { + const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base3 = has_oc3 ? (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C : 0; + run_pair(&acc2, has_oc3 ? &acc3 : nullptr, base2, base3); + } + } + } + } + } else { + for (size_t kz = 0; kz < KD; ++kz) { + const int64_t iz = iz0 + static_cast(kz); + if (iz < 0 || iz >= static_cast(ID)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const int64_t iy = iy0 + static_cast(ky); + if (iy < 0 || iy >= static_cast(IH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const int64_t ix = ix0 + static_cast(kx); + if (ix < 0 || ix >= static_cast(IW)) + continue; + const size_t s_base0 = index_src(n, + 0, + static_cast(iz), + static_cast(iy), + static_cast(ix)); + auto run_pair2 = [&](float* acc, float* acc2, size_t base0, size_t base1) { + jit_conv3d_call_args aa{}; + aa.src = src_p + s_base0; + aa.src_stride = src_c_stride_elems * sizeof(uint16_t); + aa.src_blk_stride = aa.src_stride * 8; + aa.acc = acc; + aa.acc2 = acc2; + const size_t pack_base0 = base0; + aa.wei = m_wei_packed.data() + pack_base0; + aa.repeats = C / 8; + aa.tail = C % 8; + aa.wei_stride = sizeof(uint16_t); + aa.wei_blk_stride = aa.wei_stride * 8; + if (acc2) { + const size_t pack_base1 = base1; + aa.wei2 = m_wei_packed.data() + pack_base1; + } + (*m_ip_kernel)(&aa); + }; + const size_t b0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t b1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C : 0; + run_pair2(&acc0, has_oc1 ? &acc1 : nullptr, b0, b1); + if (has_oc2) { + const size_t b2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t b3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C : 0; + run_pair2(&acc2, has_oc3 ? &acc3 : nullptr, b2, b3); + } + } + } + } + } + // No optional post-ops in product mode + + dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); + if (has_oc1) + dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + if (has_oc2) + dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + if (has_oc3) + dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); + } + } + }; + + ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); +} + +void JitConv3DExecutor::execute(const MemoryArgs& memory) { + if (m_is_fp32) + run_naive_fp32(memory); + else + run_naive_fp16(memory); +} + +void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { + if (m_wei_packed_ready_f32) + return; + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + if (srcDims.size() != 5 || weiDims.size() != 5) + return; + const size_t C = srcDims[1]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_C_f32 = (C + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_C_f32; + m_wei_packed_f32.assign(total, 0.0f); + const auto* wsrc = reinterpret_cast(wei->getData()); + + auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + const size_t blk = c / 4; + const size_t lane = c % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t c = 0; c < C; ++c) { + m_wei_packed_f32[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f32 = true; + // no global cache store +} + +void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + auto dst = memory.at(ARG_DST); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); + + const size_t N = srcDims[0]; + const size_t C = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; + const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; + const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; + + const ptrdiff_t PD0 = m_attrs.paddingL.size() > 0 ? m_attrs.paddingL[0] : 0; + const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; + const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; + + const auto* src_p = reinterpret_cast(src->getData()); + const auto* wei_p = reinterpret_cast(wei->getData()); + auto* dst_p = reinterpret_cast(dst->getData()); + + auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * C + c) * ID + z) * IH + y) * IW + x; + }; + auto index_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; + }; + + const size_t src_c_stride_elems = ID * IH * IW; // elements between channels + const size_t wei_c_stride_elems = KD * KH * KW; // elements between weight channels + + ensure_weights_packed_f32(memory); + + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = std::min(oc0 + 1, OC); + const size_t oc2 = std::min(oc0 + 2, OC); + const size_t oc3 = std::min(oc0 + 3, OC); + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + + for (size_t od = 0; od < OD; ++od) { + const ptrdiff_t iz0 = static_cast(od) * static_cast(SD) - PD0; + for (size_t oh = 0; oh < OH; ++oh) { + const ptrdiff_t iy0 = static_cast(oh) * static_cast(SH) - PH0; + for (size_t ow = 0; ow < OW; ++ow) { + const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; + + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = + std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = + std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = + std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy = static_cast(iy0 + ky); + const size_t ix = static_cast(ix0 + kx_lo); + const size_t s_base = index_src(n, 0, iz, iy, ix); + + if (m_wei_packed_ready_f32) { + // pair 0 + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + } else { + // generic path: non-packed weights + const size_t w0 = index_wei(oc0, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w1 = has_oc1 ? index_wei(oc1, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + a.wei = wei_p + w0; + if (has_oc1) + a.wei2 = wei_p + w1; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = sizeof(float); + (*m_ip_kernel_f32)(&a); + if (has_oc2) { + const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a2{a}; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.wei = wei_p + w2; + if (has_oc3) a2.wei2 = wei_p + w3; + (*m_ip_kernel_f32)(&a2); + } + } + } + } + } + } else { + // generic spatial stride path (host loops over all taps) + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t iz = iz0 + static_cast(kz); + if (iz < 0 || iz >= static_cast(ID)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy = iy0 + static_cast(ky); + if (iy < 0 || iy >= static_cast(IH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix = ix0 + static_cast(kx); + if (ix < 0 || ix >= static_cast(IW)) + continue; + const size_t s_base = index_src(n, + 0, + static_cast(iz), + static_cast(iy), + static_cast(ix)); + auto run_pair_f32 = [&](float* acc, float* acc2, const float* w0, const float* w1) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = acc; + a.acc2 = acc2; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = w0; + if (w1) a.wei2 = w1; + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + }; + const size_t base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + const size_t base1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32 : 0; + run_pair_f32(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f32.data() + base0, + has_oc1 ? m_wei_packed_f32.data() + base1 : nullptr); + if (has_oc2) { + const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + const size_t base3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32 : 0; + run_pair_f32(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f32.data() + base2, + has_oc3 ? m_wei_packed_f32.data() + base3 : nullptr); + } + } + } + } + } + + // Store FP32 accumulators directly to FP32 destination + dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; + if (has_oc1) + dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; + if (has_oc2) + dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; + if (has_oc3) + dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; + } + } + } + }); +} + +} // namespace ov::intel_cpu +void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& memory) { + if (m_wei_packed_ready) + return; + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + if (srcDims.size() != 5 || weiDims.size() != 5) + return; + const size_t C = srcDims[1]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_C = (C + 7) / 8 * 8; + // Pack layout: [OC, KD, KH, KW, Ct] + const size_t total = OC * KD * KH * KW * m_padded_C; + m_wei_packed.assign(total, static_cast(0)); + const auto* wsrc = reinterpret_cast(wei->getData()); + + auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t blk = c / 8; + const size_t lane = c % 8; + return base + blk * 8 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t c = 0; c < C; ++c) { + m_wei_packed[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready = true; + // no global cache store +} diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp new file mode 100644 index 00000000000000..d9100f7d780e90 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -0,0 +1,153 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "nodes/executors/convolution_config.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "onednn/iml_type_mapper.h" +#include + +namespace ov::intel_cpu { + +struct jit_conv3d_call_args { + const uint16_t* src; // f16 base ptr + const uint16_t* wei; // f16 base ptr + const uint16_t* wei2; // optional second oc f16 base ptr (can be null) + const uint16_t* wei3; // optional third oc f16 base ptr (can be null) + const uint16_t* wei4; // optional fourth oc f16 base ptr (can be null) + size_t repeats; // number of full 8-channel blocks + size_t tail; // remaining channels (< 8) + size_t src_stride; // stride between channels in bytes + size_t wei_stride; // stride between channels in bytes + size_t src_blk_stride; // stride between successive 8-channel blocks in bytes + size_t wei_blk_stride; // stride between successive 8-channel blocks in bytes + float* acc; // f32 accumulator + float* acc2; // optional second f32 accumulator (can be null) + size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dx; // bytes to advance src base between successive kx taps + size_t wei_dx; // bytes to advance weights base between successive kx taps + float* acc3; // optional third f32 accumulator (can be null) + float* acc4; // optional fourth f32 accumulator (can be null) + size_t kh_cnt; // number of taps along H to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dy; // bytes to advance src base between successive ky taps + size_t wei_dy; // bytes to advance weights base between successive ky taps +}; + +struct jit_conv3d_f32_call_args { + const float* src; + const float* wei; + const float* wei2; + size_t repeats; + size_t tail; + size_t src_stride; + size_t wei_stride; + size_t src_blk_stride; + size_t wei_blk_stride; + float* acc; + float* acc2; + size_t kw_cnt; + size_t src_dx; + size_t wei_dx; +}; + +class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF16) + using jit_fn = void (*)(const jit_conv3d_call_args*); + + JitConv3DKernelF16(); + + void create_ker(); + inline void operator()(const jit_conv3d_call_args* p) const { + ker_(p); + } + +private: + void generate() override; + + void gen_minimal_kernel(); + void gen_optimized_kernel(); + + jit_fn ker_{nullptr}; + bool m_force_single_kh_{true}; +public: + void set_force_single_kh(bool v) { m_force_single_kh_ = v; } +}; + +class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF32) + using jit_fn = void (*)(const jit_conv3d_f32_call_args*); + + JitConv3DKernelF32() = default; + + void create_ker(); + inline void operator()(const jit_conv3d_f32_call_args* p) const { + ker_(p); + } + +private: + void generate() override; + + jit_fn ker_{nullptr}; +}; + +class JitConv3DExecutor : public Executor { +public: + JitConv3DExecutor(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); + + bool update(const MemoryArgs& memory) override { + m_memory = memory; + return true; + } + void execute(const MemoryArgs& memory) override; + void execute() override { + execute(m_memory); + } + void exec([[maybe_unused]] const std::vector& src, + [[maybe_unused]] const std::vector& dst) override {} + + [[nodiscard]] impl_desc_type implType() const override { + return impl_desc_type::jit_asimd; + } + + static bool supports(const ConvConfig& cfg); + + void prepare_weights_early(const MemoryArgs& memory); + +private: + void run_naive_fp16(const MemoryArgs& memory); + void ensure_weights_packed(const MemoryArgs& memory); + void run_naive_fp32(const MemoryArgs& memory); + void ensure_weights_packed_f32(const MemoryArgs& memory); + + std::unique_ptr m_ip_kernel; + std::unique_ptr m_ip_kernel_f32; + + ConvAttrs m_attrs; + MemoryArgs m_memory; + size_t m_threadsNum{0}; + bool m_is_fp32{false}; + + std::vector m_wei_packed; + bool m_wei_packed_ready{false}; + size_t m_padded_C{0}; + std::vector m_wei_packed_f32; + bool m_wei_packed_ready_f32{false}; + size_t m_padded_C_f32{0}; + + bool m_has_prelu{false}; + std::vector m_prelu_slopes; +}; + +using JitConv3DExecutorPtr = std::shared_ptr; + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp new file mode 100644 index 00000000000000..7c864786b3dc3c --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -0,0 +1,1679 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/jit_deconv3d.hpp" + +#include +#include +#include +#include + +#include "cpu_memory.h" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/core/type/float16.hpp" + +namespace ov::intel_cpu { + +bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& /*attr*/) { + deconvAttrs = attrs; + m_srcDescs = srcDescs; + m_dstDescs = dstDescs; + // Choose kernel by precision + const auto prec = m_srcDescs[0]->getPrecision(); + m_is_fp32 = (prec == ov::element::f32); + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + m_ip_kernel_f16 = std::make_unique(); + m_ip_kernel_f16->create_ker(); + } + return true; +} + +void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector& src) { + if (m_wei_packed_ready_f16) + return; + // src[1] holds weights for deconv with shape: + // - no-group: [IC, OC, KD, KH, KW] + // - group: [G, ICg, OCg, KD, KH, KW] + const auto& weiDims = src[1]->getStaticDims(); + const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f16 = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_f16.assign(total, static_cast(0)); + + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t blk = ic / 8; + const size_t lane = ic % 8; + return base + blk * 8 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_f16[idx_wei_pack(oc, ic, kz, ky, kx)] = + wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f16 = true; + return; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f16 = (ICg + 7) / 8 * 8; // per-group padding + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_f16.assign(total, static_cast(0)); + + auto idx_wei_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) -> size_t { + // layout [G, ICg, OCg, KD, KH, KW] + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + auto idx_wei_pack = [&](size_t oc_global, size_t icg, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t blk = icg / 8; + const size_t lane = icg % 8; + return base + blk * 8 + lane; + }; + + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_f16[idx_wei_pack(oc_global, icg, kz, ky, kx)] = + wsrc[idx_wei_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + } + m_wei_packed_ready_f16 = true; + return; + } +} + +void JitDeconv3DExecutor::ensure_weights_packed_f32(const std::vector& src) { + if (m_wei_packed_ready_f32) + return; + const auto& weiDims = src[1]->getStaticDims(); + const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f32 = (IC + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_f32.assign(total, 0.0F); + + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t blk = ic / 4; + const size_t lane = ic % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_f32[idx_wei_pack(oc, ic, kz, ky, kx)] = + wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f32 = true; + return; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f32 = (ICg + 3) / 4 * 4; + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_f32.assign(total, 0.0F); + + auto idx_wei_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + auto idx_wei_pack = [&](size_t oc_global, size_t icg, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t blk = icg / 4; + const size_t lane = icg % 4; + return base + blk * 4 + lane; + }; + + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_f32[idx_wei_pack(oc_global, icg, kz, ky, kx)] = + wsrc[idx_wei_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + } + m_wei_packed_ready_f32 = true; + return; + } +} + +// Alternative even/odd packing for S=2 (FP16) +void JitDeconv3DExecutor::ensure_weights_packed_s2_f16(const std::vector& src) { + if (m_wei_packed_s2_ready_f16) + return; + const auto& weiDims = src[1]->getStaticDims(); + const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f16 = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_s2_f16.assign(total, static_cast(0)); + auto idx_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + // evens + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_s2_f16[base + (ic / 8) * 8 + (ic % 8)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; + } + } + // odds + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_s2_f16[base + (ic / 8) * 8 + (ic % 8)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_s2_ready_f16 = true; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f16 = (ICg + 7) / 8 * 8; + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_s2_f16.assign(total, static_cast(0)); + auto idx_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) { + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_s2_f16[base + (icg / 8) * 8 + (icg % 8)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_s2_f16[base + (icg / 8) * 8 + (icg % 8)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + } + m_wei_packed_s2_ready_f16 = true; + } +} + + +void JitDeconv3DExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* /*post_ops_data_*/) { + if (m_is_fp32) { + exec_fp32(src, dst); + } else { + exec_fp16(src, dst); + } +} + + +void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { + if (src.size() < 2 || !src[0] || !src[1] || !src[0]->getDescPtr() || !src[1]->getDescPtr()) + return; + const auto& s = src[0]->getDescPtr()->getShape(); + const auto& w = src[1]->getDescPtr()->getShape(); + if (!s.isStatic() || !w.isStatic()) + return; + if (m_is_fp32) { + ensure_weights_packed_f32(src); + } else { + ensure_weights_packed_f16(src); + ensure_weights_packed_s2_f16(src); + } +} + +void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const std::vector& dst) { + // NCDHW, fp16: compute each output pixel (n, oc, od, oh, ow) as a sum over (ic, kz, ky, kx) + const auto& srcDims = src[0]->getStaticDims(); + const auto& weiDims = src[1]->getStaticDims(); + const auto& dstDims = dst[0]->getStaticDims(); + + const size_t N = srcDims[0]; + const size_t IC = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + // Use OC from dst to support grouped weights layout + const size_t OC = dstDims[1]; + // Weights: no-group [IC, OC, KD, KH, KW]; grouped [G, ICg, OCg, KD, KH, KW] + const bool grouped = weiDims.size() == 6; + [[maybe_unused]] const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + const size_t KD = weiDims[grouped ? 3 : 2], KH = weiDims[grouped ? 4 : 3], KW = weiDims[grouped ? 5 : 4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; + const size_t SH = deconvAttrs.stride.size() > 1 ? static_cast(deconvAttrs.stride[1]) : 1; + const size_t SW = deconvAttrs.stride.size() > 2 ? static_cast(deconvAttrs.stride[2]) : 1; + + const ptrdiff_t PD0 = deconvAttrs.paddingL.size() > 0 ? deconvAttrs.paddingL[0] : 0; + const ptrdiff_t PH0 = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL[1] : 0; + const ptrdiff_t PW0 = deconvAttrs.paddingL.size() > 2 ? deconvAttrs.paddingL[2] : 0; + + const auto* src_p = reinterpret_cast(src[0]->getData()); + const auto* wei_p = reinterpret_cast(src[1]->getData()); + auto* dst_p = reinterpret_cast(dst[0]->getData()); + + auto idx_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * IC + c) * ID + z) * IH + y) * IW + x; + }; + auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + // weight: no-group [IC, OC, KD, KH, KW]; grouped [G, ICg, OCg, KD, KH, KW] + auto idx_wei = [&](size_t ic_or_icg, size_t oc_global, size_t kz, size_t ky, size_t kx) { + if (!grouped) { + return ((((ic_or_icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + const size_t g = oc_global / OCg; + const size_t ocg = oc_global % OCg; + return ((((((g * ICg + ic_or_icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + + // Strides in elements + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; + + // Always prepare packed weights (standard + S=2 even/odd) + ensure_weights_packed_f16(src); + ensure_weights_packed_s2_f16(src); + + // Effective dilations are stored as (dilation - 1) inside attrs; convert to actual factors + const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; + const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; + const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; + + auto worker = [&](size_t n, size_t oc_quad, size_t od) { + const size_t oc0 = oc_quad * 4; + const size_t g = OCg ? (oc0 / OCg) : 0; + const size_t ocg0 = OCg ? (oc0 % OCg) : oc0; + const size_t oc1 = oc0 + 1; + const size_t oc2 = oc0 + 2; + const size_t oc3 = oc0 + 3; + const bool has_oc1 = (ocg0 + 1) < OCg && oc1 < OC; + const bool has_oc2 = (ocg0 + 2) < OCg && oc2 < OC; + const bool has_oc3 = (ocg0 + 3) < OCg && oc3 < OC; + const size_t n_base = n * IC * ID * IH * IW; + { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow_ = 0; ow_ < OW; ++ow_) { + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; + + if (SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { + // Fast path: contiguous tap ranges, no modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID) + 1); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH) + 1); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW) + 1); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t id = static_cast(tzd - kz); + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + (void)kx_hi; + (void)kx_lo; + if (m_wei_packed_ready_f16) { + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t ihh = static_cast(tyd - ky); + const size_t s_base_x0 = s_base_row + ihh * IW + static_cast(txd); + // Precompute ky-dependent packed bases (no kx loop in-kernel) + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + // Start from rightmost tap to keep src_dx positive (+1 element per kx) + const size_t s_base0 = s_base_x0 - static_cast(kx_hi); + // Packed weights advance by padded_IC per kx; start from leftmost kx + const size_t base0 = (py0 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? (py1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + // pair 0: oc0/oc1 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = m_wei_packed_f16.data() + base0; + if (has_oc1) + a.wei2 = m_wei_packed_f16.data() + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1: oc2/oc3 + if (has_oc2) { + const size_t pz2 = (oc2 * KD + static_cast(kz)) * KH; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + const size_t py2 = (pz2 + static_cast(ky)) * KW; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const size_t base2 = (py2 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? (py3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = m_wei_packed_f16.data() + base2; + if (has_oc3) + a.wei2 = m_wei_packed_f16.data() + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } else { + { + // In-kernel kx only + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t ihh2 = static_cast(tyd - ky); + const size_t s_base_row2 = n_base + (g * ICg) * src_c_stride_elems + src_z_off + ihh2 * IW; + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + const size_t s_base0 = s_base_row2 + static_cast(txd - kx_hi); + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + a.wei = wei_p + w_base0; + if (has_oc1) a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + a.wei = wei_p + w_base2; + if (has_oc3) a.wei2 = wei_p + w_base3; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } + } + } + } + } else if (SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1) { + // Fast path S=2, dil=1 (packed weights): parity-filtered taps without modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + // X2 micro-tiling over output width for stride=2: compute (ow, ow+2) together when possible + if ((ow_ + 2) < OW) { + float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ + float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 + + const ptrdiff_t txd1 = static_cast(ow_ + 2) + PW0; + const ptrdiff_t kx_lo1 = std::max(0, txd1 - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi1 = std::min(static_cast(KW) - 1, txd1); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + + + + // Even/odd S=2 packing selection + const uint16_t* wei_pack_ptr_tile2 = m_wei_packed_s2_f16.data(); + auto pack_index_eo_tile2 = [&](size_t py, size_t kx) { + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + + // Pass A: kx subset for txd (ow_) + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw0 = static_cast((txd - kx) / 2); + if (iw0 >= IW) continue; + const size_t iw1 = iw0 + 1; // for ow_+2 + const size_t s_base0 = s_base_row + ih * IW + iw0; + // pair 0 for ow_ + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0a; + a.acc2 = has_oc1 ? &acc1a : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + if (iw1 < IW) { + const size_t s_base1 = s_base0 + 1; + // pair 0 for ow_+2 + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + } + // pair 1 (oc2/oc3), ow_ + if (has_oc2) { + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2a; + a.acc2 = has_oc3 ? &acc3a : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + // pair 1 for ow_+2 + if (has_oc2 && (iw1 < IW)) { + const size_t s_base1 = s_base0 + 1; + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + } + + // Pass B: extra kx subset for txd1 (ow_+2) only (complement of Pass A) + for (ptrdiff_t kx = kx_lo1 + ((txd1 - kx_lo1) & 1); kx <= kx_hi1; kx += 2) { + const ptrdiff_t iw0_tmp = (txd - kx) / 2; // may be out-of-range or wrong parity for ow_ + const bool covered_in_A = (kx >= kx_lo && kx <= kx_hi && (((txd - kx) & 1) == 0) && (iw0_tmp >= 0 && iw0_tmp < static_cast(IW))); + if (covered_in_A) continue; // already accumulated in Pass A for ow_+2 + const ptrdiff_t iw1_tmp = (txd1 - kx) / 2; + if (iw1_tmp < 0 || iw1_tmp >= static_cast(IW)) continue; + const size_t iw1 = static_cast(iw1_tmp); + const size_t s_base1 = s_base_row + ih * IW + iw1; + // pair 0 for ow_+2 only + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + if (has_oc2) { + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + } + } + } + } + + // Optional fused bias for both outputs (ow_, ow_+2) + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += bias_ptr[oc0]; + if (has_oc1) acc1a += bias_ptr[oc1]; + if (has_oc2) acc2a += bias_ptr[oc2]; + if (has_oc3) acc3a += bias_ptr[oc3]; + acc0b += bias_ptr[oc0]; + if (has_oc1) acc1b += bias_ptr[oc1]; + if (has_oc2) acc2b += bias_ptr[oc2]; + if (has_oc3) acc3b += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1a += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2a += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3a += static_cast(ov::float16(bias_ptr[oc3])); + acc0b += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1b += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2b += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3b += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + // Store both outputs + dst_p[idx_dst(n, oc0, od, oh, ow_)] = ov::float16(acc0a).to_bits(); + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1a).to_bits(); + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2a).to_bits(); + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3a).to_bits(); + + const size_t ow2 = ow_ + 2; + dst_p[idx_dst(n, oc0, od, oh, ow2)] = ov::float16(acc0b).to_bits(); + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow2)] = ov::float16(acc1b).to_bits(); + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow2)] = ov::float16(acc2b).to_bits(); + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow2)] = ov::float16(acc3b).to_bits(); + + ow_ += 2; // skip next two positions (we computed ow and ow+2); for-loop ++ will advance to ow+3 + continue; + } + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const uint16_t* wei_pack_ptr = m_wei_packed_s2_f16.data(); + auto pack_index_eo = [&](size_t py, size_t kx) { + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + const ptrdiff_t kx_start = kx_lo + ((txd - kx_lo) & 1); + const size_t iw_start = static_cast((txd - kx_start) / 2); + if (iw_start >= IW) continue; + const size_t s_base0 = s_base_row + ih * IW + iw_start; + const size_t kw_count = static_cast((kx_hi - kx_start) / 2 + 1); + // pair 0 + { + const size_t base0 = pack_index_eo(py0, static_cast(kx_start)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo(py1, static_cast(kx_start)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = static_cast(-static_cast(sizeof(uint16_t))); + a.wei = wei_pack_ptr + base0; + if (has_oc1) a.wei2 = wei_pack_ptr + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + const size_t base2 = pack_index_eo(py2, static_cast(kx_start)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo(py3, static_cast(kx_start)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = static_cast(-static_cast(sizeof(uint16_t))); + a.wei = wei_pack_ptr + base2; + if (has_oc3) a.wei2 = wei_pack_ptr + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } + { + // Per-tap parity stepping + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const uint16_t* wei_pack_ptr_orig = m_wei_packed_s2_f16.data(); + auto pack_index_eo_orig = [&](size_t py, size_t kx) { + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw = static_cast((txd - kx) / 2); + if (iw >= IW) continue; + const size_t s_base0 = s_base_row + ih * IW + iw; + // pair 0 + { + const size_t base0 = pack_index_eo_orig(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_orig(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_orig + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_orig + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + const size_t base2 = pack_index_eo_orig(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_orig(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_orig + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_orig + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + (*m_ip_kernel_f16)(&a); + } + } + } + } + } + } + + } else { + // Generic path (stride/dilation) + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t id_num = + static_cast(od) + PD0 - static_cast(kz * dilD); + if (SD == 0) + continue; + if (id_num % static_cast(SD) != 0) + continue; + const ptrdiff_t id_idx = id_num / static_cast(SD); + if (id_idx < 0 || id_idx >= static_cast(ID)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy_num = + static_cast(oh) + PH0 - static_cast(ky * dilH); + if (SH == 0) + continue; + if (iy_num % static_cast(SH) != 0) + continue; + const ptrdiff_t ih_idx = iy_num / static_cast(SH); + if (ih_idx < 0 || ih_idx >= static_cast(IH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix_num = + static_cast(ow_) + PW0 - static_cast(kx * dilW); + if (SW == 0) + continue; + if (ix_num % static_cast(SW) != 0) + continue; + const ptrdiff_t iw_idx = ix_num / static_cast(SW); + if (iw_idx < 0 || iw_idx >= static_cast(IW)) + continue; + + const size_t s_base0 = idx_src(n, + g * ICg, + static_cast(id_idx), + static_cast(ih_idx), + static_cast(iw_idx)); + + auto run_pair = [&](float* acc, float* acc2, const uint16_t* w0, const uint16_t* w1) { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = acc; + a.acc2 = acc2; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = w0; + if (w1) a.wei2 = w1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + (*m_ip_kernel_f16)(&a); + }; + const size_t base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16 : 0; + run_pair(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f16.data() + base0, + has_oc1 ? m_wei_packed_f16.data() + base1 : nullptr); + + if (has_oc2) { + const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16 : 0; + run_pair(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f16.data() + base2, + has_oc3 ? m_wei_packed_f16.data() + base3 : nullptr); + } + } + } + } + } + // Optional fused bias for deconv + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += bias_ptr[oc0]; + if (has_oc1) + acc1 += bias_ptr[oc1]; + if (has_oc2) + acc2 += bias_ptr[oc2]; + if (has_oc3) + acc3 += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) + acc1 += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) + acc2 += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) + acc3 += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + dst_p[idx_dst(n, oc0, od, oh, ow_)] = ov::float16(acc0).to_bits(); + if (has_oc1) + dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1).to_bits(); + if (has_oc2) + dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2).to_bits(); + if (has_oc3) + dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3).to_bits(); + } + } + } + }; + + ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); +} + +void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const std::vector& dst) { + // NCDHW, f32 + const auto& srcDims = src[0]->getStaticDims(); + const auto& weiDims = src[1]->getStaticDims(); + const auto& dstDims = dst[0]->getStaticDims(); + + const size_t N = srcDims[0]; + const size_t IC = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = dstDims[1]; + const bool grouped = weiDims.size() == 6; + [[maybe_unused]] const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + const size_t KD = weiDims[grouped ? 3 : 2], KH = weiDims[grouped ? 4 : 3], KW = weiDims[grouped ? 5 : 4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; + const size_t SH = deconvAttrs.stride.size() > 1 ? static_cast(deconvAttrs.stride[1]) : 1; + const size_t SW = deconvAttrs.stride.size() > 2 ? static_cast(deconvAttrs.stride[2]) : 1; + + const ptrdiff_t PD0 = deconvAttrs.paddingL.size() > 0 ? deconvAttrs.paddingL[0] : 0; + const ptrdiff_t PH0 = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL[1] : 0; + const ptrdiff_t PW0 = deconvAttrs.paddingL.size() > 2 ? deconvAttrs.paddingL[2] : 0; + + const auto* src_p = reinterpret_cast(src[0]->getData()); + const auto* wei_p = reinterpret_cast(src[1]->getData()); + auto* dst_p = reinterpret_cast(dst[0]->getData()); + + auto idx_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * IC + c) * ID + z) * IH + y) * IW + x; + }; + auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + auto idx_wei = [&](size_t ic_or_icg, size_t oc_global, size_t kz, size_t ky, size_t kx) { + if (!grouped) { + return ((((ic_or_icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + const size_t g = oc_global / OCg; + const size_t ocg = oc_global % OCg; + return ((((((g * ICg + ic_or_icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; + + ensure_weights_packed_f32(src); + // Output dilations + const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; + const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; + const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; + + + + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t g = OCg ? (oc0 / OCg) : 0; + const size_t ocg0 = OCg ? (oc0 % OCg) : oc0; + const size_t oc1 = std::min(oc0 + 1, OC); + const size_t oc2 = std::min(oc0 + 2, OC); + const size_t oc3 = std::min(oc0 + 3, OC); + const bool has_oc1 = (ocg0 + 1) < OCg && oc1 < OC; + const bool has_oc2 = (ocg0 + 2) < OCg && oc2 < OC; + const bool has_oc3 = (ocg0 + 3) < OCg && oc3 < OC; + + for (size_t od = 0; od < OD; ++od) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow_ = 0; ow_ < OW; ++ow_) { + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; + + if (SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { + // contiguous tap range in each dimension + const ptrdiff_t tz_pos = static_cast(od) + PD0; + const ptrdiff_t ty_pos = static_cast(oh) + PH0; + const ptrdiff_t tx_pos = static_cast(ow_) + PW0; + const ptrdiff_t kz_lo = std::max(0, tz_pos - static_cast(ID) + 1); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tz_pos); + const ptrdiff_t ky_lo = std::max(0, ty_pos - static_cast(IH) + 1); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, ty_pos); + const ptrdiff_t kx_lo = std::max(0, tx_pos - static_cast(IW) + 1); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, tx_pos); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const auto iz_idx = static_cast(tz_pos - kz); + const auto ky_base = static_cast(ky_lo); + const auto iy0 = static_cast(ty_pos - ky_lo); + const auto ix0 = static_cast(tx_pos - kx_lo); + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy_idx = static_cast(ty_pos - ky); + const size_t ix_idx = ix0; + (void)iy0; + (void)ky_base; + const size_t s_base = idx_src(n, g * ICg, iz_idx, iy_idx, ix_idx); + + // pair 0 + { + jit_conv3d_f32_call_args args{}; + args.src = src_p + s_base; + args.src_stride = src_c_stride_elems * sizeof(float); + args.src_blk_stride = args.src_stride * 4; + args.acc = &acc0; + args.acc2 = has_oc1 ? &acc1 : nullptr; + args.repeats = ICg / 4; + args.tail = ICg % 4; + args.kw_cnt = kw_count; + args.src_dx = static_cast(-static_cast(sizeof(float))); + if (true) { + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args.wei2 = m_wei_packed_f32.data() + base1; + } + args.wei_stride = sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; + args.wei_dx = m_padded_IC_f32 * sizeof(float); + } else { /* unreachable: raw weights path removed */ } + (*m_ip_kernel_f32)(&args); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args args2{}; + args2.src = src_p + s_base; + args2.src_stride = src_c_stride_elems * sizeof(float); + args2.src_blk_stride = args2.src_stride * 4; + args2.acc = &acc2; + args2.acc2 = has_oc3 ? &acc3 : nullptr; + args2.repeats = ICg / 4; + args2.tail = ICg % 4; + args2.kw_cnt = kw_count; + args2.src_dx = static_cast(-static_cast(sizeof(float))); + if (true) { + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args2.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args2.wei2 = m_wei_packed_f32.data() + base3; + } + args2.wei_stride = sizeof(float); + args2.wei_blk_stride = args2.wei_stride * 4; + args2.wei_dx = m_padded_IC_f32 * sizeof(float); + } else { /* unreachable: raw weights path removed */ } + (*m_ip_kernel_f32)(&args2); + } + } + } + } + } else if (SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1) { + // Fast path S=2, dil=1 (packed weights preferred): parity-filtered taps without modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + // X2 micro-tiling over output width for stride=2: compute (ow_, ow_+2) together (disabled for FP32) + if (false && (ow_ + 2) < OW) { + float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ + float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 + const ptrdiff_t txd1 = static_cast(ow_ + 2) + PW0; + const ptrdiff_t kx_lo1 = std::max(0, txd1 - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi1 = std::min(static_cast(KW) - 1, txd1); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n * IC * ID * IH * IW + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + + // Pass A: main kx set valid for ow_ + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw0 = static_cast((txd - kx) / 2); + if (iw0 >= IW) continue; + const size_t iw1 = iw0 + 1; // for ow_+2 + const size_t s_base0 = s_base_row + ih * IW + iw0; + // pair 0 for ow_ + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0a; + a.acc2 = has_oc1 ? &acc1a : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + // For ow_+2 if in-bounds + if (iw1 < IW) { + const size_t s_base1 = s_base0 + 1; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base0; + if (has_oc1) { + const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base1; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + // pair 1 (oc2/oc3) for ow_ + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2a; + a.acc2 = has_oc3 ? &acc3a : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + // pair 1 for ow_+2 + if (has_oc2 && (iw1 < IW)) { + const size_t s_base1_b = s_base0 + 1; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1_b; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + } + + // Pass B: extra kx set valid only для ow_+2 (комплемент по паритету/границам) + for (ptrdiff_t kx = kx_lo1 + ((txd1 - kx_lo1) & 1); kx <= kx_hi1; kx += 2) { + const ptrdiff_t iw0_tmp = (txd - kx) / 2; + const bool covered_in_A = (kx >= kx_lo && kx <= kx_hi && (((txd - kx) & 1) == 0) && + (iw0_tmp >= 0 && iw0_tmp < static_cast(IW))); + if (covered_in_A) continue; + const ptrdiff_t iw1_tmp = (txd1 - kx) / 2; + if (iw1_tmp < 0 || iw1_tmp >= static_cast(IW)) continue; + const size_t iw1 = static_cast(iw1_tmp); + const size_t s_base1 = s_base_row + ih * IW + iw1; + // pair 0 for ow_+2 only + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + } + } + } + } + + // Optional fused bias for both outputs + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += bias_ptr[oc0]; + if (has_oc1) acc1a += bias_ptr[oc1]; + if (has_oc2) acc2a += bias_ptr[oc2]; + if (has_oc3) acc3a += bias_ptr[oc3]; + acc0b += bias_ptr[oc0]; + if (has_oc1) acc1b += bias_ptr[oc1]; + if (has_oc2) acc2b += bias_ptr[oc2]; + if (has_oc3) acc3b += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1a += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2a += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3a += static_cast(ov::float16(bias_ptr[oc3])); + acc0b += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1b += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2b += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3b += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + // Store results for both outputs + dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0a; + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1a; + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2a; + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3a; + + const size_t ow2 = ow_ + 2; + dst_p[idx_dst(n, oc0, od, oh, ow2)] = acc0b; + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow2)] = acc1b; + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow2)] = acc2b; + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow2)] = acc3b; + + ow_ += 2; // skip next two positions; for-loop ++ will advance to ow_+3 + continue; + } + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw = static_cast((txd - kx) / 2); + if (iw >= IW) continue; + // Base source offset for this (id, ih, iw) + const size_t s_base0 = idx_src(n, g * ICg, id, ih, iw); + // pair 0 (oc0, oc1) + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + // pair 1 (oc2, oc3) + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (true) { + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { /* unreachable */ } + (*m_ip_kernel_f32)(&a); + } + } + } + } + } + } else { + // Generic path (stride/dilation) + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t id_num = + static_cast(od) + PD0 - static_cast(kz * dilD); + if (SD == 0) + continue; + if (id_num % static_cast(SD) != 0) + continue; + const ptrdiff_t id_idx = id_num / static_cast(SD); + if (id_idx < 0 || id_idx >= static_cast(ID)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy_num = + static_cast(oh) + PH0 - static_cast(ky * dilH); + if (SH == 0) + continue; + if (iy_num % static_cast(SH) != 0) + continue; + const ptrdiff_t ih_idx = iy_num / static_cast(SH); + if (ih_idx < 0 || ih_idx >= static_cast(IH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix_num = + static_cast(ow_) + PW0 - static_cast(kx * dilW); + if (SW == 0) + continue; + if (ix_num % static_cast(SW) != 0) + continue; + const ptrdiff_t iw_idx = ix_num / static_cast(SW); + if (iw_idx < 0 || iw_idx >= static_cast(IW)) + continue; + + const size_t s_base0 = idx_src(n, + g * ICg, + static_cast(id_idx), + static_cast(ih_idx), + static_cast(iw_idx)); + + auto run_pair_f32 = [&](float* acc, float* acc2, const float* w0, const float* w1) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = acc; + a.acc2 = acc2; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = w0; + if (w1) a.wei2 = w1; + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + }; + const size_t pb0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pb1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32 : 0; + run_pair_f32(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f32.data() + pb0, + has_oc1 ? m_wei_packed_f32.data() + pb1 : nullptr); + if (has_oc2) { + const size_t pb2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pb3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32 : 0; + run_pair_f32(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f32.data() + pb2, + has_oc3 ? m_wei_packed_f32.data() + pb3 : nullptr); + } + } + } + } + } + // Optional bias (support f32 or f16 input bias) + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += bias_ptr[oc0]; + if (has_oc1) + acc1 += bias_ptr[oc1]; + if (has_oc2) + acc2 += bias_ptr[oc2]; + if (has_oc3) + acc3 += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) + acc1 += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) + acc2 += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) + acc3 += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0; + if (has_oc1) + dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; + if (has_oc2) + dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2; + if (has_oc3) + dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3; + } + } + } + }); +} + +bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const { + // Support 5D NCDHW, fp16 and fp32 + if (srcDescs.size() < 2 || dstDescs.empty()) + return false; + const auto src0_rank = srcDescs[0]->getShape().getRank(); + const auto wei_rank = srcDescs[1]->getShape().getRank(); + const auto dst0_rank = dstDescs[0]->getShape().getRank(); + if (src0_rank != 5 || (wei_rank != 5 && wei_rank != 6) || dst0_rank != 5) { + return false; + } + const auto src0_prec = srcDescs[0]->getPrecision(); + const auto src1_prec = srcDescs[1]->getPrecision(); + const auto dst0_prec = dstDescs[0]->getPrecision(); + const bool fp16_ok = + (src0_prec == ov::element::f16 && src1_prec == ov::element::f16 && dst0_prec == ov::element::f16); + const bool fp32_ok = + (src0_prec == ov::element::f32 && src1_prec == ov::element::f32 && dst0_prec == ov::element::f32); + return fp16_ok || fp32_ok; +} + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp new file mode 100644 index 00000000000000..b9d961f5693d48 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -0,0 +1,89 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "nodes/executors/aarch64/jit_conv3d.hpp" +#include "nodes/executors/deconv.hpp" + +namespace ov::intel_cpu { + +class JitDeconv3DExecutor : public DeconvExecutor { +public: + explicit JitDeconv3DExecutor(ExecutorContext::CPtr context) : DeconvExecutor(std::move(context)) {} + // Early weight preparation in ctor (src[0]=input, src[1]=weights; guards dynamic shapes) + JitDeconv3DExecutor(const std::vector& src, ExecutorContext::CPtr context) + : DeconvExecutor(std::move(context)) { + if (!src.empty() && src[0] && src[0]->getDescPtr()) { + const auto prec = src[0]->getDescPtr()->getPrecision(); + m_is_fp32 = (prec == ov::element::f32); + } + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + m_ip_kernel_f16 = std::make_unique(); + m_ip_kernel_f16->create_ker(); + } + prepare_weights_early(src); + } + ~JitDeconv3DExecutor() override = default; + + bool init(const DeconvAttrs& deconvAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) override; + + void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) override; + + [[nodiscard]] impl_desc_type getImplType() const override { + return impl_desc_type::jit_asimd; + } + + void prepare_weights_early(const std::vector& src); + +private: + std::vector m_srcDescs; + std::vector m_dstDescs; + std::unique_ptr m_ip_kernel_f16; + std::unique_ptr m_ip_kernel_f32; + bool m_is_fp32{false}; + + std::vector m_wei_packed_f16; + std::vector m_wei_packed_f32; + std::vector m_wei_packed_s2_f16; + bool m_wei_packed_ready_f16{false}; + bool m_wei_packed_ready_f32{false}; + bool m_wei_packed_s2_ready_f16{false}; + size_t m_padded_IC_f16{0}; + size_t m_padded_IC_f32{0}; + + void ensure_weights_packed_f16(const std::vector& src); + void ensure_weights_packed_f32(const std::vector& src); + void ensure_weights_packed_s2_f16(const std::vector& src); + void exec_fp16(const std::vector& src, const std::vector& dst); + void exec_fp32(const std::vector& src, const std::vector& dst); +}; + +class AArch64JitDeconvExecutorBuilder : public DeconvExecutorBuilder { +public: + [[nodiscard]] bool isSupported(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const override; + [[nodiscard]] DeconvExecutorPtr makeExecutor(ExecutorContext::CPtr context) const override { + return std::make_shared(context); + } + // Helper to create executor with early packing in constructor + [[nodiscard]] DeconvExecutorPtr makeExecutorWithMem(ExecutorContext::CPtr context, + const std::vector& src) const { + return std::make_shared(src, context); + } +}; + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp index 78a45bd10bb76f..50c66c2f03b89e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp @@ -32,6 +32,9 @@ #if defined(OV_CPU_WITH_ACL) # include "nodes/executors/acl/acl_conv.hpp" #endif +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_conv3d.hpp" +#endif namespace ov::intel_cpu { @@ -101,6 +104,16 @@ struct CreateOptimalConfigAclLowp { template <> const std::vector>& getImplementations() { static const std::vector> convolutionImplementations { + OV_CPU_INSTANCE_ARM64( + "convolution_jit_aarch64_3d_ncsp", ExecutorType::Jit, OperationType::Convolution, + [](const ConvConfig& config, [[maybe_unused]] const MemoryFormatFilter& memoryFormatFilter) -> bool { + return JitConv3DExecutor::supports(config); + }, + // Request plain ncsp layouts + CreateOptimalConfigDefault{{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}}, + AcceptsAnyShape, + CreateDefault{} + ) OV_CPU_INSTANCE_DNNL_X64( "convolution_dnnl_nspc_nspc", ExecutorType::Dnnl, OperationType::Convolution, // supports @@ -245,6 +258,20 @@ const std::vector>& getImplementations() { AcceptsAnyShape, CreateDnnlDefault{} ) + OV_CPU_INSTANCE_DNNL_X64( + "convolution_dnnl_nspc_nspc_unconditional_x64", ExecutorType::Dnnl, OperationType::Convolution, + // supports + [](const ConvConfig& config, const MemoryFormatFilter& memoryFormatFilter) -> bool { + // Unconditionally allow nspc path for x64 for shapes where backup may decline. + VERIFY(MatchesMemoryFormatFilter(config.descs, LayoutConfig{LayoutType::nspc, LayoutType::ncsp, LayoutType::nspc, LayoutType::nspc}, + memoryFormatFilter, dnnlConvolutionMappingNotation), MEMORY_FORMAT_MISMATCH); + VERIFY(!isQuantized(config), UNSUPPORTED_SRC_PRECISIONS); + return true; + }, + CreateOptimalConfigDefault{{LayoutType::nspc, LayoutType::ncsp, LayoutType::nspc, LayoutType::nspc}}, + AcceptsAnyShape, + CreateDnnlDefault{} + ) OV_CPU_INSTANCE_ACL( "convolution_dnnl_nspc_nspc_unconditional_acl", ExecutorType::Dnnl, OperationType::Convolution, // supports diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp index 2a5d7a087610f2..12b6b34737061c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp @@ -7,6 +7,9 @@ #include #include "utils/arch_macros.h" +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif #if defined(OV_CPU_WITH_ACL) # include @@ -19,7 +22,9 @@ namespace ov::intel_cpu { const std::vector& getDeconvExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; + // Prefer ACL builder first for stability/perf; fallback to AArch64 JIT if ACL not supported + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) + OV_CPU_INSTANCE_ARM64(ExecutorType::Jit, std::make_shared())}; return descs; } diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp index cfa7eb0c8549d8..753c404b7b8d6f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp @@ -15,6 +15,9 @@ #if defined(OV_CPU_WITH_ACL) # include "acl/acl_deconv.hpp" #endif +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif namespace ov::intel_cpu { @@ -69,6 +72,47 @@ class DeconvExecutorFactory : public ExecutorFactoryLegacy { OPENVINO_THROW("DeconvExecutorFactory: Supported executor is not found"); } + // ARM64 helper: build executor and allow constructor-time early packing using input/weight memories + DeconvExecutorPtr makeExecutorWithMem(const DeconvAttrs& deconvAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr, + const std::vector& srcMemories) { + auto build = [&](const DeconvExecutorDesc* desc) -> DeconvExecutorPtr { + // If this is our AArch64 JIT builder, construct with memories to trigger early packing in ctor +#if defined(OPENVINO_ARCH_ARM64) + if (auto jitBuilder = std::dynamic_pointer_cast(desc->builder)) { + auto executor = jitBuilder->makeExecutorWithMem(context, srcMemories); + if (executor->init(deconvAttrs, srcDescs, dstDescs, attr)) { + return executor; + } + } +#endif + // Fallback to regular path + auto executor = desc->builder->makeExecutor(context); + if (executor->init(deconvAttrs, srcDescs, dstDescs, attr)) { + return executor; + } + DeconvExecutorPtr ptr = nullptr; + return ptr; + }; + + if (chosenDesc) { + if (auto executor = build(chosenDesc)) { + return executor; + } + } + + for (const auto& sd : supportedDescs) { + if (auto executor = build(&sd)) { + chosenDesc = &sd; + return executor; + } + } + + OPENVINO_THROW("DeconvExecutorFactory: Supported executor is not found (with memories)"); + } + private: std::vector supportedDescs; const DeconvExecutorDesc* chosenDesc = nullptr;