From 203605d62bb3c52fd574b342ad28b520fab6071e Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 16 Oct 2025 13:50:35 +0200 Subject: [PATCH 01/20] Add support for AArch64 JIT-based 3D Deconvolution and Convolution Executors --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 92 + .../nodes/executors/aarch64/jit_conv3d.cpp | 1485 +++++++++++++++++ .../nodes/executors/aarch64/jit_conv3d.hpp | 109 ++ .../nodes/executors/aarch64/jit_deconv3d.cpp | 339 ++++ .../nodes/executors/aarch64/jit_deconv3d.hpp | 51 + .../executors/convolution_implementations.cpp | 14 + .../src/nodes/executors/deconv_list.cpp | 8 +- 7 files changed, 2097 insertions(+), 1 deletion(-) create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 6b4d08e1ab7b5f..ba3bed2e65ea66 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -1296,7 +1296,99 @@ bool Deconvolution::canFuseBias() const { } void Deconvolution::initSupportedPrimitiveDescriptors() { + // Prefer AArch64 JIT deconv for 5D FP16 on ARM64 regardless of ACL +#if defined(OPENVINO_ARCH_ARM64) + { + const auto rank = getInputShapeAtPort(0).getRank(); + const bool is5D = (rank == 5); + const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && + getOriginalInputPrecisionAtPort(1) == ov::element::f16 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f16; + if (is5D && fp16_ok) { + auto [inDims, outDims] = makeDummyInOutShape(); + auto tmpInShape = Shape(inDims); + auto tmpOutShape = Shape(outDims); + initPaddingR(tmpInShape, tmpOutShape); + + const auto& creatorsMap = BlockedDescCreator::getCommonCreators(); + NodeConfig config; + config.inConfs.resize(getParentEdges().size()); + config.outConfs.resize(getOriginalOutputsNumber()); + + auto setDesc = [&](size_t port, bool isInput) { + const auto prec = isInput ? getOriginalInputPrecisionAtPort(port) + : getOriginalOutputPrecisionAtPort(port); + const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); + auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); + if (isInput) config.inConfs[port].setMemDesc(d); + else config.outConfs[port].setMemDesc(d); + }; + setDesc(0, true); + setDesc(1, true); + for (size_t i = 2; i < getParentEdges().size(); ++i) setDesc(i, true); + setDesc(0, false); + + std::vector srcMemoryDescs; + srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); + for (size_t i = 1; i < config.inConfs.size(); i++) srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); + std::vector dstMemoryDescs; + dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); + for (size_t i = 1; i < config.outConfs.size(); i++) dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); + + auto factory = std::make_shared( + deconvAttrs, srcMemoryDescs, dstMemoryDescs, + std::make_shared(context, getImplPriority())); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); + useACL = true; // reuse factory-based execution path + return; + } + } +#endif + // If ACL path is not selected, try AArch64 JIT factory for 5D FP16 if (!useACL) { +#if defined(OPENVINO_ARCH_ARM64) + const auto rank = getInputShapeAtPort(0).getRank(); + const bool is5D = (rank == 5); + const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && + getOriginalInputPrecisionAtPort(1) == ov::element::f16 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f16; + if (is5D && fp16_ok) { + auto [inDims, outDims] = makeDummyInOutShape(); + auto tmpInShape = Shape(inDims); + auto tmpOutShape = Shape(outDims); + initPaddingR(tmpInShape, tmpOutShape); + + const auto& creatorsMap = BlockedDescCreator::getCommonCreators(); + NodeConfig config; + config.inConfs.resize(getParentEdges().size()); + config.outConfs.resize(getOriginalOutputsNumber()); + + auto setDesc = [&](size_t port, const Shape& shape, bool isInput) { + const auto prec = isInput ? getOriginalInputPrecisionAtPort(port) + : getOriginalOutputPrecisionAtPort(port); + const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); + auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); + if (isInput) config.inConfs[port].setMemDesc(d); else config.outConfs[port].setMemDesc(d); + }; + setDesc(0, tmpInShape, true); + setDesc(1, Shape(getInputShapeAtPort(1).getStaticDims()), true); + for (size_t i = 2; i < getParentEdges().size(); ++i) setDesc(i, Shape(getInputShapeAtPort(i).getStaticDims()), true); + setDesc(0, tmpOutShape, false); + + std::vector srcMemoryDescs; + srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); + for (size_t i = 1; i < config.inConfs.size(); i++) srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); + std::vector dstMemoryDescs; + dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); + for (size_t i = 1; i < config.outConfs.size(); i++) dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); + + auto factory = std::make_shared( + deconvAttrs, srcMemoryDescs, dstMemoryDescs, + std::make_shared(context, getImplPriority())); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); + return; + } +#endif Node::initSupportedPrimitiveDescriptors(); return; } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp new file mode 100644 index 00000000000000..bb1d417dc324d6 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -0,0 +1,1485 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/jit_conv3d.hpp" + +#include +#include +#include +#include +#include +#include + +#include "cpu_memory.h" +#include "memory_desc/cpu_memory_desc.h" +#include "nodes/executors/implementation_utils.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" +#include "utils/general_utils.h" +#include "openvino/core/type/float16.hpp" +// helper for jit_kernel_cast +#include "utils/cpu_utils.hpp" +// no direct NEON intrinsics are used here; we rely on Xbyak_aarch64 + +using namespace dnnl::impl::cpu::aarch64; + +namespace ov::intel_cpu { + +// --------------------------- JIT kernel (placeholder) --------------------------- +JitConv3DKernelF16::JitConv3DKernelF16() {} + +void JitConv3DKernelF16::create_ker() { + jit_generator::create_kernel(); + ker_ = jit_kernel_cast(jit_ker()); +} + +void JitConv3DKernelF16::generate() { + using namespace Xbyak_aarch64; + + // Stable minimal kernel: dual-OC or single-OC accumulation over C in 8-lane blocks + tail. + // Avoid callee-saved registers and any in-kernel spatial loops to ensure ABI safety on macOS arm64. + + { + const XReg reg_args = abi_param1; // x0 + // Load essential arguments (absolute offsets from jit_conv3d_call_args) + const XReg reg_src = x1; // const uint16_t* src + const XReg reg_wei = x2; // const uint16_t* wei + const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats + const XReg reg_tail = x5; // size_t tail + const XReg reg_src_stride = x6; // size_t src_stride (bytes) + const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) + const XReg reg_acc = x8; // float* acc + const XReg reg_acc2 = x9; // float* acc2 (optional) + + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 40)); + ldr(reg_tail, ptr(reg_args, 48)); + ldr(reg_src_stride, ptr(reg_args, 56)); + ldr(reg_wei_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 88)); + ldr(reg_acc2, ptr(reg_args, 96)); + + Label Lsingle, Ldone; + // Additional labels for kx-loop variants + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + cbz(reg_acc2, Lsingle); + // Jump to in-kernel kx loop dual-OC path (safe, call-clobbered only) + b(Ldual_kx); + + // Dual-OC with in-kernel kx loop (v20 for oc0, v21 for oc1) + L(Ldual_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + // Load kx-loop controls and set bases + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_kw_cnt, ptr(reg_args, 104)); + ldr(reg_src_dx, ptr(reg_args, 112)); + ldr(reg_wei_dx, ptr(reg_args, 120)); + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; + const XReg reg_wei_blk_stride2 = x10; + ldr(reg_wei_blk_stride2, ptr(reg_args, 80)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + // Treat kw_cnt==0 as 1 + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + L(Lkx_d); + // Reset pointers and repeats for this kx + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + // Channel repeats over 8-lane blocks + Label Lrep_d_kx; + L(Lrep_d_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + // Load src lanes into v0 + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load weights for oc0/oc1 (vector fast path if stride==2) + Label Lw_np_d, Lw_done_d2; + cmp(reg_wei_stride, 2); + b(NE, Lw_np_d); + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_blk_stride2); + add(reg_wei2, reg_wei2, reg_wei_blk_stride2); + b(Lw_done_d2); + L(Lw_np_d); + ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Lw_done_d2); + // MAC into accumulators + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d_kx); + // Tail handling per kx + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // advance bases to next kx and continue + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + // reduce and store accumulators + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + // Dual-OC path: v20 (oc0), v21 (oc1) + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + + Label Lrep_d, Ltail_prep_d, Ltail_done_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d); + // Load src lanes (v0) + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load wei lanes for oc0 (v1) and oc1 (v2) + ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + // MAC into v20/v21 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); + + // Tail handling + L(Ltail_prep_d); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + // lanes 0..7 guarded by tail + cmp(reg_tail, 0); b(LE, Ltail_done_d); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_d); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_d); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_d); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_d); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_d); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_d); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_d); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + + L(Ltail_done_d); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // horizontal add and store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + // Single-OC path + L(Lsingle); + // Jump to in-kernel kx loop single-OC path + b(Lsingle_kx); + // Single-OC with in-kernel kx loop + L(Lsingle_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + const XReg s_kw_cnt = x12; + const XReg s_src_dx = x13; + const XReg s_wei_dx = x14; + ldr(s_kw_cnt, ptr(reg_args, 104)); + ldr(s_src_dx, ptr(reg_args, 112)); + ldr(s_wei_dx, ptr(reg_args, 120)); + const XReg s_src_base = x15; + const XReg s_wei_base = x16; + const XReg s_wei_blk_stride2 = x10; + ldr(s_wei_blk_stride2, ptr(reg_args, 80)); + mov(s_src_base, reg_src); + mov(s_wei_base, reg_wei); + cbnz(s_kw_cnt, Lkx_s); + mov(s_kw_cnt, 1); + Label Lrep_s_kx; + L(Lkx_s); + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, s_src_base); + mov(reg_wei, s_wei_base); + L(Lrep_s_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // weights (vector fast path if stride==2) + Label Lw_np_s, Lw_done_s2; + cmp(reg_wei_stride, 2); + b(NE, Lw_np_s); + ld1(VReg8H(1), ptr(reg_wei)); + add(reg_wei, reg_wei, s_wei_blk_stride2); + b(Lw_done_s2); + L(Lw_np_s); + ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + L(Lw_done_s2); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s_kx); + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s_kx); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // advance bases + sub(s_kw_cnt, s_kw_cnt, 1); + add(s_src_base, s_src_base, s_src_dx); + add(s_wei_base, s_wei_base, s_wei_dx); + cbnz(s_kw_cnt, Lkx_s); + // reduce/store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + Label Lrep_s, Ltail_prep_s, Ltail_done_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s); + // src lanes + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // wei lanes + ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + // Tail (single) + L(Ltail_prep_s); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); b(LE, Ltail_done_s); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + + L(Ldone); + ret(); + } + return; + + // abi_param1 -> args pointer + const XReg reg_args = abi_param1; + // Use call-clobbered registers to avoid saving/restoring callee-saved regs + const XReg reg_src = x1; + const XReg reg_wei = x2; + const XReg reg_wei2 = x10; + const XReg reg_reps = x3; // number of full 8-lane blocks + const XReg reg_tail = x9; // remaining channels (< 8) + const XReg reg_src_stride = x4; + const XReg reg_wei_stride = x5; + const XReg reg_acc = x6; + const XReg reg_acc2 = x11; + const XReg reg_wei3 = x19; + const XReg reg_wei4 = x20; + const XReg reg_acc3 = x21; + const XReg reg_acc4 = x22; + + // Prolog: save callee-saved we will use (x19-x29) and LR (x30) + stp(XReg(19), XReg(20), pre_ptr(sp, -16)); + stp(XReg(21), XReg(22), pre_ptr(sp, -16)); + stp(XReg(23), XReg(24), pre_ptr(sp, -16)); + stp(XReg(25), XReg(26), pre_ptr(sp, -16)); + stp(XReg(27), XReg(28), pre_ptr(sp, -16)); + stp(XReg(29), XReg(30), pre_ptr(sp, -16)); + + // Load args + ldr(reg_src, ptr(reg_args)); // src + ldr(reg_wei, ptr(reg_args, 8)); // wei + ldr(reg_wei2, ptr(reg_args, 16)); // wei2 (optional) + ldr(reg_wei3, ptr(reg_args, 24)); // wei3 (optional) + ldr(reg_wei4, ptr(reg_args, 32)); // wei4 (optional) + ldr(reg_reps, ptr(reg_args, 40)); // repeats + ldr(reg_tail, ptr(reg_args, 48)); // tail (<= 8) + ldr(reg_src_stride, ptr(reg_args, 56)); // src_stride bytes + ldr(reg_wei_stride, ptr(reg_args, 64)); // wei_stride bytes + const XReg reg_src_blk_stride = x7; + const XReg reg_wei_blk_stride = x8; + ldr(reg_src_blk_stride, ptr(reg_args, 72)); // src_blk_stride bytes + ldr(reg_wei_blk_stride, ptr(reg_args, 80)); // wei_blk_stride bytes + ldr(reg_acc, ptr(reg_args, 88)); // acc (float*) + ldr(reg_acc2, ptr(reg_args, 96)); // acc2 (float* or 0) + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_kw_cnt, ptr(reg_args, 104)); // kw count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dx, ptr(reg_args, 112)); // src dx step in bytes (x dimension) + ldr(reg_wei_dx, ptr(reg_args, 120)); // wei dx step in bytes (x dimension) + ldr(reg_acc3, ptr(reg_args, 128)); // acc3 (float* or 0) + ldr(reg_acc4, ptr(reg_args, 136)); // acc4 (float* or 0) + const XReg reg_kh_cnt = x26; + const XReg reg_src_dy = x27; + const XReg reg_wei_dy = x28; + ldr(reg_kh_cnt, ptr(reg_args, 144)); // kh count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) + ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) + + // Force single-ky iteration inside kernel to simplify and ensure stability + mov(reg_kh_cnt, 1); + // Disable quad-OC path for stability (use dual or single) + eor(reg_acc4, reg_acc4, reg_acc4); + + Label Lsingle, Lend_all; + // If acc4 != 0, run quad-OC; else if acc2 != 0, run dual-OC; else single-OC. + Label Lq_entry; + cbnz(reg_acc4, Lq_entry); + cbz(reg_acc2, Lsingle); + + // ---------------- Quad-OC path ---------------- + L(Lq_entry); + { + // Zero v20..v23 + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + eor(VReg16B(22), VReg16B(22), VReg16B(22)); + eor(VReg16B(23), VReg16B(23), VReg16B(23)); + + Label Lq_ky_loop, Lq_kx_loop, Lq_loop, Lq_after_loop, Lq_after_fill, Lq_after_kx, Lq_np; + // Packed fast path if wei_stride == 2 + cmp(reg_wei_stride, 2); + b(NE, Lq_np); + + // Save repeats and bases for kw loop + const XReg q_reps_init = x15; + const XReg q_src_base = x16; + const XReg q_wei_base = x17; + // avoid x18 on Apple; use callee-saved and restore at epilog + const XReg q_wei2_base = x23; + const XReg q_wei3_base = x24; + const XReg q_wei4_base = x25; + const XReg q_kw_init = x28; + const XReg q_kh_work = x29; + mov(q_reps_init, reg_reps); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + mov(q_wei3_base, reg_wei3); + mov(q_wei4_base, reg_wei4); + mov(q_kw_init, reg_kw_cnt); + mov(q_kh_work, reg_kh_cnt); + + // ky loop + L(Lq_ky_loop); + // kx loop entry + L(Lq_kx_loop); + mov(reg_reps, q_reps_init); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + mov(reg_wei3, q_wei3_base); + mov(reg_wei4, q_wei4_base); + mov(reg_kw_cnt, q_kw_init); + + // Main repeats loop + L(Lq_loop); + cmp(reg_reps, 0); + b(LE, Lq_after_loop); + // Load 8 src lanes + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load 4 weight vectors + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + ld1(VReg8H(3), ptr(reg_wei3)); + ld1(VReg8H(4), ptr(reg_wei4)); + // Accumulate into v20..v23 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + // Advance block pointers (src already advanced by 8*src_stride during lane loads) + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + add(reg_wei3, reg_wei3, reg_wei_blk_stride); + add(reg_wei4, reg_wei4, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Lq_loop); + + // Tail <=8 + L(Lq_after_loop); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + eor(VReg16B(3), VReg16B(3), VReg16B(3)); + eor(VReg16B(4), VReg16B(4), VReg16B(4)); + cmp(reg_tail, 0); b(LE, Lq_after_fill); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); ld1(VReg(3).h[0], ptr(reg_wei3)); ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Lq_after_fill); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); ld1(VReg(3).h[1], ptr(reg_wei3)); ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Lq_after_fill); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); ld1(VReg(3).h[2], ptr(reg_wei3)); ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Lq_after_fill); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); ld1(VReg(3).h[3], ptr(reg_wei3)); ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Lq_after_fill); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); ld1(VReg(3).h[4], ptr(reg_wei3)); ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Lq_after_fill); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); ld1(VReg(3).h[5], ptr(reg_wei3)); ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Lq_after_fill); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); ld1(VReg(3).h[6], ptr(reg_wei3)); ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Lq_after_fill); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); ld1(VReg(3).h[7], ptr(reg_wei3)); ld1(VReg(4).h[7], ptr(reg_wei4)); + L(Lq_after_fill); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + + // advance bases for next kx + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + add(q_wei3_base, q_wei3_base, reg_wei_dx); + add(q_wei4_base, q_wei4_base, reg_wei_dx); + subs(reg_kw_cnt, reg_kw_cnt, 1); + b(GT, Lq_kx_loop); + + // reduce/store 4 accumulators + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + // advance bases for next ky and continue if any + add(q_src_base, q_src_base, reg_src_dy); + add(q_wei_base, q_wei_base, reg_wei_dy); + add(q_wei2_base, q_wei2_base, reg_wei_dy); + add(q_wei3_base, q_wei3_base, reg_wei_dy); + add(q_wei4_base, q_wei4_base, reg_wei_dy); + subs(q_kh_work, q_kh_work, 1); + b(GT, Lq_ky_loop); + + // After ky loop: reduce/store + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); fadd(SReg(2), SReg(2), SReg(22)); str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); fadd(SReg(3), SReg(3), SReg(23)); str(SReg(3), ptr(reg_acc4)); + b(Lend_all); + + // Not-packed fallback for quad: do lane-wise loads for 4 outputs + L(Lq_np); + // Zero v20..v23 + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + eor(VReg16B(22), VReg16B(22), VReg16B(22)); + eor(VReg16B(23), VReg16B(23), VReg16B(23)); + // single block (not looped here) — rely on host collapsing work + // lanes 0..7 + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); ld1(VReg(3).h[0], ptr(reg_wei3)); ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); ld1(VReg(3).h[1], ptr(reg_wei3)); ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); ld1(VReg(3).h[2], ptr(reg_wei3)); ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); ld1(VReg(3).h[3], ptr(reg_wei3)); ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); ld1(VReg(3).h[4], ptr(reg_wei3)); ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); ld1(VReg(3).h[5], ptr(reg_wei3)); ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); ld1(VReg(3).h[6], ptr(reg_wei3)); ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); ld1(VReg(3).h[7], ptr(reg_wei3)); ld1(VReg(4).h[7], ptr(reg_wei4)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + // reduce/store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); fadd(SReg(2), SReg(2), SReg(22)); str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); fadd(SReg(3), SReg(3), SReg(23)); str(SReg(3), ptr(reg_acc4)); + b(Lend_all); + } + // ---------------- Dual-OC path ---------------- + { + // Zero FP32 accumulators v20.4s and v21.4s + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + Label Lnp_d, Ld_after_fill, Ld_after_loop, Ld_kx_loop, Ld_after_kx; + // If not packed, fallback to non-packed path below (no kw-loop optimization) + cmp(reg_wei_stride, 2); + b(NE, Lnp_d); + // Save initial repeats and bases + const XReg reg_reps_init = x15; + const XReg reg_src_base = x16; + const XReg reg_wei_base = x17; + const XReg reg_wei2_base = x18; + mov(reg_reps_init, reg_reps); + mov(reg_src_base, reg_src); + mov(reg_wei_base, reg_wei); + mov(reg_wei2_base, reg_wei2); + const XReg reg_kw_init = x28; + const XReg reg_kh_work = x29; + mov(reg_kw_init, reg_kw_cnt); + mov(reg_kh_work, reg_kh_cnt); + // ky-loop wrapper around kx-loop (packed fast path) + Label Ld_ky_loop; + L(Ld_ky_loop); + // Reset per-ky state and restore kw counter + mov(reg_reps, reg_reps_init); + mov(reg_src, reg_src_base); + mov(reg_wei, reg_wei_base); + mov(reg_wei2, reg_wei2_base); + mov(reg_kw_cnt, reg_kw_init); + // kw-loop (only if kw_cnt > 0); if kw_cnt==0 -> process a single position + L(Ld_kx_loop); + // Main loop over full 8-lane channel blocks + Label Ld_loop; + L(Ld_loop); + cmp(reg_reps, 0); + b(LE, Ld_after_loop); + // Load 8 src half lanes with strides between channels + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load 8 half weights for oc0 and oc1 as vectors + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + // MAC for oc0 → v20, oc1 → v21 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance pointers to next block (src already advanced by 8*src_stride) + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Ld_loop); + // Tail processing (<=8) + L(Ld_after_loop); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ld_after_fill); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ld_after_fill); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ld_after_fill); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ld_after_fill); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ld_after_fill); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ld_after_fill); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ld_after_fill); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ld_after_fill); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ld_after_fill); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance base pointers for next kx + add(reg_src_base, reg_src_base, reg_src_dx); + add(reg_wei_base, reg_wei_base, reg_wei_dx); + add(reg_wei2_base, reg_wei2_base, reg_wei_dx); + // Decrement kw_cnt and loop + subs(reg_kw_cnt, reg_kw_cnt, 1); + b(GT, Ld_kx_loop); + // After kx loop for one ky, advance to next ky if any + add(reg_src_base, reg_src_base, reg_src_dy); + add(reg_wei_base, reg_wei_base, reg_wei_dy); + add(reg_wei2_base, reg_wei2_base, reg_wei_dy); + subs(reg_kh_work, reg_kh_work, 1); + b(GT, Ld_ky_loop); + // Fallthrough to store + b(Ld_after_kx); + + // Not-packed path: pairwise lane loads for both wei/wei2 (no kw-loop optimization) + L(Lnp_d); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + // Accumulate + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // Advance to next block + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + sub(reg_reps, reg_reps, 1); + b(Ld_loop); + + // Reduce and store both accumulators + L(Ld_after_kx); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Lend_all); + } + + // ---------------- Single-OC path (default) ---------------- + L(Lsingle); + // Zero FP32 accumulator v20.4s + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + + Label Lloop2, Lloop, Lafter_loop, Lafter_fill, Lnp1, Lend1, Lnp2, Lend2, Lnot_packed, Lkx_loop_s, Lafter_kx_s, Ls_loop, Ls_after_loop, Ls_after_fill; + + // Unrolled-by-2 loop over full 8-lane channel blocks (default single position) + b(Lloop); + L(Lloop2); + cmp(reg_reps, 2); + b(LT, Lloop); + + // First block (packed fast path if available) + cmp(reg_wei_stride, 2); + b(NE, Lnp1); + // packed: src lanes (8), then vector-load weights + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg8H(1), ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lend1); + L(Lnp1); + // not packed: pairwise lanes + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lend1); + + // Second block + cmp(reg_wei_stride, 2); + b(NE, Lnp2); + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg8H(1), ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lend2); + L(Lnp2); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lend2); + + sub(reg_reps, reg_reps, 2); + b(Lloop2); + + // Single-block loop for remaining one block + L(Lloop); + cmp(reg_reps, 0); + b(EQ, Lafter_loop); + + // Prepare containers for src/wei half vectors in v0/v1 (fully overwritten by loads) + + // Choose packed-weight fast path if wei_stride == 2 bytes + cmp(reg_wei_stride, 2); + b(NE, Lnot_packed); + + // Packed path: fill src lanes, vector-load wei + ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // load 8 half wei as one vector + ld1(VReg8H(1), ptr(reg_wei)); + // Multiply-accumulate (fp16 widen to fp32): lower and upper halves + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // Advance pointers to next block + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_blk_stride); + sub(reg_reps, reg_reps, 1); + b(Lloop); + + // Not-packed path: fill src/wei lanes pairwise and accumulate + L(Lnot_packed); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + // Widening MAC for both halves + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // Advance pointers to next block (src already advanced by 8*src_stride) + add(reg_wei, reg_wei, reg_wei_stride); + sub(reg_reps, reg_reps, 1); + b(Lloop); + + L(Lafter_loop); + + // Tail processing (<= 8) + // Prepare containers for src/wei half vectors in v0/v1 + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + // lane 0 + cmp(reg_tail, 0); + b(LE, Lafter_fill); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 1 + cmp(reg_tail, 1); + b(LE, Lafter_fill); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 2 + cmp(reg_tail, 2); + b(LE, Lafter_fill); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 3 + cmp(reg_tail, 3); + b(LE, Lafter_fill); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 4 + cmp(reg_tail, 4); + b(LE, Lafter_fill); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 5 + cmp(reg_tail, 5); + b(LE, Lafter_fill); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 6 + cmp(reg_tail, 6); + b(LE, Lafter_fill); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + // lane 7 + cmp(reg_tail, 7); + b(LE, Lafter_fill); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + + L(Lafter_fill); + // Accumulate tail using widening fp16 MAC + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + + // Horizontal add v20 + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + + // Load *acc, add, store back + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + + L(Lend_all); + // Epilog: restore callee-saved + ldp(XReg(29), XReg(30), post_ptr(sp, 16)); + ldp(XReg(27), XReg(28), post_ptr(sp, 16)); + ldp(XReg(25), XReg(26), post_ptr(sp, 16)); + ldp(XReg(23), XReg(24), post_ptr(sp, 16)); + ldp(XReg(21), XReg(22), post_ptr(sp, 16)); + ldp(XReg(19), XReg(20), post_ptr(sp, 16)); + ret(); +} + +// --------------------------- Executor --------------------------- + +static inline const uint16_t* ptr_f16(const MemoryPtr& m) { + return reinterpret_cast(m->getData()); +} +static inline uint16_t* ptr_f16(MemoryPtr& m) { + return reinterpret_cast(m->getData()); +} + +JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, + const MemoryArgs& memory, + const ExecutorContext::CPtr& context) + : m_attrs(attrs), m_memory(memory) { + (void)context; + m_threadsNum = static_cast(parallel_get_max_threads()); + // Initialize minimal inner-product kernel (scaffold) + m_ip_kernel = std::make_unique(); + m_ip_kernel->create_ker(); + + // Extract optional PReLU post-op (per-tensor or per-channel), but keep disabled by default. + for (const auto& po : m_attrs.postOps) { + if (const auto* const ss = std::any_cast(&po)) { + if (ss->type() == ScaleShiftPostOp::Type::prelu) { + m_has_prelu = true; + m_prelu_slopes = ss->scales(); + break; + } + } + } +} + +bool JitConv3DExecutor::supports(const ConvConfig& cfg) { + // Require 5D NCDHW, FP16 src/wei/dst, group=1, no dilation, stride 1 or 2 + if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; + if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) return false; + + const auto& s = cfg.descs.at(ARG_SRC)->getShape(); + const auto& w = cfg.descs.at(ARG_WEI)->getShape(); + const auto& d = cfg.descs.at(ARG_DST)->getShape(); + if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) return false; + + const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); + const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); + const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); + if (!(sp == ov::element::f16 && wp == ov::element::f16 && dp == ov::element::f16)) return false; + + // group == 1: weights rank==5 (no groups) + if (w.getRank() != 5) return false; + + // dilation == 0 + for (auto v : cfg.attrs.dilation) { + if (v != 0) return false; + } + // stride in [1,2] if set + for (auto v : cfg.attrs.stride) { + if (!(v == 1 || v == 2)) return false; + } + return true; +} + +void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { + // NCDHW + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + auto dst = memory.at(ARG_DST); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); + + const size_t N = srcDims[0]; + const size_t C = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; + const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; + const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; + + const size_t PD0 = m_attrs.paddingL.size() > 0 ? static_cast(m_attrs.paddingL[0]) : 0; + const size_t PH0 = m_attrs.paddingL.size() > 1 ? static_cast(m_attrs.paddingL[1]) : 0; + const size_t PW0 = m_attrs.paddingL.size() > 2 ? static_cast(m_attrs.paddingL[2]) : 0; + + const uint16_t* src_p = ptr_f16(src); + const uint16_t* wei_p = ptr_f16(wei); + uint16_t* dst_p = ptr_f16(dst); + + auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) -> size_t { + return (((n * C + c) * ID + z) * IH + y) * IW + x; + }; + auto index_dst = [&](size_t n, size_t oc, size_t z, size_t y, size_t x) -> size_t { + return (((n * OC + oc) * OD + z) * OH + y) * OW + x; + }; + auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + + // Prepare packed weights once + ensure_weights_packed(memory); + + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = oc0 + 1; + const size_t oc2 = oc0 + 2; + const size_t oc3 = oc0 + 3; + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + for (size_t od = 0; od < OD; ++od) { + const int64_t iz0 = static_cast(od * SD) - static_cast(PD0); + for (size_t oh = 0; oh < OH; ++oh) { + const int64_t iy0 = static_cast(oh * SH) - static_cast(PH0); + for (size_t ow = 0; ow < OW; ++ow) { + const int64_t ix0 = static_cast(ow * SW) - static_cast(PW0); + + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_c_stride_elems = KD * KH * KW; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, + static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, + static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, + static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + const size_t kh_count = static_cast(ky_hi - ky_lo + 1); + const size_t iy = static_cast(iy0 + ky_lo); + const size_t ix = static_cast(ix0 + kx_lo); + const size_t s_base0 = index_src(n, 0, iz, iy, ix); + + if (m_wei_packed_ready) { + // Loop over ky in host; kernel handles kx via kw_cnt + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t pack_base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei = m_wei_packed.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base1; + } + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_C * sizeof(uint16_t); + (*m_ip_kernel)(&a); + if (has_oc2) { + jit_conv3d_call_args a2{}; + a2.src = src_p + s_base2; + a2.src_stride = a.src_stride; + a2.src_blk_stride = a.src_blk_stride; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = a.repeats; + a2.tail = a.tail; + a2.kw_cnt = a.kw_cnt; + a2.src_dx = a.src_dx; + const size_t pack_base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a2.wei = m_wei_packed.data() + pack_base2; + if (has_oc3) { + const size_t pack_base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a2.wei2 = m_wei_packed.data() + pack_base3; + } + a2.wei_stride = a.wei_stride; + a2.wei_blk_stride = a.wei_blk_stride; + a2.wei_dx = a.wei_dx; + (*m_ip_kernel)(&a2); + } + } + } else { + // Non-packed: keep ky loop outside and issue dual calls + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + const size_t w0_base = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w1_base = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = wei_p + w0_base; + if (has_oc1) a.wei2 = wei_p + w1_base; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel)(&a); + } + if (has_oc2) { + const size_t w2_base = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w3_base = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = wei_p + w2_base; + if (has_oc3) a.wei2 = wei_p + w3_base; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel)(&a); + } + } + } + } + } + } else { + for (size_t kz = 0; kz < KD; ++kz) { + const int64_t iz = iz0 + static_cast(kz); + if (iz < 0 || iz >= static_cast(ID)) continue; + for (size_t ky = 0; ky < KH; ++ky) { + const int64_t iy = iy0 + static_cast(ky); + if (iy < 0 || iy >= static_cast(IH)) continue; + for (size_t kx = 0; kx < KW; ++kx) { + const int64_t ix = ix0 + static_cast(kx); + if (ix < 0 || ix >= static_cast(IW)) continue; + const size_t s_base0 = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + const size_t w0_base = index_wei(oc0, 0, kz, ky, kx); + const size_t w1_base = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; // used logically, kernel advances by stride once after 8 lanes + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + + if (m_wei_packed_ready) { + // packed index: ((((oc*KD + kz)*KH + ky)*KW + kx)*paddedC) + const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei = m_wei_packed.data() + pack_base0; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = sizeof(uint16_t); // contiguous halves + a.wei_blk_stride = a.wei_stride * 8; // logical + if (has_oc1) { + const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base1; + } + (*m_ip_kernel)(&a); + } else { + a.wei = wei_p + w0_base; + if (has_oc1) a.wei2 = wei_p + w1_base; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + (*m_ip_kernel)(&a); + } + } + // pair 1 + if (has_oc2) { + const size_t w2_base = index_wei(oc2, 0, kz, ky, kx); + const size_t w3_base = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; // used logically, kernel advances by stride once after 8 lanes + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + if (m_wei_packed_ready) { + const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei = m_wei_packed.data() + pack_base2; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = sizeof(uint16_t); // contiguous halves + a.wei_blk_stride = a.wei_stride * 8; // logical + if (has_oc3) { + const size_t pack_base3 = (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base3; + } + (*m_ip_kernel)(&a); + } else { + a.wei = wei_p + w2_base; + if (has_oc3) a.wei2 = wei_p + w3_base; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + (*m_ip_kernel)(&a); + } + } + } + } + } + } + // Optional fused bias (disabled by default) + if (m_apply_post_ops && m_attrs.withBias && memory.count(ARG_BIAS) && memory.at(ARG_BIAS)) { + auto bia = memory.at(ARG_BIAS); + const auto bprec = bia->getDescPtr()->getPrecision(); + if (bprec == ov::element::f32) { + const float* b = reinterpret_cast(bia->getData()); + acc0 += b[oc0]; + if (has_oc1) acc1 += b[oc1]; + if (has_oc2) acc2 += b[oc2]; + if (has_oc3) acc3 += b[oc3]; + } else if (bprec == ov::element::f16) { + const uint16_t* b = reinterpret_cast(bia->getData()); + acc0 += static_cast(ov::float16(b[oc0])); + if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); + } + } + + // Optional fused PReLU (apply after bias) — disabled by default + if (m_apply_post_ops && m_has_prelu && !m_prelu_slopes.empty()) { + const auto slope_at = [&](size_t oc) -> float { + return m_prelu_slopes.size() == 1 ? m_prelu_slopes[0] + : m_prelu_slopes[std::min(oc, m_prelu_slopes.size() - 1)]; + }; + const float s0 = slope_at(oc0); + if (acc0 < 0.f) acc0 *= s0; + if (has_oc1) { const float s1 = slope_at(oc1); if (acc1 < 0.f) acc1 *= s1; } + if (has_oc2) { const float s2 = slope_at(oc2); if (acc2 < 0.f) acc2 *= s2; } + if (has_oc3) { const float s3 = slope_at(oc3); if (acc3 < 0.f) acc3 *= s3; } + } + + dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); + if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); + } + } + } + }); +} + +void JitConv3DExecutor::execute(const MemoryArgs& memory) { + run_naive_fp16(memory); +} + +} // namespace ov::intel_cpu +void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& memory) { + if (m_wei_packed_ready) return; + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + if (srcDims.size() != 5 || weiDims.size() != 5) return; + const size_t C = srcDims[1]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_C = (C + 7) / 8 * 8; + // Pack layout: [OC, KD, KH, KW, Ct] + const size_t total = OC * KD * KH * KW * m_padded_C; + m_wei_packed.assign(total, static_cast(0)); + const uint16_t* wsrc = reinterpret_cast(wei->getData()); + + auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t blk = c / 8; + const size_t lane = c % 8; + return base + blk * 8 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t c = 0; c < C; ++c) { + m_wei_packed[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready = true; +} diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp new file mode 100644 index 00000000000000..5941b378eb2f1d --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "nodes/executors/convolution_config.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "onednn/iml_type_mapper.h" + +// Xbyak AArch64 JIT +#include + +namespace ov::intel_cpu { + +struct jit_conv3d_call_args { + const uint16_t* src; // f16 base ptr + const uint16_t* wei; // f16 base ptr + const uint16_t* wei2; // optional second oc f16 base ptr (can be null) + const uint16_t* wei3; // optional third oc f16 base ptr (can be null) + const uint16_t* wei4; // optional fourth oc f16 base ptr (can be null) + size_t repeats; // number of full 8-channel blocks + size_t tail; // remaining channels (< 8) + size_t src_stride; // stride between channels in bytes + size_t wei_stride; // stride between channels in bytes + size_t src_blk_stride; // stride between successive 8-channel blocks in bytes + size_t wei_blk_stride; // stride between successive 8-channel blocks in bytes + float* acc; // f32 accumulator + float* acc2; // optional second f32 accumulator (can be null) + size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dx; // bytes to advance src base between successive kx taps + size_t wei_dx; // bytes to advance weights base between successive kx taps + float* acc3; // optional third f32 accumulator (can be null) + float* acc4; // optional fourth f32 accumulator (can be null) + size_t kh_cnt; // number of taps along H to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dy; // bytes to advance src base between successive ky taps + size_t wei_dy; // bytes to advance weights base between successive ky taps +}; + +class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF16) + using jit_fn = void (*)(const jit_conv3d_call_args*); + + JitConv3DKernelF16(); + + void create_ker(); + inline void operator()(const jit_conv3d_call_args* p) const { + ker_(p); + } + +private: + void generate() override; + + jit_fn ker_{nullptr}; +}; + +// AArch64 JIT Convolution (FP16) executor for 3D conv (NCDHW) +class JitConv3DExecutor : public Executor { +public: + JitConv3DExecutor(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); + + bool update(const MemoryArgs& memory) override { + m_memory = memory; + return true; + } + void execute(const MemoryArgs& memory) override; + void execute() override { execute(m_memory); } + void exec([[maybe_unused]] const std::vector& src, + [[maybe_unused]] const std::vector& dst) override {} + + [[nodiscard]] impl_desc_type implType() const override { return impl_desc_type::jit_asimd; } + + static bool supports(const ConvConfig& cfg); + +private: + // Simple reference fallback (parallelized) using FP16 data; correctness-first + void run_naive_fp16(const MemoryArgs& memory); + void ensure_weights_packed(const MemoryArgs& memory); + + // Minimal inner-product kernel (fp16 x fp16 -> f32 accumulation) + std::unique_ptr m_ip_kernel; + + ConvAttrs m_attrs; + MemoryArgs m_memory; + size_t m_threadsNum{0}; + + // Packed weights: layout [OC, KD, KH, KW, Ct] where Ct is 8-lane channel tiles + std::vector m_wei_packed; + bool m_wei_packed_ready{false}; + size_t m_padded_C{0}; + + // Optional fused PReLU (per-tensor or per-channel). Extracted from attrs.postOps. + bool m_has_prelu{false}; + std::vector m_prelu_slopes; // size 1 (per-tensor) or OC (per-channel) + + // Gate executor-side post-ops (bias, PReLU). Disabled per user request for measurements. + bool m_apply_post_ops{false}; +}; + +using JitConv3DExecutorPtr = std::shared_ptr; + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp new file mode 100644 index 00000000000000..5cae3d6be3c1c5 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -0,0 +1,339 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/jit_deconv3d.hpp" + +#include +#include +#include + +#include "cpu_memory.h" +#include "memory_desc/cpu_memory_desc.h" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/core/type/float16.hpp" + +namespace ov::intel_cpu { + +static inline const uint16_t* as_f16(const MemoryPtr& m) { + return reinterpret_cast(m->getData()); +} +static inline uint16_t* as_f16(MemoryPtr& m) { + return reinterpret_cast(m->getData()); +} + +bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& /*attr*/) { + deconvAttrs = attrs; + m_srcDescs = srcDescs; + m_dstDescs = dstDescs; + // Initialize AArch64 ip kernel (fp16 x fp16 -> f32 accum) + m_ip_kernel = std::make_unique(); + m_ip_kernel->create_ker(); + return true; +} + +void JitDeconv3DExecutor::ensure_weights_packed(const std::vector& src) { + if (m_wei_packed_ready) return; + // src[1] holds weights for deconv with shape [IC, OC, KD, KH, KW] + const auto& weiDims = src[1]->getStaticDims(); + if (weiDims.size() != 5) return; + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC; + m_wei_packed.assign(total, static_cast(0)); + const uint16_t* wsrc = reinterpret_cast(src[1]->getData()); + + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; + const size_t blk = ic / 8; + const size_t lane = ic % 8; + return base + blk * 8 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready = true; +} + +void JitDeconv3DExecutor::exec(const std::vector& src, + const std::vector& dst, + const void* /*post_ops_data_*/) { + // NCDHW, fp16: compute each output pixel (n, oc, od, oh, ow) as a sum over (ic, kz, ky, kx) + const auto& srcDims = src[0]->getStaticDims(); + const auto& weiDims = src[1]->getStaticDims(); + const auto& dstDims = dst[0]->getStaticDims(); + + const size_t N = srcDims[0]; + const size_t IC = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + // Deconv weights layout: [IC, OC, KD, KH, KW] + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; + const size_t SH = deconvAttrs.stride.size() > 1 ? static_cast(deconvAttrs.stride[1]) : 1; + const size_t SW = deconvAttrs.stride.size() > 2 ? static_cast(deconvAttrs.stride[2]) : 1; + + const ptrdiff_t PD0 = deconvAttrs.paddingL.size() > 0 ? deconvAttrs.paddingL[0] : 0; + const ptrdiff_t PH0 = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL[1] : 0; + const ptrdiff_t PW0 = deconvAttrs.paddingL.size() > 2 ? deconvAttrs.paddingL[2] : 0; + + const auto* src_p = reinterpret_cast(src[0]->getData()); + const auto* wei_p = reinterpret_cast(src[1]->getData()); + auto* dst_p = reinterpret_cast(dst[0]->getData()); + + auto idx_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * IC + c) * ID + z) * IH + y) * IW + x; + }; + auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + // weight [IC, OC, KD, KH, KW] + auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { + return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + + // Strides in elements + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_ic_stride_elems = OC * KD * KH * KW; + + ensure_weights_packed(src); + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = oc0 + 1; + const size_t oc2 = oc0 + 2; + const size_t oc3 = oc0 + 3; + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + const size_t n_base = n * IC * ID * IH * IW; + for (size_t od = 0; od < OD; ++od) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow_ = 0; ow_ < OW; ++ow_) { + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + if (SD == 1 && SH == 1 && SW == 1) { + // Fast path: contiguous tap ranges, no modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID) + 1); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH) + 1); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW) + 1); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t id = static_cast(tzd - kz); + const size_t src_z_off = id * IH * IW; + const size_t kh_count = static_cast(ky_hi - ky_lo + 1); + const size_t ihh = static_cast(tyd - ky_lo); + const size_t src_y_off = ihh * IW; + size_t s_base_row = n_base + src_z_off + src_y_off; + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + if (m_wei_packed_ready) { + const size_t s_base0 = s_base_row + static_cast(txd - kx_lo); + // Compute packed bases for ky_lo + size_t pack_base_z0 = (oc0 * KD + static_cast(kz)) * KH; + size_t pack_base_z1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + size_t pack_base_z2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + size_t pack_base_z3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + size_t pack_base_y0 = (pack_base_z0 + static_cast(ky_lo)) * KW; + size_t pack_base_y1 = has_oc1 ? (pack_base_z1 + static_cast(ky_lo)) * KW : 0; + size_t pack_base_y2 = has_oc2 ? (pack_base_z2 + static_cast(ky_lo)) * KW : 0; + size_t pack_base_y3 = has_oc3 ? (pack_base_z3 + static_cast(ky_lo)) * KW : 0; + const size_t pack_base0 = (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC; + const size_t pack_base1 = has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC : 0; + const size_t pack_base2 = has_oc2 ? (pack_base_y2 + static_cast(kx_lo)) * m_padded_IC : 0; + const size_t pack_base3 = has_oc3 ? (pack_base_y3 + static_cast(kx_lo)) * m_padded_IC : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + // Compute only oc0/oc1 in this call; oc2/oc3 will be handled by a second dual call + a.repeats = IC / 8; + a.tail = IC % 8; + a.kw_cnt = kw_count; + a.kh_cnt = kh_count; + a.src_dx = sizeof(uint16_t); + a.src_dy = IW * sizeof(uint16_t); + a.wei = m_wei_packed.data() + pack_base0; + if (has_oc1) a.wei2 = m_wei_packed.data() + pack_base1; + // oc2/oc3 handled in a follow-up dual call + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC * sizeof(uint16_t); + a.wei_dy = KW * m_padded_IC * sizeof(uint16_t); + (*m_ip_kernel)(&a); + } else { + // Generic ky+kx loops (not packed) + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t ihh2 = static_cast(tyd - ky); + size_t s_base_row2 = n_base + src_z_off + ihh2 * IW; + size_t iww = static_cast(txd - kx_lo); + for (ptrdiff_t kx = kx_lo; kx <= kx_hi; ++kx, ++iww) { + const size_t s_base0 = s_base_row2 + iww; + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = IC / 8; + a.tail = IC % 8; + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)) : 0; + a.wei = wei_p + w_base0; + if (has_oc1) a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + (*m_ip_kernel)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = IC / 8; + a.tail = IC % 8; + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); + const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)) : 0; + a.wei = wei_p + w_base2; + if (has_oc3) a.wei2 = wei_p + w_base3; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + (*m_ip_kernel)(&a); + } + } + } + } + } + } + } else { + // Generic path (stride > 1): keep modulus checks + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); + if (SD == 0) continue; + if (iz_num % static_cast(SD) != 0) continue; + const ptrdiff_t id = iz_num / static_cast(SD); + if (id < 0 || id >= static_cast(ID)) continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); + if (SH == 0) continue; + if (iy_num % static_cast(SH) != 0) continue; + const ptrdiff_t ihh = iy_num / static_cast(SH); + if (ihh < 0 || ihh >= static_cast(IH)) continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix_num = static_cast(ow_) + PW0 - static_cast(kx); + if (SW == 0) continue; + if (ix_num % static_cast(SW) != 0) continue; + const ptrdiff_t iww = ix_num / static_cast(SW); + if (iww < 0 || iww >= static_cast(IW)) continue; + + const size_t s_base0 = idx_src(n, 0, static_cast(id), static_cast(ihh), static_cast(iww)); + const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); + const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; + + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = IC / 8; + a.tail = IC % 8; + if (m_wei_packed_ready) { + const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; + a.wei = m_wei_packed.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; + a.wei2 = m_wei_packed.data() + pack_base1; + } + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + } else { + a.wei = wei_p + w_base0; + if (has_oc1) a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + } + (*m_ip_kernel)(&a); + } + } + } + } + // Optional fused bias for deconv + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const float* b = reinterpret_cast(src[2]->getData()); + acc0 += b[oc0]; + if (has_oc1) acc1 += b[oc1]; + if (has_oc2) acc2 += b[oc2]; + if (has_oc3) acc3 += b[oc3]; + } else if (bprec == ov::element::f16) { + const uint16_t* b = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(b[oc0])); + if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); + } + } + + dst_p[idx_dst(n, oc0, od, oh, ow_)] = ov::float16(acc0).to_bits(); + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1).to_bits(); + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2).to_bits(); + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3).to_bits(); + } + } + } + }); +} + +bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const { + // Support 5D NCDHW, fp16 only for now + if (srcDescs.size() < 2 || dstDescs.empty()) return false; + if (srcDescs[0]->getShape().getRank() != 5 || srcDescs[1]->getShape().getRank() != 5 || + dstDescs[0]->getShape().getRank() != 5) { + return false; + } + const auto prec = srcDescs[0]->getPrecision(); + if (!(prec == ov::element::f16 && srcDescs[1]->getPrecision() == ov::element::f16 && + dstDescs[0]->getPrecision() == ov::element::f16)) { + return false; + } + return true; +} + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp new file mode 100644 index 00000000000000..a7348f4eac39b7 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "nodes/executors/deconv.hpp" +#include "nodes/executors/aarch64/jit_conv3d.hpp" + +namespace ov::intel_cpu { + +class JitDeconv3DExecutor : public DeconvExecutor { +public: + explicit JitDeconv3DExecutor(ExecutorContext::CPtr context) : DeconvExecutor(std::move(context)) {} + ~JitDeconv3DExecutor() override = default; + + bool init(const DeconvAttrs& deconvAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) override; + + void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) override; + + [[nodiscard]] impl_desc_type getImplType() const override { return impl_desc_type::jit_asimd; } + +private: + std::vector m_srcDescs; + std::vector m_dstDescs; + std::unique_ptr m_ip_kernel; + std::vector m_wei_packed; + bool m_wei_packed_ready{false}; + size_t m_padded_IC{0}; + void ensure_weights_packed(const std::vector& src); +}; + +class AArch64JitDeconvExecutorBuilder : public DeconvExecutorBuilder { +public: + [[nodiscard]] bool isSupported(const DeconvAttrs& attrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const override; + [[nodiscard]] DeconvExecutorPtr makeExecutor(ExecutorContext::CPtr context) const override { + return std::make_shared(context); + } +}; + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp index 78a45bd10bb76f..1c088fe0478ae9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp @@ -32,6 +32,9 @@ #if defined(OV_CPU_WITH_ACL) # include "nodes/executors/acl/acl_conv.hpp" #endif +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_conv3d.hpp" +#endif namespace ov::intel_cpu { @@ -101,6 +104,17 @@ struct CreateOptimalConfigAclLowp { template <> const std::vector>& getImplementations() { static const std::vector> convolutionImplementations { + OV_CPU_INSTANCE_ARM64( + "convolution_jit_aarch64_3d_fp16_ncsp", ExecutorType::Jit, OperationType::Convolution, + // supports: prefer our AArch64 JIT whenever attrs/shapes permit + [](const ConvConfig& config, [[maybe_unused]] const MemoryFormatFilter& memoryFormatFilter) -> bool { + return JitConv3DExecutor::supports(config); + }, + // Ask for plain ncsp layouts to avoid being filtered out + CreateOptimalConfigDefault{{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}}, + AcceptsAnyShape, + CreateDefault{} + ) OV_CPU_INSTANCE_DNNL_X64( "convolution_dnnl_nspc_nspc", ExecutorType::Dnnl, OperationType::Convolution, // supports diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp index 2a5d7a087610f2..b77d3d62fa1d5f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp @@ -7,6 +7,9 @@ #include #include "utils/arch_macros.h" +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif #if defined(OV_CPU_WITH_ACL) # include @@ -19,7 +22,10 @@ namespace ov::intel_cpu { const std::vector& getDeconvExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; + // Prefer ACL builder first for stability/perf; fallback to AArch64 JIT if ACL not supported + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) + OV_CPU_INSTANCE_ARM64(ExecutorType::Jit, std::make_shared()) + }; return descs; } From f77cee0042866b7ab002f3c7f63d06ec8341b103 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 16 Oct 2025 15:25:32 +0200 Subject: [PATCH 02/20] Optimize AArch64 JIT Conv3D executor by adding vector fast path for `wei_stride == 2`. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index bb1d417dc324d6..4b6a7f1ba3fb1f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -211,7 +211,16 @@ void JitConv3DKernelF16::generate() { ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); - // Load wei lanes for oc0 (v1) and oc1 (v2) + // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 + Label Ldw_np_d, Ldw_done_d; + cmp(reg_wei_stride, 2); + b(NE, Ldw_np_d); + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, 16); + add(reg_wei2, reg_wei2, 16); + b(Ldw_done_d); + L(Ldw_np_d); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(2).h[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); @@ -227,6 +236,7 @@ void JitConv3DKernelF16::generate() { ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(2).h[6], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ldw_done_d); // MAC into v20/v21 fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -389,7 +399,14 @@ void JitConv3DKernelF16::generate() { ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); - // wei lanes + // wei lanes — vector fast path if wei_stride==2 + Label Ldw_np_s, Ldw_done_s; + cmp(reg_wei_stride, 2); + b(NE, Ldw_np_s); + ld1(VReg8H(1), ptr(reg_wei)); + add(reg_wei, reg_wei, 16); + b(Ldw_done_s); + L(Ldw_np_s); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); @@ -398,6 +415,7 @@ void JitConv3DKernelF16::generate() { ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ldw_done_s); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); sub(reg_reps, reg_reps, 1); From 32487512b424bbd706590219a75bd642fe9ea7cd Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 16 Oct 2025 16:19:29 +0200 Subject: [PATCH 03/20] Add FP32 support to AArch64 JIT-based 3D Deconvolution and Convolution Executors --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 14 +- .../nodes/executors/aarch64/jit_conv3d.cpp | 305 +++++++++- .../nodes/executors/aarch64/jit_conv3d.hpp | 12 + .../executors/aarch64/jit_conv3d_f32.cpp | 555 ++++++++++++++++++ .../executors/aarch64/jit_conv3d_f32.hpp | 88 +++ .../nodes/executors/aarch64/jit_deconv3d.cpp | 330 +++++++++-- .../nodes/executors/aarch64/jit_deconv3d.hpp | 23 +- .../executors/convolution_implementations.cpp | 5 +- 8 files changed, 1277 insertions(+), 55 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index ba3bed2e65ea66..941d255bca1e9c 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -1296,7 +1296,7 @@ bool Deconvolution::canFuseBias() const { } void Deconvolution::initSupportedPrimitiveDescriptors() { - // Prefer AArch64 JIT deconv for 5D FP16 on ARM64 regardless of ACL + // Prefer AArch64 JIT deconv for 5D FP16/FP32 on ARM64 regardless of ACL #if defined(OPENVINO_ARCH_ARM64) { const auto rank = getInputShapeAtPort(0).getRank(); @@ -1304,7 +1304,10 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && getOriginalInputPrecisionAtPort(1) == ov::element::f16 && getOriginalOutputPrecisionAtPort(0) == ov::element::f16; - if (is5D && fp16_ok) { + const bool fp32_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f32 && + getOriginalInputPrecisionAtPort(1) == ov::element::f32 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f32; + if (is5D && (fp16_ok || fp32_ok)) { auto [inDims, outDims] = makeDummyInOutShape(); auto tmpInShape = Shape(inDims); auto tmpOutShape = Shape(outDims); @@ -1344,7 +1347,7 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { } } #endif - // If ACL path is not selected, try AArch64 JIT factory for 5D FP16 + // If ACL path is not selected, try AArch64 JIT factory for 5D FP16/FP32 if (!useACL) { #if defined(OPENVINO_ARCH_ARM64) const auto rank = getInputShapeAtPort(0).getRank(); @@ -1352,7 +1355,10 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && getOriginalInputPrecisionAtPort(1) == ov::element::f16 && getOriginalOutputPrecisionAtPort(0) == ov::element::f16; - if (is5D && fp16_ok) { + const bool fp32_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f32 && + getOriginalInputPrecisionAtPort(1) == ov::element::f32 && + getOriginalOutputPrecisionAtPort(0) == ov::element::f32; + if (is5D && (fp16_ok || fp32_ok)) { auto [inDims, outDims] = makeDummyInOutShape(); auto tmpInShape = Shape(inDims); auto tmpOutShape = Shape(outDims); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 4b6a7f1ba3fb1f..cca73a81a6ecff 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -1109,9 +1109,21 @@ JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, : m_attrs(attrs), m_memory(memory) { (void)context; m_threadsNum = static_cast(parallel_get_max_threads()); - // Initialize minimal inner-product kernel (scaffold) - m_ip_kernel = std::make_unique(); - m_ip_kernel->create_ker(); + // Decide precision from src tensor + ov::element::Type sp = ov::element::dynamic; + auto it = memory.find(ARG_SRC); + if (it != memory.end() && it->second && it->second->getDescPtr()) { + sp = it->second->getDescPtr()->getPrecision(); + } + m_is_fp32 = (sp == ov::element::f32); + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + // default to fp16 + m_ip_kernel = std::make_unique(); + m_ip_kernel->create_ker(); + } // Extract optional PReLU post-op (per-tensor or per-channel), but keep disabled by default. for (const auto& po : m_attrs.postOps) { @@ -1126,7 +1138,7 @@ JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, } bool JitConv3DExecutor::supports(const ConvConfig& cfg) { - // Require 5D NCDHW, FP16 src/wei/dst, group=1, no dilation, stride 1 or 2 + // Require 5D NCDHW, FP16/FP32 src/wei/dst, group=1, no dilation, stride 1 or 2 if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) return false; @@ -1138,7 +1150,9 @@ bool JitConv3DExecutor::supports(const ConvConfig& cfg) { const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); - if (!(sp == ov::element::f16 && wp == ov::element::f16 && dp == ov::element::f16)) return false; + const bool f16_ok = (sp == ov::element::f16 && wp == ov::element::f16 && dp == ov::element::f16); + const bool f32_ok = (sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32); + if (!(f16_ok || f32_ok)) return false; // group == 1: weights rank==5 (no groups) if (w.getRank() != 5) return false; @@ -1458,7 +1472,286 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { } void JitConv3DExecutor::execute(const MemoryArgs& memory) { - run_naive_fp16(memory); + if (m_is_fp32) run_naive_fp32(memory); else run_naive_fp16(memory); +} + +void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { + if (m_wei_packed_ready_f32) return; + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + if (srcDims.size() != 5 || weiDims.size() != 5) return; + const size_t C = srcDims[1]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_C_f32 = (C + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_C_f32; + m_wei_packed_f32.assign(total, 0.0f); + const float* wsrc = reinterpret_cast(wei->getData()); + + auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + const size_t blk = c / 4; + const size_t lane = c % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t c = 0; c < C; ++c) { + m_wei_packed_f32[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f32 = true; +} + +void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + auto dst = memory.at(ARG_DST); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); + + const size_t N = srcDims[0]; + const size_t C = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; + const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; + const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; + + const ptrdiff_t PD0 = m_attrs.paddingL.size() > 0 ? m_attrs.paddingL[0] : 0; + const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; + const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; + + const float* src_p = reinterpret_cast(src->getData()); + const float* wei_p = reinterpret_cast(wei->getData()); + float* dst_p = reinterpret_cast(dst->getData()); + + auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * C + c) * ID + z) * IH + y) * IW + x; + }; + auto index_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + + const size_t src_c_stride_elems = ID * IH * IW; // elements between channels + const size_t wei_c_stride_elems = KD * KH * KW; // elements between weight channels + + ensure_weights_packed_f32(memory); + + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = std::min(oc0 + 1, OC); + const size_t oc2 = std::min(oc0 + 2, OC); + const size_t oc3 = std::min(oc0 + 3, OC); + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + + for (size_t od = 0; od < OD; ++od) { + const ptrdiff_t iz0 = static_cast(od) * static_cast(SD) - PD0; + for (size_t oh = 0; oh < OH; ++oh) { + const ptrdiff_t iy0 = static_cast(oh) * static_cast(SH) - PH0; + for (size_t ow = 0; ow < OW; ++ow) { + const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; + + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, + static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, + static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, + static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy = static_cast(iy0 + ky); + const size_t ix = static_cast(ix0 + kx_lo); + const size_t s_base = index_src(n, 0, iz, iy, ix); + + if (m_wei_packed_ready_f32) { + // pair 0 + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + } else { + // generic path: non-packed weights + const size_t w0 = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w1 = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + a.wei = wei_p + w0; + if (has_oc1) a.wei2 = wei_p + w1; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = sizeof(float); + (*m_ip_kernel_f32)(&a); + if (has_oc2) { + const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a2{}; + a2.src = src_p + s_base; + a2.src_stride = a.src_stride; + a2.src_blk_stride = a.src_blk_stride; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = a.repeats; + a2.tail = a.tail; + a2.kw_cnt = a.kw_cnt; + a2.src_dx = a.src_dx; + a2.wei = wei_p + w2; + if (has_oc3) a2.wei2 = wei_p + w3; + a2.wei_stride = a.wei_stride; + a2.wei_blk_stride = a.wei_blk_stride; + a2.wei_dx = a.wei_dx; + (*m_ip_kernel_f32)(&a2); + } + } + } + } + } + } else { + // generic spatial stride path (host loops over all taps) + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t iz = iz0 + static_cast(kz); + if (iz < 0 || iz >= static_cast(ID)) continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy = iy0 + static_cast(ky); + if (iy < 0 || iy >= static_cast(IH)) continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix = ix0 + static_cast(kx); + if (ix < 0 || ix >= static_cast(IW)) continue; + const size_t s_base = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + // pair 0 + { + const size_t w0 = index_wei(oc0, 0, kz, ky, kx); + const size_t w1 = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w0; + if (has_oc1) a.wei2 = wei_p + w1; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + const size_t w2 = index_wei(oc2, 0, kz, ky, kx); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w2; + if (has_oc3) a.wei2 = wei_p + w3; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + } + } + } + } + } + + // Store + dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; + if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; + if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; + if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; + } + } + } + }); } } // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index 5941b378eb2f1d..08eabf13c39e6a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -16,6 +16,8 @@ // Xbyak AArch64 JIT #include +// FP32 kernel +#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" namespace ov::intel_cpu { @@ -83,18 +85,28 @@ class JitConv3DExecutor : public Executor { // Simple reference fallback (parallelized) using FP16 data; correctness-first void run_naive_fp16(const MemoryArgs& memory); void ensure_weights_packed(const MemoryArgs& memory); + // FP32 path + void run_naive_fp32(const MemoryArgs& memory); + void ensure_weights_packed_f32(const MemoryArgs& memory); // Minimal inner-product kernel (fp16 x fp16 -> f32 accumulation) std::unique_ptr m_ip_kernel; + // Minimal inner-product kernel (fp32 x fp32 -> f32 accumulation) + std::unique_ptr m_ip_kernel_f32; ConvAttrs m_attrs; MemoryArgs m_memory; size_t m_threadsNum{0}; + bool m_is_fp32{false}; // Packed weights: layout [OC, KD, KH, KW, Ct] where Ct is 8-lane channel tiles std::vector m_wei_packed; bool m_wei_packed_ready{false}; size_t m_padded_C{0}; + // FP32 packed weights: [OC, KD, KH, KW, Ct=4] + std::vector m_wei_packed_f32; + bool m_wei_packed_ready_f32{false}; + size_t m_padded_C_f32{0}; // Optional fused PReLU (per-tensor or per-channel). Extracted from attrs.postOps. bool m_has_prelu{false}; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp new file mode 100644 index 00000000000000..fbd31aa3ed1ccc --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -0,0 +1,555 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" + +#include +#include +#include +#include +#include +#include + +#include "cpu_memory.h" +#include "memory_desc/cpu_memory_desc.h" +#include "nodes/executors/implementation_utils.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" +#include "utils/general_utils.h" +// helper for jit_kernel_cast +#include "utils/cpu_utils.hpp" + +using namespace dnnl::impl::cpu::aarch64; + +namespace ov::intel_cpu { + +// --------------------------- JIT kernel (FP32) --------------------------- +void JitConv3DKernelF32::create_ker() { + jit_generator::create_kernel(); + ker_ = jit_kernel_cast(jit_ker()); +} +void JitConv3DKernelF32::generate() { + using namespace Xbyak_aarch64; + + const XReg reg_args = abi_param1; // x0 + + const XReg reg_src = x1; // const float* src + const XReg reg_wei = x2; // const float* wei + const XReg reg_wei2 = x3; // const float* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats (C/4) + const XReg reg_tail = x5; // size_t tail (C%4) + const XReg reg_src_stride = x6; // bytes between channels + const XReg reg_wei_stride = x7; // bytes between channels + const XReg reg_src_blk_stride = x8; // bytes between successive 4-ch blocks + const XReg reg_wei_blk_stride = x9; // bytes between successive 4-ch blocks + const XReg reg_acc = x10; // float* acc + const XReg reg_acc2 = x11; // float* acc2 (optional) + const XReg reg_kw_cnt = x12; // taps along W + const XReg reg_src_dx = x13; // bytes to step src base per kx + const XReg reg_wei_dx = x14; // bytes to step wei base per kx + + // Load args by struct offsets (see jit_conv3d_f32_call_args) + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 24)); + ldr(reg_tail, ptr(reg_args, 32)); + ldr(reg_src_stride, ptr(reg_args, 40)); + ldr(reg_wei_stride, ptr(reg_args, 48)); + ldr(reg_src_blk_stride, ptr(reg_args, 56)); + ldr(reg_wei_blk_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 72)); + ldr(reg_acc2, ptr(reg_args, 80)); + ldr(reg_kw_cnt, ptr(reg_args, 88)); + ldr(reg_src_dx, ptr(reg_args, 96)); + ldr(reg_wei_dx, ptr(reg_args, 104)); + + // Work registers for base pointers per kx + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; // avoid x18 + + Label Lsingle, Ldone; + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + + cbz(reg_acc2, Lsingle); + b(Ldual_kx); + + // ---------------- Dual-OC with in-kernel kx loop ---------------- + L(Ldual_kx); + // accumulators v20 (oc0), v21 (oc1) + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + + // Save bases + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + // Treat kw_cnt==0 as 1 + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + + // kx loop + L(Lkx_d); + // Reset per-kx pointers and repeats + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + + // repeats loop over channel tiles of 4 + Label Lrep_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + // src lanes -> v0.s[0..3] + ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + // wei lanes: vector fast path if stride==4 bytes + Label Lw_np_d, Lw_done_d; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_d); + ld1(VReg4S(1), ptr(reg_wei)); + ld1(VReg4S(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + b(Lw_done_d); + L(Lw_np_d); + ld1(VReg(1).s[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); + L(Lw_done_d); + // MAC + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); + + // Tail per kx + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); ld1(VReg(1).s[0], ptr(reg_wei)); ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[1], ptr(reg_src)); ld1(VReg(1).s[1], ptr(reg_wei)); ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[2], ptr(reg_src)); ld1(VReg(1).s[2], ptr(reg_wei)); ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[3], ptr(reg_src)); ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + // advance bases to next kx + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + // horizontal add and store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + // ---------------- Single-OC with in-kernel kx loop ---------------- + L(Lsingle); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + cbnz(reg_kw_cnt, Lsingle_kx); + mov(reg_kw_cnt, 1); + + L(Lsingle_kx); + // Reset per-kx pointers and repeats + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + + Label Lrep_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + // src lanes + ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + // wei lanes: vector fast path if stride==4 + Label Lw_np_s, Lw_done_s; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_s); + ld1(VReg4S(1), ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lw_done_s); + L(Lw_np_s); + ld1(VReg(1).s[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + L(Lw_done_s); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + // Tail single + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[1], ptr(reg_src)); ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[2], ptr(reg_src)); ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[3], ptr(reg_src)); ld1(VReg(1).s[3], ptr(reg_wei)); + L(Ltail_done_s_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + + // advance to next kx + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lsingle_kx); + + // reduce and store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + b(Ldone); + + L(Ldone); + ret(); +} + +// --------------------------- Executor (FP32) --------------------------- +JitConv3DExecutorF32::JitConv3DExecutorF32(const ConvAttrs& attrs, + const MemoryArgs& memory, + const ExecutorContext::CPtr& /*context*/) : m_attrs(attrs) { + m_memory = memory; + m_ip_kernel = std::make_unique(); + m_ip_kernel->create_ker(); +} + +bool JitConv3DExecutorF32::supports(const ConvConfig& cfg) { + if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; + if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) return false; + const auto& s = cfg.descs.at(ARG_SRC)->getShape(); + const auto& w = cfg.descs.at(ARG_WEI)->getShape(); + const auto& d = cfg.descs.at(ARG_DST)->getShape(); + if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) return false; + const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); + const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); + const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); + if (!(sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32)) return false; + if (w.getRank() != 5) return false; // groups unsupported here + for (auto v : cfg.attrs.dilation) { if (v != 0) return false; } + for (auto v : cfg.attrs.stride) { if (!(v == 1 || v == 2)) return false; } + return true; +} + +void JitConv3DExecutorF32::ensure_weights_packed(const MemoryArgs& memory) { + if (m_wei_packed_ready) return; + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + if (srcDims.size() != 5 || weiDims.size() != 5) return; + const size_t C = srcDims[1]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_C = (C + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_C; + m_wei_packed.assign(total, 0.0f); + const float* wsrc = reinterpret_cast(wei->getData()); + + auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t blk = c / 4; + const size_t lane = c % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t c = 0; c < C; ++c) { + m_wei_packed[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready = true; +} + +void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { + auto src = memory.at(ARG_SRC); + auto wei = memory.at(ARG_WEI); + auto dst = memory.at(ARG_DST); + const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); + const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); + const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); + + const size_t N = srcDims[0]; + const size_t C = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[0]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; + const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; + const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; + + const ptrdiff_t PD0 = m_attrs.paddingL.size() > 0 ? m_attrs.paddingL[0] : 0; + const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; + const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; + + const float* src_p = reinterpret_cast(src->getData()); + const float* wei_p = reinterpret_cast(wei->getData()); + float* dst_p = reinterpret_cast(dst->getData()); + + auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * C + c) * ID + z) * IH + y) * IW + x; + }; + auto index_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { + return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + }; + + const size_t src_c_stride_elems = ID * IH * IW; // elements between channels + const size_t wei_c_stride_elems = KD * KH * KW; // elements between weight channels + + ensure_weights_packed(memory); + + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = std::min(oc0 + 1, OC); + const size_t oc2 = std::min(oc0 + 2, OC); + const size_t oc3 = std::min(oc0 + 3, OC); + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + + for (size_t od = 0; od < OD; ++od) { + const ptrdiff_t iz0 = static_cast(od) * static_cast(SD) - PD0; + for (size_t oh = 0; oh < OH; ++oh) { + const ptrdiff_t iy0 = static_cast(oh) * static_cast(SH) - PH0; + for (size_t ow = 0; ow < OW; ++ow) { + const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; + + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, + static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, + static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, + static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy = static_cast(iy0 + ky); + const size_t ix = static_cast(ix0 + kx_lo); + const size_t s_base = index_src(n, 0, iz, iy, ix); + + if (m_wei_packed_ready) { + // dual pairs: (oc0,oc1), (oc2,oc3) + // pair 0 + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei = m_wei_packed.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei2 = m_wei_packed.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C * sizeof(float); + (*m_ip_kernel)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei = m_wei_packed.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + a.wei2 = m_wei_packed.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_C * sizeof(float); + (*m_ip_kernel)(&a); + } + } else { + // generic path: kx loop in kernel, but weights non-packed + const size_t w0 = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w1 = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + a.wei = wei_p + w0; + if (has_oc1) a.wei2 = wei_p + w1; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = sizeof(float); + (*m_ip_kernel)(&a); + + if (has_oc2) { + const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a2{}; + a2.src = src_p + s_base; + a2.src_stride = a.src_stride; + a2.src_blk_stride = a.src_blk_stride; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = a.repeats; + a2.tail = a.tail; + a2.kw_cnt = a.kw_cnt; + a2.src_dx = a.src_dx; + a2.wei = wei_p + w2; + if (has_oc3) a2.wei2 = wei_p + w3; + a2.wei_stride = a.wei_stride; + a2.wei_blk_stride = a.wei_blk_stride; + a2.wei_dx = a.wei_dx; + (*m_ip_kernel)(&a2); + } + } + } + } + } + } else { + // generic spatial stride path (host loops over all taps) + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t iz = iz0 + static_cast(kz); + if (iz < 0 || iz >= static_cast(ID)) continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy = iy0 + static_cast(ky); + if (iy < 0 || iy >= static_cast(IH)) continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix = ix0 + static_cast(kx); + if (ix < 0 || ix >= static_cast(IW)) continue; + const size_t s_base = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + // pair 0 + { + const size_t w0 = index_wei(oc0, 0, kz, ky, kx); + const size_t w1 = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w0; + if (has_oc1) a.wei2 = wei_p + w1; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel)(&a); + } + // pair 1 + if (has_oc2) { + const size_t w2 = index_wei(oc2, 0, kz, ky, kx); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 4; + a.tail = C % 4; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w2; + if (has_oc3) a.wei2 = wei_p + w3; + a.wei_stride = wei_c_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel)(&a); + } + } + } + } + } + + // Store + dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; + if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; + if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; + if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; + } + } + } + }); +} + +void JitConv3DExecutorF32::execute(const MemoryArgs& memory) { + run_naive_fp32(memory); +} + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp new file mode 100644 index 00000000000000..54cdc0b07b9ff2 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp @@ -0,0 +1,88 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "nodes/executors/convolution_config.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" + +// Xbyak AArch64 JIT +#include + +namespace ov::intel_cpu { + +struct jit_conv3d_f32_call_args { + const float* src; // f32 base ptr + const float* wei; // f32 base ptr (oc0) + const float* wei2; // optional second oc f32 base ptr (can be null) + size_t repeats; // number of full 4-channel blocks + size_t tail; // remaining channels (< 4) + size_t src_stride; // stride between channels in bytes + size_t wei_stride; // stride between channels in bytes + size_t src_blk_stride; // stride between successive 4-channel blocks in bytes + size_t wei_blk_stride; // stride between successive 4-channel blocks in bytes + float* acc; // f32 accumulator + float* acc2; // optional second f32 accumulator (can be null) + size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dx; // bytes to advance src base between successive kx taps + size_t wei_dx; // bytes to advance weights base between successive kx taps +}; + +class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF32) + using jit_fn = void (*)(const jit_conv3d_f32_call_args*); + + JitConv3DKernelF32() = default; + + void create_ker(); + inline void operator()(const jit_conv3d_f32_call_args* p) const { ker_(p); } + +private: + void generate() override; + + jit_fn ker_{nullptr}; +}; + +// AArch64 JIT Convolution (FP32) executor for 3D conv (NCDHW) +class JitConv3DExecutorF32 : public Executor { +public: + JitConv3DExecutorF32(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); + + bool update(const MemoryArgs& memory) override { + m_memory = memory; + return true; + } + void execute(const MemoryArgs& memory) override; + void execute() override { execute(m_memory); } + void exec([[maybe_unused]] const std::vector& src, + [[maybe_unused]] const std::vector& dst) override {} + + [[nodiscard]] impl_desc_type implType() const override { return impl_desc_type::jit_asimd; } + + static bool supports(const ConvConfig& cfg); + +private: + void run_naive_fp32(const MemoryArgs& memory); + void ensure_weights_packed(const MemoryArgs& memory); + + std::unique_ptr m_ip_kernel; + + ConvAttrs m_attrs; + MemoryArgs m_memory; + + std::vector m_wei_packed; // [OC, KD, KH, KW, Ct=4] + bool m_wei_packed_ready{false}; + size_t m_padded_C{0}; +}; + +using JitConv3DExecutorF32Ptr = std::shared_ptr; + +} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 5cae3d6be3c1c5..9a7a107f89d6a1 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -30,30 +30,37 @@ bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, deconvAttrs = attrs; m_srcDescs = srcDescs; m_dstDescs = dstDescs; - // Initialize AArch64 ip kernel (fp16 x fp16 -> f32 accum) - m_ip_kernel = std::make_unique(); - m_ip_kernel->create_ker(); + // Choose kernel by precision + const auto prec = m_srcDescs[0]->getPrecision(); + m_is_fp32 = (prec == ov::element::f32); + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + m_ip_kernel_f16 = std::make_unique(); + m_ip_kernel_f16->create_ker(); + } return true; } -void JitDeconv3DExecutor::ensure_weights_packed(const std::vector& src) { - if (m_wei_packed_ready) return; +void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector& src) { + if (m_wei_packed_ready_f16) return; // src[1] holds weights for deconv with shape [IC, OC, KD, KH, KW] const auto& weiDims = src[1]->getStaticDims(); if (weiDims.size() != 5) return; const size_t IC = weiDims[0]; const size_t OC = weiDims[1]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - m_padded_IC = (IC + 7) / 8 * 8; - const size_t total = OC * KD * KH * KW * m_padded_IC; - m_wei_packed.assign(total, static_cast(0)); + m_padded_IC_f16 = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_f16.assign(total, static_cast(0)); const uint16_t* wsrc = reinterpret_cast(src[1]->getData()); auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; const size_t blk = ic / 8; const size_t lane = ic % 8; return base + blk * 8 + lane; @@ -64,18 +71,63 @@ void JitDeconv3DExecutor::ensure_weights_packed(const std::vector& s for (size_t ky = 0; ky < KH; ++ky) { for (size_t kx = 0; kx < KW; ++kx) { for (size_t ic = 0; ic < IC; ++ic) { - m_wei_packed[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + m_wei_packed_f16[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f16 = true; +} + +void JitDeconv3DExecutor::ensure_weights_packed_f32(const std::vector& src) { + if (m_wei_packed_ready_f32) return; + const auto& weiDims = src[1]->getStaticDims(); + if (weiDims.size() != 5) return; + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f32 = (IC + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_f32.assign(total, 0.0f); + const float* wsrc = reinterpret_cast(src[1]->getData()); + + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t blk = ic / 4; + const size_t lane = ic % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_f32[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; } } } } } - m_wei_packed_ready = true; + m_wei_packed_ready_f32 = true; } void JitDeconv3DExecutor::exec(const std::vector& src, const std::vector& dst, const void* /*post_ops_data_*/) { + if (m_is_fp32) { + exec_fp32(src, dst); + } else { + exec_fp16(src, dst); + } +} + +void JitDeconv3DExecutor::exec_fp16(const std::vector& src, + const std::vector& dst) { // NCDHW, fp16: compute each output pixel (n, oc, od, oh, ow) as a sum over (ic, kz, ky, kx) const auto& srcDims = src[0]->getStaticDims(); const auto& weiDims = src[1]->getStaticDims(); @@ -116,7 +168,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, const size_t src_c_stride_elems = ID * IH * IW; const size_t wei_ic_stride_elems = OC * KD * KH * KW; - ensure_weights_packed(src); + ensure_weights_packed_f16(src); ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { const size_t oc0 = oc_quad * 4; const size_t oc1 = oc0 + 1; @@ -153,7 +205,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, const size_t src_y_off = ihh * IW; size_t s_base_row = n_base + src_z_off + src_y_off; const size_t kw_count = static_cast(kx_hi - kx_lo + 1); - if (m_wei_packed_ready) { + if (m_wei_packed_ready_f16) { const size_t s_base0 = s_base_row + static_cast(txd - kx_lo); // Compute packed bases for ky_lo size_t pack_base_z0 = (oc0 * KD + static_cast(kz)) * KH; @@ -164,10 +216,10 @@ void JitDeconv3DExecutor::exec(const std::vector& src, size_t pack_base_y1 = has_oc1 ? (pack_base_z1 + static_cast(ky_lo)) * KW : 0; size_t pack_base_y2 = has_oc2 ? (pack_base_z2 + static_cast(ky_lo)) * KW : 0; size_t pack_base_y3 = has_oc3 ? (pack_base_z3 + static_cast(ky_lo)) * KW : 0; - const size_t pack_base0 = (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC; - const size_t pack_base1 = has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC : 0; - const size_t pack_base2 = has_oc2 ? (pack_base_y2 + static_cast(kx_lo)) * m_padded_IC : 0; - const size_t pack_base3 = has_oc3 ? (pack_base_y3 + static_cast(kx_lo)) * m_padded_IC : 0; + const size_t pack_base0 = (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t pack_base1 = has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + const size_t pack_base2 = has_oc2 ? (pack_base_y2 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + const size_t pack_base3 = has_oc3 ? (pack_base_y3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; jit_conv3d_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -181,14 +233,14 @@ void JitDeconv3DExecutor::exec(const std::vector& src, a.kh_cnt = kh_count; a.src_dx = sizeof(uint16_t); a.src_dy = IW * sizeof(uint16_t); - a.wei = m_wei_packed.data() + pack_base0; - if (has_oc1) a.wei2 = m_wei_packed.data() + pack_base1; + a.wei = m_wei_packed_f16.data() + pack_base0; + if (has_oc1) a.wei2 = m_wei_packed_f16.data() + pack_base1; // oc2/oc3 handled in a follow-up dual call a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_IC * sizeof(uint16_t); - a.wei_dy = KW * m_padded_IC * sizeof(uint16_t); - (*m_ip_kernel)(&a); + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); } else { // Generic ky+kx loops (not packed) for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { @@ -213,7 +265,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, if (has_oc1) a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - (*m_ip_kernel)(&a); + (*m_ip_kernel_f16)(&a); } // pair 1 if (has_oc2) { @@ -231,7 +283,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, if (has_oc3) a.wei2 = wei_p + w_base3; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - (*m_ip_kernel)(&a); + (*m_ip_kernel_f16)(&a); } } } @@ -271,12 +323,12 @@ void JitDeconv3DExecutor::exec(const std::vector& src, a.acc2 = has_oc1 ? &acc1 : nullptr; a.repeats = IC / 8; a.tail = IC % 8; - if (m_wei_packed_ready) { - const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; - a.wei = m_wei_packed.data() + pack_base0; + if (m_wei_packed_ready_f16) { + const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + a.wei = m_wei_packed_f16.data() + pack_base0; if (has_oc1) { - const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC; - a.wei2 = m_wei_packed.data() + pack_base1; + const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + a.wei2 = m_wei_packed_f16.data() + pack_base1; } a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; @@ -286,7 +338,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; } - (*m_ip_kernel)(&a); + (*m_ip_kernel_f16)(&a); } } } @@ -319,21 +371,225 @@ void JitDeconv3DExecutor::exec(const std::vector& src, }); } +void JitDeconv3DExecutor::exec_fp32(const std::vector& src, + const std::vector& dst) { + // NCDHW, f32 + const auto& srcDims = src[0]->getStaticDims(); + const auto& weiDims = src[1]->getStaticDims(); + const auto& dstDims = dst[0]->getStaticDims(); + + const size_t N = srcDims[0]; + const size_t IC = srcDims[1]; + const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; + + const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; + const size_t SH = deconvAttrs.stride.size() > 1 ? static_cast(deconvAttrs.stride[1]) : 1; + const size_t SW = deconvAttrs.stride.size() > 2 ? static_cast(deconvAttrs.stride[2]) : 1; + + const ptrdiff_t PD0 = deconvAttrs.paddingL.size() > 0 ? deconvAttrs.paddingL[0] : 0; + const ptrdiff_t PH0 = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL[1] : 0; + const ptrdiff_t PW0 = deconvAttrs.paddingL.size() > 2 ? deconvAttrs.paddingL[2] : 0; + + const float* src_p = reinterpret_cast(src[0]->getData()); + const float* wei_p = reinterpret_cast(src[1]->getData()); + float* dst_p = reinterpret_cast(dst[0]->getData()); + + auto idx_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * IC + c) * ID + z) * IH + y) * IW + x; + }; + auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { + return (((n * OC + c) * OD + z) * OH + y) * OW + x; + }; + auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { + return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_ic_stride_elems = OC * KD * KH * KW; + + ensure_weights_packed_f32(src); + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + const size_t oc0 = oc_quad * 4; + const size_t oc1 = std::min(oc0 + 1, OC); + const size_t oc2 = std::min(oc0 + 2, OC); + const size_t oc3 = std::min(oc0 + 3, OC); + const bool has_oc1 = oc1 < OC; + const bool has_oc2 = oc2 < OC; + const bool has_oc3 = oc3 < OC; + + for (size_t od = 0; od < OD; ++od) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow_ = 0; ow_ < OW; ++ow_) { + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + if (SD == 1 && SH == 1 && SW == 1) { + // contiguous tap range in each dimension + const ptrdiff_t tz = static_cast(od) + PD0; + const ptrdiff_t ty = static_cast(oh) + PH0; + const ptrdiff_t tx = static_cast(ow_) + PW0; + const ptrdiff_t kz_lo = std::max(0, tz - static_cast(ID) + 1); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tz); + const ptrdiff_t ky_lo = std::max(0, ty - static_cast(IH) + 1); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, ty); + const ptrdiff_t kx_lo = std::max(0, tx - static_cast(IW) + 1); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, tx); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(tz - kz); + const size_t ky_base = static_cast(ky_lo); + const size_t iy0 = static_cast(ty - ky_lo); + const size_t ix0 = static_cast(tx - kx_lo); + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy = static_cast(ty - ky); + const size_t ix = ix0; (void)iy0; (void)ky_base; + const size_t s_base = idx_src(n, 0, iz, iy, ix); + + // pair 0 + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = IC / 4; + a.tail = IC % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_IC_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = IC / 4; + a.tail = IC % 4; + a.kw_cnt = kw_count; + a.src_dx = sizeof(float); + const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = m_padded_IC_f32 * sizeof(float); + (*m_ip_kernel_f32)(&a); + } + } + } + } + } else { + // generic stride path with modulus checks + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); + if (SD == 0) continue; + if (iz_num % static_cast(SD) != 0) continue; + const ptrdiff_t id = iz_num / static_cast(SD); + if (id < 0 || id >= static_cast(ID)) continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); + if (SH == 0) continue; + if (iy_num % static_cast(SH) != 0) continue; + const ptrdiff_t ihh = iy_num / static_cast(SH); + if (ihh < 0 || ihh >= static_cast(IH)) continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix_num = static_cast(ow_) + PW0 - static_cast(kx); + if (SW == 0) continue; + if (ix_num % static_cast(SW) != 0) continue; + const ptrdiff_t iww = ix_num / static_cast(SW); + if (iww < 0 || iww >= static_cast(IW)) continue; + + const size_t s_base0 = idx_src(n, 0, static_cast(id), static_cast(ihh), static_cast(iww)); + const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); + const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; + + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = IC / 4; + a.tail = IC % 4; + if (m_wei_packed_ready_f32) { + const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + pack_base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + } else { + a.wei = wei_p + w_base0; + if (has_oc1) a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + } + (*m_ip_kernel_f32)(&a); + } + } + } + } + // Optional bias (support f32 or f16 input bias) + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const float* b = reinterpret_cast(src[2]->getData()); + acc0 += b[oc0]; if (has_oc1) acc1 += b[oc1]; if (has_oc2) acc2 += b[oc2]; if (has_oc3) acc3 += b[oc3]; + } else if (bprec == ov::element::f16) { + const uint16_t* b = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(b[oc0])); + if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); + } + } + + dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0; + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2; + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3; + } + } + } + }); +} + bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, const std::vector& srcDescs, const std::vector& dstDescs) const { - // Support 5D NCDHW, fp16 only for now + // Support 5D NCDHW, fp16 and fp32 if (srcDescs.size() < 2 || dstDescs.empty()) return false; if (srcDescs[0]->getShape().getRank() != 5 || srcDescs[1]->getShape().getRank() != 5 || dstDescs[0]->getShape().getRank() != 5) { return false; } - const auto prec = srcDescs[0]->getPrecision(); - if (!(prec == ov::element::f16 && srcDescs[1]->getPrecision() == ov::element::f16 && - dstDescs[0]->getPrecision() == ov::element::f16)) { - return false; - } - return true; + const auto s0 = srcDescs[0]->getPrecision(); + const auto s1 = srcDescs[1]->getPrecision(); + const auto d0 = dstDescs[0]->getPrecision(); + const bool fp16_ok = (s0 == ov::element::f16 && s1 == ov::element::f16 && d0 == ov::element::f16); + const bool fp32_ok = (s0 == ov::element::f32 && s1 == ov::element::f32 && d0 == ov::element::f32); + return fp16_ok || fp32_ok; } } // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index a7348f4eac39b7..09ea8fc608cbf0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -9,6 +9,7 @@ #include "nodes/executors/deconv.hpp" #include "nodes/executors/aarch64/jit_conv3d.hpp" +#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" namespace ov::intel_cpu { @@ -31,11 +32,23 @@ class JitDeconv3DExecutor : public DeconvExecutor { private: std::vector m_srcDescs; std::vector m_dstDescs; - std::unique_ptr m_ip_kernel; - std::vector m_wei_packed; - bool m_wei_packed_ready{false}; - size_t m_padded_IC{0}; - void ensure_weights_packed(const std::vector& src); + // kernels + std::unique_ptr m_ip_kernel_f16; + std::unique_ptr m_ip_kernel_f32; + bool m_is_fp32{false}; + + // packed weights + std::vector m_wei_packed_f16; + std::vector m_wei_packed_f32; + bool m_wei_packed_ready_f16{false}; + bool m_wei_packed_ready_f32{false}; + size_t m_padded_IC_f16{0}; + size_t m_padded_IC_f32{0}; + + void ensure_weights_packed_f16(const std::vector& src); + void ensure_weights_packed_f32(const std::vector& src); + void exec_fp16(const std::vector& src, const std::vector& dst); + void exec_fp32(const std::vector& src, const std::vector& dst); }; class AArch64JitDeconvExecutorBuilder : public DeconvExecutorBuilder { diff --git a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp index 1c088fe0478ae9..9d6c2b5355b207 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp @@ -105,12 +105,11 @@ template <> const std::vector>& getImplementations() { static const std::vector> convolutionImplementations { OV_CPU_INSTANCE_ARM64( - "convolution_jit_aarch64_3d_fp16_ncsp", ExecutorType::Jit, OperationType::Convolution, - // supports: prefer our AArch64 JIT whenever attrs/shapes permit + "convolution_jit_aarch64_3d_ncsp", ExecutorType::Jit, OperationType::Convolution, [](const ConvConfig& config, [[maybe_unused]] const MemoryFormatFilter& memoryFormatFilter) -> bool { return JitConv3DExecutor::supports(config); }, - // Ask for plain ncsp layouts to avoid being filtered out + // Request plain ncsp layouts CreateOptimalConfigDefault{{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}}, AcceptsAnyShape, CreateDefault{} From b9cee944991626bf6492493d63bd35fff17753a1 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Sun, 19 Oct 2025 22:40:51 +0200 Subject: [PATCH 04/20] Refactor AArch64 JIT 3D Deconvolution and Convolution Executors to use `parallel_for3d` for better clarity and efficiency. Adjust FP32 to FP16 conversions for consistency. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 36 +++++++++---------- .../nodes/executors/aarch64/jit_deconv3d.cpp | 8 +++-- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index cca73a81a6ecff..3e9e2403cc3f82 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -36,9 +36,7 @@ void JitConv3DKernelF16::create_ker() { void JitConv3DKernelF16::generate() { using namespace Xbyak_aarch64; - - // Stable minimal kernel: dual-OC or single-OC accumulation over C in 8-lane blocks + tail. - // Avoid callee-saved registers and any in-kernel spatial loops to ensure ABI safety on macOS arm64. + // Minimal stable kernel (dual-OC, in-kernel kx loop) { const XReg reg_args = abi_param1; // x0 @@ -458,7 +456,6 @@ void JitConv3DKernelF16::generate() { L(Ldone); ret(); } - return; // abi_param1 -> args pointer const XReg reg_args = abi_param1; @@ -516,9 +513,8 @@ void JitConv3DKernelF16::generate() { ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) - // Force single-ky iteration inside kernel to simplify and ensure stability + // Force single-ky iteration and disable quad-OC for stability on macOS arm64 mov(reg_kh_cnt, 1); - // Disable quad-OC path for stability (use dual or single) eor(reg_acc4, reg_acc4, reg_acc4); Label Lsingle, Lend_all; @@ -1209,7 +1205,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { // Prepare packed weights once ensure_weights_packed(memory); - ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + auto worker = [&](size_t n, size_t oc_quad, size_t od) { const size_t oc0 = oc_quad * 4; const size_t oc1 = oc0 + 1; const size_t oc2 = oc0 + 2; @@ -1217,12 +1213,11 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { const bool has_oc1 = oc1 < OC; const bool has_oc2 = oc2 < OC; const bool has_oc3 = oc3 < OC; - for (size_t od = 0; od < OD; ++od) { - const int64_t iz0 = static_cast(od * SD) - static_cast(PD0); - for (size_t oh = 0; oh < OH; ++oh) { - const int64_t iy0 = static_cast(oh * SH) - static_cast(PH0); - for (size_t ow = 0; ow < OW; ++ow) { - const int64_t ix0 = static_cast(ow * SW) - static_cast(PW0); + const int64_t iz0 = static_cast(od * SD) - static_cast(PD0); + for (size_t oh = 0; oh < OH; ++oh) { + const int64_t iy0 = static_cast(oh * SH) - static_cast(PH0); + for (size_t ow = 0; ow < OW; ++ow) { + const int64_t ix0 = static_cast(ow * SW) - static_cast(PW0); float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; @@ -1467,8 +1462,9 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); } } - } - }); + }; + + ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); } void JitConv3DExecutor::execute(const MemoryArgs& memory) { @@ -1743,11 +1739,11 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { } } - // Store - dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; - if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; - if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; - if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; + // Store (convert FP32 accumulators to FP16 bits) + dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); + if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); } } } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 9a7a107f89d6a1..dd9fd695d73416 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -169,7 +169,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const size_t wei_ic_stride_elems = OC * KD * KH * KW; ensure_weights_packed_f16(src); - ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { + auto worker = [&](size_t n, size_t oc_quad, size_t od) { const size_t oc0 = oc_quad * 4; const size_t oc1 = oc0 + 1; const size_t oc2 = oc0 + 2; @@ -178,7 +178,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const bool has_oc2 = oc2 < OC; const bool has_oc3 = oc3 < OC; const size_t n_base = n * IC * ID * IH * IW; - for (size_t od = 0; od < OD; ++od) { + { for (size_t oh = 0; oh < OH; ++oh) { for (size_t ow_ = 0; ow_ < OW; ++ow_) { float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; @@ -368,7 +368,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, } } } - }); + }; + + ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); } void JitDeconv3DExecutor::exec_fp32(const std::vector& src, From 6e7c4136d4bc135c16935cd8772a3596a7210f75 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Sun, 19 Oct 2025 23:03:47 +0200 Subject: [PATCH 05/20] Refactor AArch64 JIT 3D Deconvolution Executor for improved readability by reformatting code, removing unused helpers, and ensuring consistent formatting. --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 59 +- .../nodes/executors/aarch64/jit_conv3d.cpp | 1763 +++++++++++------ .../nodes/executors/aarch64/jit_conv3d.hpp | 50 +- .../executors/aarch64/jit_conv3d_f32.cpp | 295 ++- .../executors/aarch64/jit_conv3d_f32.hpp | 12 +- .../nodes/executors/aarch64/jit_deconv3d.cpp | 263 ++- .../nodes/executors/aarch64/jit_deconv3d.hpp | 8 +- .../src/nodes/executors/deconv_list.cpp | 3 +- 8 files changed, 1670 insertions(+), 783 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 941d255bca1e9c..658ab78bf0e959 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -1319,30 +1319,37 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { config.outConfs.resize(getOriginalOutputsNumber()); auto setDesc = [&](size_t port, bool isInput) { - const auto prec = isInput ? getOriginalInputPrecisionAtPort(port) - : getOriginalOutputPrecisionAtPort(port); + const auto prec = + isInput ? getOriginalInputPrecisionAtPort(port) : getOriginalOutputPrecisionAtPort(port); const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); - if (isInput) config.inConfs[port].setMemDesc(d); - else config.outConfs[port].setMemDesc(d); + if (isInput) + config.inConfs[port].setMemDesc(d); + else + config.outConfs[port].setMemDesc(d); }; setDesc(0, true); setDesc(1, true); - for (size_t i = 2; i < getParentEdges().size(); ++i) setDesc(i, true); + for (size_t i = 2; i < getParentEdges().size(); ++i) + setDesc(i, true); setDesc(0, false); std::vector srcMemoryDescs; srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); - for (size_t i = 1; i < config.inConfs.size(); i++) srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); + for (size_t i = 1; i < config.inConfs.size(); i++) + srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); std::vector dstMemoryDescs; dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); - for (size_t i = 1; i < config.outConfs.size(); i++) dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); - - auto factory = std::make_shared( - deconvAttrs, srcMemoryDescs, dstMemoryDescs, - std::make_shared(context, getImplPriority())); + for (size_t i = 1; i < config.outConfs.size(); i++) + dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); + + auto factory = + std::make_shared(deconvAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); - useACL = true; // reuse factory-based execution path + useACL = true; // reuse factory-based execution path return; } } @@ -1370,27 +1377,35 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { config.outConfs.resize(getOriginalOutputsNumber()); auto setDesc = [&](size_t port, const Shape& shape, bool isInput) { - const auto prec = isInput ? getOriginalInputPrecisionAtPort(port) - : getOriginalOutputPrecisionAtPort(port); + const auto prec = + isInput ? getOriginalInputPrecisionAtPort(port) : getOriginalOutputPrecisionAtPort(port); const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); - if (isInput) config.inConfs[port].setMemDesc(d); else config.outConfs[port].setMemDesc(d); + if (isInput) + config.inConfs[port].setMemDesc(d); + else + config.outConfs[port].setMemDesc(d); }; setDesc(0, tmpInShape, true); setDesc(1, Shape(getInputShapeAtPort(1).getStaticDims()), true); - for (size_t i = 2; i < getParentEdges().size(); ++i) setDesc(i, Shape(getInputShapeAtPort(i).getStaticDims()), true); + for (size_t i = 2; i < getParentEdges().size(); ++i) + setDesc(i, Shape(getInputShapeAtPort(i).getStaticDims()), true); setDesc(0, tmpOutShape, false); std::vector srcMemoryDescs; srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); - for (size_t i = 1; i < config.inConfs.size(); i++) srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); + for (size_t i = 1; i < config.inConfs.size(); i++) + srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); std::vector dstMemoryDescs; dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); - for (size_t i = 1; i < config.outConfs.size(); i++) dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); - - auto factory = std::make_shared( - deconvAttrs, srcMemoryDescs, dstMemoryDescs, - std::make_shared(context, getImplPriority())); + for (size_t i = 1; i < config.outConfs.size(); i++) + dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); + + auto factory = + std::make_shared(deconvAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); return; } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 3e9e2403cc3f82..12a977e4045908 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -16,8 +16,8 @@ #include "nodes/executors/implementation_utils.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" -#include "utils/general_utils.h" #include "openvino/core/type/float16.hpp" +#include "utils/general_utils.h" // helper for jit_kernel_cast #include "utils/cpu_utils.hpp" // no direct NEON intrinsics are used here; we rely on Xbyak_aarch64 @@ -39,17 +39,17 @@ void JitConv3DKernelF16::generate() { // Minimal stable kernel (dual-OC, in-kernel kx loop) { - const XReg reg_args = abi_param1; // x0 + const XReg reg_args = abi_param1; // x0 // Load essential arguments (absolute offsets from jit_conv3d_call_args) - const XReg reg_src = x1; // const uint16_t* src - const XReg reg_wei = x2; // const uint16_t* wei - const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) - const XReg reg_reps = x4; // size_t repeats - const XReg reg_tail = x5; // size_t tail - const XReg reg_src_stride = x6; // size_t src_stride (bytes) - const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) - const XReg reg_acc = x8; // float* acc - const XReg reg_acc2 = x9; // float* acc2 (optional) + const XReg reg_src = x1; // const uint16_t* src + const XReg reg_wei = x2; // const uint16_t* wei + const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats + const XReg reg_tail = x5; // size_t tail + const XReg reg_src_stride = x6; // size_t src_stride (bytes) + const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) + const XReg reg_acc = x8; // float* acc + const XReg reg_acc2 = x9; // float* acc2 (optional) ldr(reg_src, ptr(reg_args, 0)); ldr(reg_wei, ptr(reg_args, 8)); @@ -103,13 +103,20 @@ void JitConv3DKernelF16::generate() { cmp(reg_reps, 0); b(EQ, Ltail_prep_d_kx); // Load src lanes into v0 - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // Load weights for oc0/oc1 (vector fast path if stride==2) Label Lw_np_d, Lw_done_d2; @@ -121,21 +128,36 @@ void JitConv3DKernelF16::generate() { add(reg_wei2, reg_wei2, reg_wei_blk_stride2); b(Lw_done_d2); L(Lw_np_d); - ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[3], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[4], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[5], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[6], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); L(Lw_done_d2); // MAC into accumulators fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -149,29 +171,67 @@ void JitConv3DKernelF16::generate() { eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); eor(VReg16B(2), VReg16B(2), VReg16B(2)); - cmp(reg_tail, 0); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 4); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 5); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 6); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 7); b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); L(Ltail_done_d_kx); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -188,8 +248,12 @@ void JitConv3DKernelF16::generate() { faddp(VReg2S(20), VReg2S(20), VReg2S(20)); faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); b(Ldone); // Dual-OC path: v20 (oc0), v21 (oc1) @@ -201,13 +265,20 @@ void JitConv3DKernelF16::generate() { cmp(reg_reps, 0); b(EQ, Ltail_prep_d); // Load src lanes (v0) - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 Label Ldw_np_d, Ldw_done_d; @@ -219,21 +290,36 @@ void JitConv3DKernelF16::generate() { add(reg_wei2, reg_wei2, 16); b(Ldw_done_d); L(Ldw_np_d); - ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[3], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[4], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[5], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[6], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); L(Ldw_done_d); // MAC into v20/v21 fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -249,29 +335,67 @@ void JitConv3DKernelF16::generate() { eor(VReg16B(1), VReg16B(1), VReg16B(1)); eor(VReg16B(2), VReg16B(2), VReg16B(2)); // lanes 0..7 guarded by tail - cmp(reg_tail, 0); b(LE, Ltail_done_d); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_d); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_d); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_d); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 4); b(LE, Ltail_done_d); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 5); b(LE, Ltail_done_d); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 6); b(LE, Ltail_done_d); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 7); b(LE, Ltail_done_d); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_d); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_d); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_d); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_d); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); L(Ltail_done_d); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -283,8 +407,12 @@ void JitConv3DKernelF16::generate() { faddp(VReg2S(20), VReg2S(20), VReg2S(20)); faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); b(Ldone); // Single-OC path @@ -316,13 +444,20 @@ void JitConv3DKernelF16::generate() { L(Lrep_s_kx); cmp(reg_reps, 0); b(EQ, Ltail_prep_s_kx); - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // weights (vector fast path if stride==2) Label Lw_np_s, Lw_done_s2; @@ -332,13 +467,20 @@ void JitConv3DKernelF16::generate() { add(reg_wei, reg_wei, s_wei_blk_stride2); b(Lw_done_s2); L(Lw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[7], ptr(reg_wei)); L(Lw_done_s2); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -348,29 +490,52 @@ void JitConv3DKernelF16::generate() { L(Ltail_prep_s_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); L(Ltail_done_s_kx); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -382,20 +547,29 @@ void JitConv3DKernelF16::generate() { // reduce/store faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); eor(VReg16B(20), VReg16B(20), VReg16B(20)); Label Lrep_s, Ltail_prep_s, Ltail_done_s; L(Lrep_s); cmp(reg_reps, 0); b(EQ, Ltail_prep_s); // src lanes - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // wei lanes — vector fast path if wei_stride==2 Label Ldw_np_s, Ldw_done_s; @@ -405,13 +579,20 @@ void JitConv3DKernelF16::generate() { add(reg_wei, reg_wei, 16); b(Ldw_done_s); L(Ldw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).h[7], ptr(reg_wei)); L(Ldw_done_s); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -423,35 +604,60 @@ void JitConv3DKernelF16::generate() { L(Ltail_prep_s); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); b(LE, Ltail_done_s); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_s); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_s); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_s); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); b(LE, Ltail_done_s); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); b(LE, Ltail_done_s); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); b(LE, Ltail_done_s); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); b(LE, Ltail_done_s); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_s); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_s); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_s); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_s); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); L(Ltail_done_s); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); L(Ldone); ret(); @@ -463,8 +669,8 @@ void JitConv3DKernelF16::generate() { const XReg reg_src = x1; const XReg reg_wei = x2; const XReg reg_wei2 = x10; - const XReg reg_reps = x3; // number of full 8-lane blocks - const XReg reg_tail = x9; // remaining channels (< 8) + const XReg reg_reps = x3; // number of full 8-lane blocks + const XReg reg_tail = x9; // remaining channels (< 8) const XReg reg_src_stride = x4; const XReg reg_wei_stride = x5; const XReg reg_acc = x6; @@ -483,35 +689,35 @@ void JitConv3DKernelF16::generate() { stp(XReg(29), XReg(30), pre_ptr(sp, -16)); // Load args - ldr(reg_src, ptr(reg_args)); // src - ldr(reg_wei, ptr(reg_args, 8)); // wei - ldr(reg_wei2, ptr(reg_args, 16)); // wei2 (optional) - ldr(reg_wei3, ptr(reg_args, 24)); // wei3 (optional) - ldr(reg_wei4, ptr(reg_args, 32)); // wei4 (optional) - ldr(reg_reps, ptr(reg_args, 40)); // repeats - ldr(reg_tail, ptr(reg_args, 48)); // tail (<= 8) - ldr(reg_src_stride, ptr(reg_args, 56)); // src_stride bytes - ldr(reg_wei_stride, ptr(reg_args, 64)); // wei_stride bytes + ldr(reg_src, ptr(reg_args)); // src + ldr(reg_wei, ptr(reg_args, 8)); // wei + ldr(reg_wei2, ptr(reg_args, 16)); // wei2 (optional) + ldr(reg_wei3, ptr(reg_args, 24)); // wei3 (optional) + ldr(reg_wei4, ptr(reg_args, 32)); // wei4 (optional) + ldr(reg_reps, ptr(reg_args, 40)); // repeats + ldr(reg_tail, ptr(reg_args, 48)); // tail (<= 8) + ldr(reg_src_stride, ptr(reg_args, 56)); // src_stride bytes + ldr(reg_wei_stride, ptr(reg_args, 64)); // wei_stride bytes const XReg reg_src_blk_stride = x7; const XReg reg_wei_blk_stride = x8; - ldr(reg_src_blk_stride, ptr(reg_args, 72)); // src_blk_stride bytes - ldr(reg_wei_blk_stride, ptr(reg_args, 80)); // wei_blk_stride bytes - ldr(reg_acc, ptr(reg_args, 88)); // acc (float*) - ldr(reg_acc2, ptr(reg_args, 96)); // acc2 (float* or 0) + ldr(reg_src_blk_stride, ptr(reg_args, 72)); // src_blk_stride bytes + ldr(reg_wei_blk_stride, ptr(reg_args, 80)); // wei_blk_stride bytes + ldr(reg_acc, ptr(reg_args, 88)); // acc (float*) + ldr(reg_acc2, ptr(reg_args, 96)); // acc2 (float* or 0) const XReg reg_kw_cnt = x12; const XReg reg_src_dx = x13; const XReg reg_wei_dx = x14; - ldr(reg_kw_cnt, ptr(reg_args, 104)); // kw count (for stride=1 fast path); 0 -> disabled - ldr(reg_src_dx, ptr(reg_args, 112)); // src dx step in bytes (x dimension) - ldr(reg_wei_dx, ptr(reg_args, 120)); // wei dx step in bytes (x dimension) - ldr(reg_acc3, ptr(reg_args, 128)); // acc3 (float* or 0) - ldr(reg_acc4, ptr(reg_args, 136)); // acc4 (float* or 0) + ldr(reg_kw_cnt, ptr(reg_args, 104)); // kw count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dx, ptr(reg_args, 112)); // src dx step in bytes (x dimension) + ldr(reg_wei_dx, ptr(reg_args, 120)); // wei dx step in bytes (x dimension) + ldr(reg_acc3, ptr(reg_args, 128)); // acc3 (float* or 0) + ldr(reg_acc4, ptr(reg_args, 136)); // acc4 (float* or 0) const XReg reg_kh_cnt = x26; const XReg reg_src_dy = x27; const XReg reg_wei_dy = x28; - ldr(reg_kh_cnt, ptr(reg_args, 144)); // kh count (for stride=1 fast path); 0 -> disabled - ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) - ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) + ldr(reg_kh_cnt, ptr(reg_args, 144)); // kh count (for stride=1 fast path); 0 -> disabled + ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) + ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) // Force single-ky iteration and disable quad-OC for stability on macOS arm64 mov(reg_kh_cnt, 1); @@ -573,13 +779,20 @@ void JitConv3DKernelF16::generate() { cmp(reg_reps, 0); b(LE, Lq_after_loop); // Load 8 src lanes - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // Load 4 weight vectors ld1(VReg8H(1), ptr(reg_wei)); @@ -610,29 +823,97 @@ void JitConv3DKernelF16::generate() { eor(VReg16B(2), VReg16B(2), VReg16B(2)); eor(VReg16B(3), VReg16B(3), VReg16B(3)); eor(VReg16B(4), VReg16B(4), VReg16B(4)); - cmp(reg_tail, 0); b(LE, Lq_after_fill); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); ld1(VReg(3).h[0], ptr(reg_wei3)); ld1(VReg(4).h[0], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Lq_after_fill); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); ld1(VReg(3).h[1], ptr(reg_wei3)); ld1(VReg(4).h[1], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Lq_after_fill); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); ld1(VReg(3).h[2], ptr(reg_wei3)); ld1(VReg(4).h[2], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Lq_after_fill); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); ld1(VReg(3).h[3], ptr(reg_wei3)); ld1(VReg(4).h[3], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 4); b(LE, Lq_after_fill); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); ld1(VReg(3).h[4], ptr(reg_wei3)); ld1(VReg(4).h[4], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 5); b(LE, Lq_after_fill); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); ld1(VReg(3).h[5], ptr(reg_wei3)); ld1(VReg(4).h[5], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 6); b(LE, Lq_after_fill); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); ld1(VReg(3).h[6], ptr(reg_wei3)); ld1(VReg(4).h[6], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - cmp(reg_tail, 7); b(LE, Lq_after_fill); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); ld1(VReg(3).h[7], ptr(reg_wei3)); ld1(VReg(4).h[7], ptr(reg_wei4)); + cmp(reg_tail, 0); + b(LE, Lq_after_fill); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + ld1(VReg(3).h[0], ptr(reg_wei3)); + ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Lq_after_fill); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + ld1(VReg(3).h[1], ptr(reg_wei3)); + ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Lq_after_fill); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + ld1(VReg(3).h[2], ptr(reg_wei3)); + ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Lq_after_fill); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + ld1(VReg(3).h[3], ptr(reg_wei3)); + ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Lq_after_fill); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + ld1(VReg(3).h[4], ptr(reg_wei3)); + ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Lq_after_fill); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + ld1(VReg(3).h[5], ptr(reg_wei3)); + ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Lq_after_fill); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + ld1(VReg(3).h[6], ptr(reg_wei3)); + ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Lq_after_fill); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(3).h[7], ptr(reg_wei3)); + ld1(VReg(4).h[7], ptr(reg_wei4)); L(Lq_after_fill); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -653,10 +934,14 @@ void JitConv3DKernelF16::generate() { b(GT, Lq_kx_loop); // reduce/store 4 accumulators - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - faddp(VReg4S(22), VReg4S(22), VReg4S(22)); faddp(VReg2S(22), VReg2S(22), VReg2S(22)); - faddp(VReg4S(23), VReg4S(23), VReg4S(23)); faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); + faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); + faddp(VReg2S(23), VReg2S(23), VReg2S(23)); // advance bases for next ky and continue if any add(q_src_base, q_src_base, reg_src_dy); add(q_wei_base, q_wei_base, reg_wei_dy); @@ -667,10 +952,18 @@ void JitConv3DKernelF16::generate() { b(GT, Lq_ky_loop); // After ky loop: reduce/store - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); - ldr(SReg(2), ptr(reg_acc3)); fadd(SReg(2), SReg(2), SReg(22)); str(SReg(2), ptr(reg_acc3)); - ldr(SReg(3), ptr(reg_acc4)); fadd(SReg(3), SReg(3), SReg(23)); str(SReg(3), ptr(reg_acc4)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); + fadd(SReg(2), SReg(2), SReg(22)); + str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); + fadd(SReg(3), SReg(3), SReg(23)); + str(SReg(3), ptr(reg_acc4)); b(Lend_all); // Not-packed fallback for quad: do lane-wise loads for 4 outputs @@ -682,34 +975,110 @@ void JitConv3DKernelF16::generate() { eor(VReg16B(23), VReg16B(23), VReg16B(23)); // single block (not looped here) — rely on host collapsing work // lanes 0..7 - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); ld1(VReg(3).h[0], ptr(reg_wei3)); ld1(VReg(4).h[0], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); ld1(VReg(3).h[1], ptr(reg_wei3)); ld1(VReg(4).h[1], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); ld1(VReg(3).h[2], ptr(reg_wei3)); ld1(VReg(4).h[2], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); ld1(VReg(3).h[3], ptr(reg_wei3)); ld1(VReg(4).h[3], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); ld1(VReg(3).h[4], ptr(reg_wei3)); ld1(VReg(4).h[4], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); ld1(VReg(3).h[5], ptr(reg_wei3)); ld1(VReg(4).h[5], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); ld1(VReg(3).h[6], ptr(reg_wei3)); ld1(VReg(4).h[6], ptr(reg_wei4)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); add(reg_wei3, reg_wei3, reg_wei_stride); add(reg_wei4, reg_wei4, reg_wei_stride); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); ld1(VReg(3).h[7], ptr(reg_wei3)); ld1(VReg(4).h[7], ptr(reg_wei4)); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); - fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + ld1(VReg(3).h[0], ptr(reg_wei3)); + ld1(VReg(4).h[0], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + ld1(VReg(3).h[1], ptr(reg_wei3)); + ld1(VReg(4).h[1], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + ld1(VReg(3).h[2], ptr(reg_wei3)); + ld1(VReg(4).h[2], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + ld1(VReg(3).h[3], ptr(reg_wei3)); + ld1(VReg(4).h[3], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + ld1(VReg(3).h[4], ptr(reg_wei3)); + ld1(VReg(4).h[4], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + ld1(VReg(3).h[5], ptr(reg_wei3)); + ld1(VReg(4).h[5], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + ld1(VReg(3).h[6], ptr(reg_wei3)); + ld1(VReg(4).h[6], ptr(reg_wei4)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + add(reg_wei3, reg_wei3, reg_wei_stride); + add(reg_wei4, reg_wei4, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(3).h[7], ptr(reg_wei3)); + ld1(VReg(4).h[7], ptr(reg_wei4)); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal2(VReg4S(22), VReg4H(0), VReg4H(3)); + fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); + fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); // reduce/store - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - faddp(VReg4S(22), VReg4S(22), VReg4S(22)); faddp(VReg2S(22), VReg2S(22), VReg2S(22)); - faddp(VReg4S(23), VReg4S(23), VReg4S(23)); faddp(VReg2S(23), VReg2S(23), VReg2S(23)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); - ldr(SReg(2), ptr(reg_acc3)); fadd(SReg(2), SReg(2), SReg(22)); str(SReg(2), ptr(reg_acc3)); - ldr(SReg(3), ptr(reg_acc4)); fadd(SReg(3), SReg(3), SReg(23)); str(SReg(3), ptr(reg_acc4)); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + faddp(VReg4S(22), VReg4S(22), VReg4S(22)); + faddp(VReg2S(22), VReg2S(22), VReg2S(22)); + faddp(VReg4S(23), VReg4S(23), VReg4S(23)); + faddp(VReg2S(23), VReg2S(23), VReg2S(23)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + ldr(SReg(2), ptr(reg_acc3)); + fadd(SReg(2), SReg(2), SReg(22)); + str(SReg(2), ptr(reg_acc3)); + ldr(SReg(3), ptr(reg_acc4)); + fadd(SReg(3), SReg(3), SReg(23)); + str(SReg(3), ptr(reg_acc4)); b(Lend_all); } // ---------------- Dual-OC path ---------------- @@ -717,7 +1086,7 @@ void JitConv3DKernelF16::generate() { // Zero FP32 accumulators v20.4s and v21.4s eor(VReg16B(20), VReg16B(20), VReg16B(20)); eor(VReg16B(21), VReg16B(21), VReg16B(21)); - Label Lnp_d, Ld_after_fill, Ld_after_loop, Ld_kx_loop, Ld_after_kx; + Label Lnp_d, Ld_after_fill, Ld_after_loop, Ld_kx_loop, Ld_after_kx; // If not packed, fallback to non-packed path below (no kw-loop optimization) cmp(reg_wei_stride, 2); b(NE, Lnp_d); @@ -751,13 +1120,20 @@ void JitConv3DKernelF16::generate() { cmp(reg_reps, 0); b(LE, Ld_after_loop); // Load 8 src half lanes with strides between channels - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // Load 8 half weights for oc0 and oc1 as vectors ld1(VReg8H(1), ptr(reg_wei)); @@ -779,35 +1155,65 @@ void JitConv3DKernelF16::generate() { eor(VReg16B(2), VReg16B(2), VReg16B(2)); cmp(reg_tail, 0); b(LE, Ld_after_fill); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 1); b(LE, Ld_after_fill); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 2); b(LE, Ld_after_fill); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 3); b(LE, Ld_after_fill); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 4); b(LE, Ld_after_fill); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 5); b(LE, Ld_after_fill); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 6); b(LE, Ld_after_fill); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); cmp(reg_tail, 7); b(LE, Ld_after_fill); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); L(Ld_after_fill); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -831,21 +1237,51 @@ void JitConv3DKernelF16::generate() { // Not-packed path: pairwise lane loads for both wei/wei2 (no kw-loop optimization) L(Lnp_d); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); // Accumulate fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -878,7 +1314,8 @@ void JitConv3DKernelF16::generate() { // Zero FP32 accumulator v20.4s eor(VReg16B(20), VReg16B(20), VReg16B(20)); - Label Lloop2, Lloop, Lafter_loop, Lafter_fill, Lnp1, Lend1, Lnp2, Lend2, Lnot_packed, Lkx_loop_s, Lafter_kx_s, Ls_loop, Ls_after_loop, Ls_after_fill; + Label Lloop2, Lloop, Lafter_loop, Lafter_fill, Lnp1, Lend1, Lnp2, Lend2, Lnot_packed, Lkx_loop_s, Lafter_kx_s, + Ls_loop, Ls_after_loop, Ls_after_fill; // Unrolled-by-2 loop over full 8-lane channel blocks (default single position) b(Lloop); @@ -890,13 +1327,20 @@ void JitConv3DKernelF16::generate() { cmp(reg_wei_stride, 2); b(NE, Lnp1); // packed: src lanes (8), then vector-load weights - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg8H(1), ptr(reg_wei)); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -906,14 +1350,36 @@ void JitConv3DKernelF16::generate() { b(Lend1); L(Lnp1); // not packed: pairwise lanes - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); add(reg_src, reg_src, reg_src_stride); @@ -923,13 +1389,20 @@ void JitConv3DKernelF16::generate() { // Second block cmp(reg_wei_stride, 2); b(NE, Lnp2); - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg8H(1), ptr(reg_wei)); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -938,14 +1411,36 @@ void JitConv3DKernelF16::generate() { add(reg_wei, reg_wei, reg_wei_blk_stride); b(Lend2); L(Lnp2); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); add(reg_src, reg_src, reg_src_stride); @@ -967,13 +1462,20 @@ void JitConv3DKernelF16::generate() { b(NE, Lnot_packed); // Packed path: fill src lanes, vector-load wei - ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // load 8 half wei as one vector ld1(VReg8H(1), ptr(reg_wei)); @@ -988,14 +1490,36 @@ void JitConv3DKernelF16::generate() { // Not-packed path: fill src/wei lanes pairwise and accumulate L(Lnot_packed); - ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); // Widening MAC for both halves fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); @@ -1095,14 +1619,15 @@ void JitConv3DKernelF16::generate() { static inline const uint16_t* ptr_f16(const MemoryPtr& m) { return reinterpret_cast(m->getData()); } -static inline uint16_t* ptr_f16(MemoryPtr& m) { +[[maybe_unused]] static inline uint16_t* ptr_f16(MemoryPtr& m) { return reinterpret_cast(m->getData()); } JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context) - : m_attrs(attrs), m_memory(memory) { + : m_attrs(attrs), + m_memory(memory) { (void)context; m_threadsNum = static_cast(parallel_get_max_threads()); // Decide precision from src tensor @@ -1135,31 +1660,38 @@ JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, bool JitConv3DExecutor::supports(const ConvConfig& cfg) { // Require 5D NCDHW, FP16/FP32 src/wei/dst, group=1, no dilation, stride 1 or 2 - if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; - if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) return false; + if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) + return false; + if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) + return false; const auto& s = cfg.descs.at(ARG_SRC)->getShape(); const auto& w = cfg.descs.at(ARG_WEI)->getShape(); const auto& d = cfg.descs.at(ARG_DST)->getShape(); - if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) return false; + if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) + return false; const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); const bool f16_ok = (sp == ov::element::f16 && wp == ov::element::f16 && dp == ov::element::f16); const bool f32_ok = (sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32); - if (!(f16_ok || f32_ok)) return false; + if (!(f16_ok || f32_ok)) + return false; // group == 1: weights rank==5 (no groups) - if (w.getRank() != 5) return false; + if (w.getRank() != 5) + return false; // dilation == 0 for (auto v : cfg.attrs.dilation) { - if (v != 0) return false; + if (v != 0) + return false; } // stride in [1,2] if set for (auto v : cfg.attrs.stride) { - if (!(v == 1 || v == 2)) return false; + if (!(v == 1 || v == 2)) + return false; } return true; } @@ -1199,7 +1731,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { return (((n * OC + oc) * OD + z) * OH + y) * OW + x; }; auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; // Prepare packed weights once @@ -1219,36 +1751,111 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { for (size_t ow = 0; ow < OW; ++ow) { const int64_t ix0 = static_cast(ow * SW) - static_cast(PW0); - float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; - - const size_t src_c_stride_elems = ID * IH * IW; - const size_t wei_c_stride_elems = KD * KH * KW; - - if (SD == 1 && SH == 1 && SW == 1) { - const ptrdiff_t kz_lo = std::max(0, -iz0); - const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, - static_cast(ID) - 1 - iz0); - const ptrdiff_t ky_lo = std::max(0, -iy0); - const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, - static_cast(IH) - 1 - iy0); - const ptrdiff_t kx_lo = std::max(0, -ix0); - const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, - static_cast(IW) - 1 - ix0); - if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - const size_t kw_count = static_cast(kx_hi - kx_lo + 1); - for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { - const size_t iz = static_cast(iz0 + kz); - const size_t kh_count = static_cast(ky_hi - ky_lo + 1); - const size_t iy = static_cast(iy0 + ky_lo); - const size_t ix = static_cast(ix0 + kx_lo); - const size_t s_base0 = index_src(n, 0, iz, iy, ix); - - if (m_wei_packed_ready) { - // Loop over ky in host; kernel handles kx via kw_cnt - for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy2 = static_cast(iy0 + ky); - const size_t ix2 = static_cast(ix0 + kx_lo); - const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + + const size_t src_c_stride_elems = ID * IH * IW; + const size_t wei_c_stride_elems = KD * KH * KW; + + if (SD == 1 && SH == 1 && SW == 1) { + const ptrdiff_t kz_lo = std::max(0, -iz0); + const ptrdiff_t kz_hi = + std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); + const ptrdiff_t ky_lo = std::max(0, -iy0); + const ptrdiff_t ky_hi = + std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); + const ptrdiff_t kx_lo = std::max(0, -ix0); + const ptrdiff_t kx_hi = + std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { + const size_t iz = static_cast(iz0 + kz); + const size_t iy = static_cast(iy0 + ky_lo); + const size_t ix = static_cast(ix0 + kx_lo); + + if (m_wei_packed_ready) { + // Loop over ky in host; kernel handles kx via kw_cnt + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t pack_base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + + static_cast(kx_lo)) * + m_padded_C; + a.wei = m_wei_packed.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = + (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base1; + } + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_C * sizeof(uint16_t); + (*m_ip_kernel)(&a); + if (has_oc2) { + jit_conv3d_call_args a2{}; + a2.src = src_p + s_base2; + a2.src_stride = a.src_stride; + a2.src_blk_stride = a.src_blk_stride; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = a.repeats; + a2.tail = a.tail; + a2.kw_cnt = a.kw_cnt; + a2.src_dx = a.src_dx; + const size_t pack_base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; + a2.wei = m_wei_packed.data() + pack_base2; + if (has_oc3) { + const size_t pack_base3 = + (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; + a2.wei2 = m_wei_packed.data() + pack_base3; + } + a2.wei_stride = a.wei_stride; + a2.wei_blk_stride = a.wei_blk_stride; + a2.wei_dx = a.wei_dx; + (*m_ip_kernel)(&a2); + } + } + } else { + // Non-packed: keep ky loop outside and issue dual calls + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + const size_t w0_base = index_wei(oc0, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w1_base = has_oc1 ? index_wei(oc1, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; + // pair 0 + { jit_conv3d_call_args a{}; a.src = src_p + s_base2; a.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -1259,101 +1866,66 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { a.tail = C % 8; a.kw_cnt = kw_count; a.src_dx = sizeof(uint16_t); - const size_t pack_base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base1; - } - a.wei_stride = sizeof(uint16_t); + a.wei = wei_p + w0_base; + if (has_oc1) + a.wei2 = wei_p + w1_base; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_C * sizeof(uint16_t); + a.wei_dx = sizeof(uint16_t); (*m_ip_kernel)(&a); - if (has_oc2) { - jit_conv3d_call_args a2{}; - a2.src = src_p + s_base2; - a2.src_stride = a.src_stride; - a2.src_blk_stride = a.src_blk_stride; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = a.repeats; - a2.tail = a.tail; - a2.kw_cnt = a.kw_cnt; - a2.src_dx = a.src_dx; - const size_t pack_base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a2.wei = m_wei_packed.data() + pack_base2; - if (has_oc3) { - const size_t pack_base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a2.wei2 = m_wei_packed.data() + pack_base3; - } - a2.wei_stride = a.wei_stride; - a2.wei_blk_stride = a.wei_blk_stride; - a2.wei_dx = a.wei_dx; - (*m_ip_kernel)(&a2); - } } - } else { - // Non-packed: keep ky loop outside and issue dual calls - for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy2 = static_cast(iy0 + ky); - const size_t ix2 = static_cast(ix0 + kx_lo); - const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); - const size_t w0_base = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w1_base = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; - // pair 0 - { - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.wei = wei_p + w0_base; - if (has_oc1) a.wei2 = wei_p + w1_base; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - (*m_ip_kernel)(&a); - } - if (has_oc2) { - const size_t w2_base = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w3_base = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.wei = wei_p + w2_base; - if (has_oc3) a.wei2 = wei_p + w3_base; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - (*m_ip_kernel)(&a); - } + if (has_oc2) { + const size_t w2_base = index_wei(oc2, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w3_base = has_oc3 ? index_wei(oc3, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = wei_p + w2_base; + if (has_oc3) + a.wei2 = wei_p + w3_base; + a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel)(&a); } } } } - } else { - for (size_t kz = 0; kz < KD; ++kz) { + } + } else { + for (size_t kz = 0; kz < KD; ++kz) { const int64_t iz = iz0 + static_cast(kz); - if (iz < 0 || iz >= static_cast(ID)) continue; + if (iz < 0 || iz >= static_cast(ID)) + continue; for (size_t ky = 0; ky < KH; ++ky) { const int64_t iy = iy0 + static_cast(ky); - if (iy < 0 || iy >= static_cast(IH)) continue; + if (iy < 0 || iy >= static_cast(IH)) + continue; for (size_t kx = 0; kx < KW; ++kx) { const int64_t ix = ix0 + static_cast(kx); - if (ix < 0 || ix >= static_cast(IW)) continue; - const size_t s_base0 = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + if (ix < 0 || ix >= static_cast(IW)) + continue; + const size_t s_base0 = index_src(n, + 0, + static_cast(iz), + static_cast(iy), + static_cast(ix)); const size_t w0_base = index_wei(oc0, 0, kz, ky, kx); const size_t w1_base = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; // pair 0 @@ -1361,7 +1933,9 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { jit_conv3d_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; // used logically, kernel advances by stride once after 8 lanes + a.src_blk_stride = + a.src_stride * + 8; // used logically, kernel advances by stride once after 8 lanes a.acc = &acc0; a.acc2 = has_oc1 ? &acc1 : nullptr; @@ -1371,16 +1945,18 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { a.wei = m_wei_packed.data() + pack_base0; a.repeats = C / 8; a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); // contiguous halves - a.wei_blk_stride = a.wei_stride * 8; // logical + a.wei_stride = sizeof(uint16_t); // contiguous halves + a.wei_blk_stride = a.wei_stride * 8; // logical if (has_oc1) { - const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t pack_base1 = + (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; a.wei2 = m_wei_packed.data() + pack_base1; } (*m_ip_kernel)(&a); } else { a.wei = wei_p + w0_base; - if (has_oc1) a.wei2 = wei_p + w1_base; + if (has_oc1) + a.wei2 = wei_p + w1_base; a.repeats = C / 8; a.tail = C % 8; a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); @@ -1395,7 +1971,9 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { jit_conv3d_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; // used logically, kernel advances by stride once after 8 lanes + a.src_blk_stride = + a.src_stride * + 8; // used logically, kernel advances by stride once after 8 lanes a.acc = &acc2; a.acc2 = has_oc3 ? &acc3 : nullptr; if (m_wei_packed_ready) { @@ -1403,16 +1981,18 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { a.wei = m_wei_packed.data() + pack_base2; a.repeats = C / 8; a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); // contiguous halves - a.wei_blk_stride = a.wei_stride * 8; // logical + a.wei_stride = sizeof(uint16_t); // contiguous halves + a.wei_blk_stride = a.wei_stride * 8; // logical if (has_oc3) { - const size_t pack_base3 = (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t pack_base3 = + (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; a.wei2 = m_wei_packed.data() + pack_base3; } (*m_ip_kernel)(&a); } else { a.wei = wei_p + w2_base; - if (has_oc3) a.wei2 = wei_p + w3_base; + if (has_oc3) + a.wei2 = wei_p + w3_base; a.repeats = C / 8; a.tail = C % 8; a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); @@ -1423,61 +2003,88 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { } } } + } + // Optional fused bias (disabled by default) + if (m_apply_post_ops && m_attrs.withBias && memory.count(ARG_BIAS) && memory.at(ARG_BIAS)) { + auto bia = memory.at(ARG_BIAS); + const auto bprec = bia->getDescPtr()->getPrecision(); + if (bprec == ov::element::f32) { + const float* b = reinterpret_cast(bia->getData()); + acc0 += b[oc0]; + if (has_oc1) + acc1 += b[oc1]; + if (has_oc2) + acc2 += b[oc2]; + if (has_oc3) + acc3 += b[oc3]; + } else if (bprec == ov::element::f16) { + const uint16_t* b = reinterpret_cast(bia->getData()); + acc0 += static_cast(ov::float16(b[oc0])); + if (has_oc1) + acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) + acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) + acc3 += static_cast(ov::float16(b[oc3])); } - // Optional fused bias (disabled by default) - if (m_apply_post_ops && m_attrs.withBias && memory.count(ARG_BIAS) && memory.at(ARG_BIAS)) { - auto bia = memory.at(ARG_BIAS); - const auto bprec = bia->getDescPtr()->getPrecision(); - if (bprec == ov::element::f32) { - const float* b = reinterpret_cast(bia->getData()); - acc0 += b[oc0]; - if (has_oc1) acc1 += b[oc1]; - if (has_oc2) acc2 += b[oc2]; - if (has_oc3) acc3 += b[oc3]; - } else if (bprec == ov::element::f16) { - const uint16_t* b = reinterpret_cast(bia->getData()); - acc0 += static_cast(ov::float16(b[oc0])); - if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); - if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); - if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); - } - } + } - // Optional fused PReLU (apply after bias) — disabled by default - if (m_apply_post_ops && m_has_prelu && !m_prelu_slopes.empty()) { - const auto slope_at = [&](size_t oc) -> float { - return m_prelu_slopes.size() == 1 ? m_prelu_slopes[0] - : m_prelu_slopes[std::min(oc, m_prelu_slopes.size() - 1)]; - }; - const float s0 = slope_at(oc0); - if (acc0 < 0.f) acc0 *= s0; - if (has_oc1) { const float s1 = slope_at(oc1); if (acc1 < 0.f) acc1 *= s1; } - if (has_oc2) { const float s2 = slope_at(oc2); if (acc2 < 0.f) acc2 *= s2; } - if (has_oc3) { const float s3 = slope_at(oc3); if (acc3 < 0.f) acc3 *= s3; } + // Optional fused PReLU (apply after bias) — disabled by default + if (m_apply_post_ops && m_has_prelu && !m_prelu_slopes.empty()) { + const auto slope_at = [&](size_t oc) -> float { + return m_prelu_slopes.size() == 1 ? m_prelu_slopes[0] + : m_prelu_slopes[std::min(oc, m_prelu_slopes.size() - 1)]; + }; + const float s0 = slope_at(oc0); + if (acc0 < 0.f) + acc0 *= s0; + if (has_oc1) { + const float s1 = slope_at(oc1); + if (acc1 < 0.f) + acc1 *= s1; + } + if (has_oc2) { + const float s2 = slope_at(oc2); + if (acc2 < 0.f) + acc2 *= s2; + } + if (has_oc3) { + const float s3 = slope_at(oc3); + if (acc3 < 0.f) + acc3 *= s3; } - - dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); - if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); - if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); - if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); } + + dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); + if (has_oc1) + dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + if (has_oc2) + dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + if (has_oc3) + dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); } + } }; ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); } void JitConv3DExecutor::execute(const MemoryArgs& memory) { - if (m_is_fp32) run_naive_fp32(memory); else run_naive_fp16(memory); + if (m_is_fp32) + run_naive_fp32(memory); + else + run_naive_fp16(memory); } void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { - if (m_wei_packed_ready_f32) return; + if (m_wei_packed_ready_f32) + return; auto src = memory.at(ARG_SRC); auto wei = memory.at(ARG_WEI); const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); - if (srcDims.size() != 5 || weiDims.size() != 5) return; + if (srcDims.size() != 5 || weiDims.size() != 5) + return; const size_t C = srcDims[1]; const size_t OC = weiDims[0]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; @@ -1487,7 +2094,7 @@ void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { const float* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; @@ -1544,7 +2151,7 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { return (((n * OC + c) * OD + z) * OH + y) * OW + x; }; auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; const size_t src_c_stride_elems = ID * IH * IW; // elements between channels @@ -1572,14 +2179,14 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { if (SD == 1 && SH == 1 && SW == 1) { const ptrdiff_t kz_lo = std::max(0, -iz0); - const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, - static_cast(ID) - 1 - iz0); + const ptrdiff_t kz_hi = + std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); const ptrdiff_t ky_lo = std::max(0, -iy0); - const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, - static_cast(IH) - 1 - iy0); + const ptrdiff_t ky_hi = + std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); const ptrdiff_t kx_lo = std::max(0, -ix0); - const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, - static_cast(IW) - 1 - ix0); + const ptrdiff_t kx_hi = + std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { const size_t kw_count = static_cast(kx_hi - kx_lo + 1); for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { @@ -1602,10 +2209,18 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); @@ -1625,10 +2240,18 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C_f32; + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C_f32; a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); @@ -1638,8 +2261,17 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { } } else { // generic path: non-packed weights - const size_t w0 = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w1 = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + const size_t w0 = index_wei(oc0, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w1 = has_oc1 ? index_wei(oc1, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; jit_conv3d_f32_call_args a{}; a.src = src_p + s_base; a.src_stride = src_c_stride_elems * sizeof(float); @@ -1651,14 +2283,24 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = kw_count; a.src_dx = sizeof(float); a.wei = wei_p + w0; - if (has_oc1) a.wei2 = wei_p + w1; + if (has_oc1) + a.wei2 = wei_p + w1; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = sizeof(float); (*m_ip_kernel_f32)(&a); if (has_oc2) { - const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + const size_t w2 = index_wei(oc2, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; jit_conv3d_f32_call_args a2{}; a2.src = src_p + s_base; a2.src_stride = a.src_stride; @@ -1670,7 +2312,8 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a2.kw_cnt = a.kw_cnt; a2.src_dx = a.src_dx; a2.wei = wei_p + w2; - if (has_oc3) a2.wei2 = wei_p + w3; + if (has_oc3) + a2.wei2 = wei_p + w3; a2.wei_stride = a.wei_stride; a2.wei_blk_stride = a.wei_blk_stride; a2.wei_dx = a.wei_dx; @@ -1684,14 +2327,21 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { // generic spatial stride path (host loops over all taps) for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t iz = iz0 + static_cast(kz); - if (iz < 0 || iz >= static_cast(ID)) continue; + if (iz < 0 || iz >= static_cast(ID)) + continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy = iy0 + static_cast(ky); - if (iy < 0 || iy >= static_cast(IH)) continue; + if (iy < 0 || iy >= static_cast(IH)) + continue; for (size_t kx = 0; kx < KW; ++kx) { const ptrdiff_t ix = ix0 + static_cast(kx); - if (ix < 0 || ix >= static_cast(IW)) continue; - const size_t s_base = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + if (ix < 0 || ix >= static_cast(IW)) + continue; + const size_t s_base = index_src(n, + 0, + static_cast(iz), + static_cast(iy), + static_cast(ix)); // pair 0 { const size_t w0 = index_wei(oc0, 0, kz, ky, kx); @@ -1707,7 +2357,8 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = 1; a.src_dx = 0; a.wei = wei_p + w0; - if (has_oc1) a.wei2 = wei_p + w1; + if (has_oc1) + a.wei2 = wei_p + w1; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; @@ -1728,7 +2379,8 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = 1; a.src_dx = 0; a.wei = wei_p + w2; - if (has_oc3) a.wei2 = wei_p + w3; + if (has_oc3) + a.wei2 = wei_p + w3; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; @@ -1741,9 +2393,12 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { // Store (convert FP32 accumulators to FP16 bits) dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); - if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); - if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); - if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); + if (has_oc1) + dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + if (has_oc2) + dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + if (has_oc3) + dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); } } } @@ -1752,12 +2407,14 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { } // namespace ov::intel_cpu void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& memory) { - if (m_wei_packed_ready) return; + if (m_wei_packed_ready) + return; auto src = memory.at(ARG_SRC); auto wei = memory.at(ARG_WEI); const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); - if (srcDims.size() != 5 || weiDims.size() != 5) return; + if (srcDims.size() != 5 || weiDims.size() != 5) + return; const size_t C = srcDims[1]; const size_t OC = weiDims[0]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; @@ -1768,7 +2425,7 @@ void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& m const uint16_t* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index 08eabf13c39e6a..fab4b7264c738f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -22,27 +22,27 @@ namespace ov::intel_cpu { struct jit_conv3d_call_args { - const uint16_t* src; // f16 base ptr - const uint16_t* wei; // f16 base ptr - const uint16_t* wei2; // optional second oc f16 base ptr (can be null) - const uint16_t* wei3; // optional third oc f16 base ptr (can be null) - const uint16_t* wei4; // optional fourth oc f16 base ptr (can be null) - size_t repeats; // number of full 8-channel blocks - size_t tail; // remaining channels (< 8) - size_t src_stride; // stride between channels in bytes - size_t wei_stride; // stride between channels in bytes - size_t src_blk_stride; // stride between successive 8-channel blocks in bytes - size_t wei_blk_stride; // stride between successive 8-channel blocks in bytes - float* acc; // f32 accumulator - float* acc2; // optional second f32 accumulator (can be null) - size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single - size_t src_dx; // bytes to advance src base between successive kx taps - size_t wei_dx; // bytes to advance weights base between successive kx taps - float* acc3; // optional third f32 accumulator (can be null) - float* acc4; // optional fourth f32 accumulator (can be null) - size_t kh_cnt; // number of taps along H to iterate (stride=1 fast path); 0 or 1 -> single - size_t src_dy; // bytes to advance src base between successive ky taps - size_t wei_dy; // bytes to advance weights base between successive ky taps + const uint16_t* src; // f16 base ptr + const uint16_t* wei; // f16 base ptr + const uint16_t* wei2; // optional second oc f16 base ptr (can be null) + const uint16_t* wei3; // optional third oc f16 base ptr (can be null) + const uint16_t* wei4; // optional fourth oc f16 base ptr (can be null) + size_t repeats; // number of full 8-channel blocks + size_t tail; // remaining channels (< 8) + size_t src_stride; // stride between channels in bytes + size_t wei_stride; // stride between channels in bytes + size_t src_blk_stride; // stride between successive 8-channel blocks in bytes + size_t wei_blk_stride; // stride between successive 8-channel blocks in bytes + float* acc; // f32 accumulator + float* acc2; // optional second f32 accumulator (can be null) + size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dx; // bytes to advance src base between successive kx taps + size_t wei_dx; // bytes to advance weights base between successive kx taps + float* acc3; // optional third f32 accumulator (can be null) + float* acc4; // optional fourth f32 accumulator (can be null) + size_t kh_cnt; // number of taps along H to iterate (stride=1 fast path); 0 or 1 -> single + size_t src_dy; // bytes to advance src base between successive ky taps + size_t wei_dy; // bytes to advance weights base between successive ky taps }; class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { @@ -73,11 +73,15 @@ class JitConv3DExecutor : public Executor { return true; } void execute(const MemoryArgs& memory) override; - void execute() override { execute(m_memory); } + void execute() override { + execute(m_memory); + } void exec([[maybe_unused]] const std::vector& src, [[maybe_unused]] const std::vector& dst) override {} - [[nodiscard]] impl_desc_type implType() const override { return impl_desc_type::jit_asimd; } + [[nodiscard]] impl_desc_type implType() const override { + return impl_desc_type::jit_asimd; + } static bool supports(const ConvConfig& cfg); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp index fbd31aa3ed1ccc..ac116c9e37e674 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -34,20 +34,20 @@ void JitConv3DKernelF32::generate() { const XReg reg_args = abi_param1; // x0 - const XReg reg_src = x1; // const float* src - const XReg reg_wei = x2; // const float* wei - const XReg reg_wei2 = x3; // const float* wei2 (optional) - const XReg reg_reps = x4; // size_t repeats (C/4) - const XReg reg_tail = x5; // size_t tail (C%4) - const XReg reg_src_stride = x6; // bytes between channels - const XReg reg_wei_stride = x7; // bytes between channels + const XReg reg_src = x1; // const float* src + const XReg reg_wei = x2; // const float* wei + const XReg reg_wei2 = x3; // const float* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats (C/4) + const XReg reg_tail = x5; // size_t tail (C%4) + const XReg reg_src_stride = x6; // bytes between channels + const XReg reg_wei_stride = x7; // bytes between channels const XReg reg_src_blk_stride = x8; // bytes between successive 4-ch blocks const XReg reg_wei_blk_stride = x9; // bytes between successive 4-ch blocks - const XReg reg_acc = x10; // float* acc - const XReg reg_acc2 = x11; // float* acc2 (optional) - const XReg reg_kw_cnt = x12; // taps along W - const XReg reg_src_dx = x13; // bytes to step src base per kx - const XReg reg_wei_dx = x14; // bytes to step wei base per kx + const XReg reg_acc = x10; // float* acc + const XReg reg_acc2 = x11; // float* acc2 (optional) + const XReg reg_kw_cnt = x12; // taps along W + const XReg reg_src_dx = x13; // bytes to step src base per kx + const XReg reg_wei_dx = x14; // bytes to step wei base per kx // Load args by struct offsets (see jit_conv3d_f32_call_args) ldr(reg_src, ptr(reg_args, 0)); @@ -68,7 +68,7 @@ void JitConv3DKernelF32::generate() { // Work registers for base pointers per kx const XReg q_src_base = x15; const XReg q_wei_base = x16; - const XReg q_wei2_base = x17; // avoid x18 + const XReg q_wei2_base = x17; // avoid x18 Label Lsingle, Ldone; Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; @@ -105,9 +105,12 @@ void JitConv3DKernelF32::generate() { cmp(reg_reps, 0); b(EQ, Ltail_prep_d_kx); // src lanes -> v0.s[0..3] - ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[3], ptr(reg_src)); // wei lanes: vector fast path if stride==4 bytes Label Lw_np_d, Lw_done_d; @@ -119,13 +122,20 @@ void JitConv3DKernelF32::generate() { add(reg_wei2, reg_wei2, reg_wei_blk_stride); b(Lw_done_d); L(Lw_np_d); - ld1(VReg(1).s[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[0], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[1], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[2], ptr(reg_wei2)); add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); L(Lw_done_d); // MAC fmla(VReg4S(20), VReg4S(0), VReg4S(1)); @@ -138,17 +148,35 @@ void JitConv3DKernelF32::generate() { eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); eor(VReg16B(2), VReg16B(2), VReg16B(2)); - cmp(reg_tail, 0); b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[0], ptr(reg_src)); ld1(VReg(1).s[0], ptr(reg_wei)); ld1(VReg(2).s[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[1], ptr(reg_src)); ld1(VReg(1).s[1], ptr(reg_wei)); ld1(VReg(2).s[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[2], ptr(reg_src)); ld1(VReg(1).s[2], ptr(reg_wei)); ld1(VReg(2).s[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[3], ptr(reg_src)); ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); L(Ltail_done_d_kx); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); fmla(VReg4S(21), VReg4S(0), VReg4S(2)); @@ -163,8 +191,12 @@ void JitConv3DKernelF32::generate() { faddp(VReg2S(20), VReg2S(20), VReg2S(20)); faddp(VReg4S(21), VReg4S(21), VReg4S(21)); faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); fadd(SReg(1), SReg(1), SReg(21)); str(SReg(1), ptr(reg_acc2)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); b(Ldone); // ---------------- Single-OC with in-kernel kx loop ---------------- @@ -186,9 +218,12 @@ void JitConv3DKernelF32::generate() { cmp(reg_reps, 0); b(EQ, Ltail_prep_s_kx); // src lanes - ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[1], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[3], ptr(reg_src)); // wei lanes: vector fast path if stride==4 Label Lw_np_s, Lw_done_s; @@ -198,9 +233,12 @@ void JitConv3DKernelF32::generate() { add(reg_wei, reg_wei, reg_wei_blk_stride); b(Lw_done_s); L(Lw_np_s); - ld1(VReg(1).s[0], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).s[1], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).s[3], ptr(reg_wei)); L(Lw_done_s); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); @@ -211,17 +249,28 @@ void JitConv3DKernelF32::generate() { L(Ltail_prep_s_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[0], ptr(reg_src)); ld1(VReg(1).s[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[1], ptr(reg_src)); ld1(VReg(1).s[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[2], ptr(reg_src)); ld1(VReg(1).s[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[3], ptr(reg_src)); ld1(VReg(1).s[3], ptr(reg_wei)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); L(Ltail_done_s_kx); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); @@ -234,7 +283,9 @@ void JitConv3DKernelF32::generate() { // reduce and store faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); fadd(SReg(0), SReg(0), SReg(20)); str(SReg(0), ptr(reg_acc)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); b(Ldone); L(Ldone); @@ -244,36 +295,50 @@ void JitConv3DKernelF32::generate() { // --------------------------- Executor (FP32) --------------------------- JitConv3DExecutorF32::JitConv3DExecutorF32(const ConvAttrs& attrs, const MemoryArgs& memory, - const ExecutorContext::CPtr& /*context*/) : m_attrs(attrs) { + const ExecutorContext::CPtr& /*context*/) + : m_attrs(attrs) { m_memory = memory; m_ip_kernel = std::make_unique(); m_ip_kernel->create_ker(); } bool JitConv3DExecutorF32::supports(const ConvConfig& cfg) { - if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; - if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) return false; + if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) + return false; + if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) + return false; const auto& s = cfg.descs.at(ARG_SRC)->getShape(); const auto& w = cfg.descs.at(ARG_WEI)->getShape(); const auto& d = cfg.descs.at(ARG_DST)->getShape(); - if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) return false; + if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) + return false; const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); - if (!(sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32)) return false; - if (w.getRank() != 5) return false; // groups unsupported here - for (auto v : cfg.attrs.dilation) { if (v != 0) return false; } - for (auto v : cfg.attrs.stride) { if (!(v == 1 || v == 2)) return false; } + if (!(sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32)) + return false; + if (w.getRank() != 5) + return false; // groups unsupported here + for (auto v : cfg.attrs.dilation) { + if (v != 0) + return false; + } + for (auto v : cfg.attrs.stride) { + if (!(v == 1 || v == 2)) + return false; + } return true; } void JitConv3DExecutorF32::ensure_weights_packed(const MemoryArgs& memory) { - if (m_wei_packed_ready) return; + if (m_wei_packed_ready) + return; auto src = memory.at(ARG_SRC); auto wei = memory.at(ARG_WEI); const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); - if (srcDims.size() != 5 || weiDims.size() != 5) return; + if (srcDims.size() != 5 || weiDims.size() != 5) + return; const size_t C = srcDims[1]; const size_t OC = weiDims[0]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; @@ -283,7 +348,7 @@ void JitConv3DExecutorF32::ensure_weights_packed(const MemoryArgs& memory) { const float* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; @@ -340,7 +405,7 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { return (((n * OC + c) * OD + z) * OH + y) * OW + x; }; auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { - return ((((oc) * C + c) * KD + kz) * KH + ky) * KW + kx; + return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; }; const size_t src_c_stride_elems = ID * IH * IW; // elements between channels @@ -368,14 +433,14 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { if (SD == 1 && SH == 1 && SW == 1) { const ptrdiff_t kz_lo = std::max(0, -iz0); - const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, - static_cast(ID) - 1 - iz0); + const ptrdiff_t kz_hi = + std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); const ptrdiff_t ky_lo = std::max(0, -iy0); - const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, - static_cast(IH) - 1 - iy0); + const ptrdiff_t ky_hi = + std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); const ptrdiff_t kx_lo = std::max(0, -ix0); - const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, - static_cast(IW) - 1 - ix0); + const ptrdiff_t kx_hi = + std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { const size_t kw_count = static_cast(kx_hi - kx_lo + 1); for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { @@ -399,10 +464,18 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; a.wei = m_wei_packed.data() + base0; if (has_oc1) { - const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; a.wei2 = m_wei_packed.data() + base1; } a.wei_stride = sizeof(float); @@ -422,10 +495,18 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; a.wei = m_wei_packed.data() + base2; if (has_oc3) { - const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_C; a.wei2 = m_wei_packed.data() + base3; } a.wei_stride = sizeof(float); @@ -435,8 +516,17 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { } } else { // generic path: kx loop in kernel, but weights non-packed - const size_t w0 = index_wei(oc0, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w1 = has_oc1 ? index_wei(oc1, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + const size_t w0 = index_wei(oc0, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w1 = has_oc1 ? index_wei(oc1, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; jit_conv3d_f32_call_args a{}; a.src = src_p + s_base; a.src_stride = src_c_stride_elems * sizeof(float); @@ -448,15 +538,25 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = kw_count; a.src_dx = sizeof(float); a.wei = wei_p + w0; - if (has_oc1) a.wei2 = wei_p + w1; + if (has_oc1) + a.wei2 = wei_p + w1; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = sizeof(float); (*m_ip_kernel)(&a); if (has_oc2) { - const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); - const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + const size_t w2 = index_wei(oc2, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, + 0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; jit_conv3d_f32_call_args a2{}; a2.src = src_p + s_base; a2.src_stride = a.src_stride; @@ -468,7 +568,8 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a2.kw_cnt = a.kw_cnt; a2.src_dx = a.src_dx; a2.wei = wei_p + w2; - if (has_oc3) a2.wei2 = wei_p + w3; + if (has_oc3) + a2.wei2 = wei_p + w3; a2.wei_stride = a.wei_stride; a2.wei_blk_stride = a.wei_blk_stride; a2.wei_dx = a.wei_dx; @@ -482,14 +583,21 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { // generic spatial stride path (host loops over all taps) for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t iz = iz0 + static_cast(kz); - if (iz < 0 || iz >= static_cast(ID)) continue; + if (iz < 0 || iz >= static_cast(ID)) + continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy = iy0 + static_cast(ky); - if (iy < 0 || iy >= static_cast(IH)) continue; + if (iy < 0 || iy >= static_cast(IH)) + continue; for (size_t kx = 0; kx < KW; ++kx) { const ptrdiff_t ix = ix0 + static_cast(kx); - if (ix < 0 || ix >= static_cast(IW)) continue; - const size_t s_base = index_src(n, 0, static_cast(iz), static_cast(iy), static_cast(ix)); + if (ix < 0 || ix >= static_cast(IW)) + continue; + const size_t s_base = index_src(n, + 0, + static_cast(iz), + static_cast(iy), + static_cast(ix)); // pair 0 { const size_t w0 = index_wei(oc0, 0, kz, ky, kx); @@ -505,7 +613,8 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = 1; a.src_dx = 0; a.wei = wei_p + w0; - if (has_oc1) a.wei2 = wei_p + w1; + if (has_oc1) + a.wei2 = wei_p + w1; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; @@ -526,7 +635,8 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { a.kw_cnt = 1; a.src_dx = 0; a.wei = wei_p + w2; - if (has_oc3) a.wei2 = wei_p + w3; + if (has_oc3) + a.wei2 = wei_p + w3; a.wei_stride = wei_c_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; @@ -539,9 +649,12 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { // Store dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; - if (has_oc1) dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; - if (has_oc2) dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; - if (has_oc3) dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; + if (has_oc1) + dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; + if (has_oc2) + dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; + if (has_oc3) + dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; } } } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp index 54cdc0b07b9ff2..88f2067a8f0163 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp @@ -43,7 +43,9 @@ class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { JitConv3DKernelF32() = default; void create_ker(); - inline void operator()(const jit_conv3d_f32_call_args* p) const { ker_(p); } + inline void operator()(const jit_conv3d_f32_call_args* p) const { + ker_(p); + } private: void generate() override; @@ -61,11 +63,15 @@ class JitConv3DExecutorF32 : public Executor { return true; } void execute(const MemoryArgs& memory) override; - void execute() override { execute(m_memory); } + void execute() override { + execute(m_memory); + } void exec([[maybe_unused]] const std::vector& src, [[maybe_unused]] const std::vector& dst) override {} - [[nodiscard]] impl_desc_type implType() const override { return impl_desc_type::jit_asimd; } + [[nodiscard]] impl_desc_type implType() const override { + return impl_desc_type::jit_asimd; + } static bool supports(const ConvConfig& cfg); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index dd9fd695d73416..c929b3269e9f30 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -16,12 +16,7 @@ namespace ov::intel_cpu { -static inline const uint16_t* as_f16(const MemoryPtr& m) { - return reinterpret_cast(m->getData()); -} -static inline uint16_t* as_f16(MemoryPtr& m) { - return reinterpret_cast(m->getData()); -} +// removed unused helpers bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, const std::vector& srcDescs, @@ -44,10 +39,12 @@ bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, } void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector& src) { - if (m_wei_packed_ready_f16) return; + if (m_wei_packed_ready_f16) + return; // src[1] holds weights for deconv with shape [IC, OC, KD, KH, KW] const auto& weiDims = src[1]->getStaticDims(); - if (weiDims.size() != 5) return; + if (weiDims.size() != 5) + return; const size_t IC = weiDims[0]; const size_t OC = weiDims[1]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; @@ -57,7 +54,7 @@ void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector(src[1]->getData()); auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; @@ -81,9 +78,11 @@ void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector& src) { - if (m_wei_packed_ready_f32) return; + if (m_wei_packed_ready_f32) + return; const auto& weiDims = src[1]->getStaticDims(); - if (weiDims.size() != 5) return; + if (weiDims.size() != 5) + return; const size_t IC = weiDims[0]; const size_t OC = weiDims[1]; const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; @@ -93,7 +92,7 @@ void JitDeconv3DExecutor::ensure_weights_packed_f32(const std::vector(src[1]->getData()); auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; }; auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; @@ -126,8 +125,7 @@ void JitDeconv3DExecutor::exec(const std::vector& src, } } -void JitDeconv3DExecutor::exec_fp16(const std::vector& src, - const std::vector& dst) { +void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const std::vector& dst) { // NCDHW, fp16: compute each output pixel (n, oc, od, oh, ow) as a sum over (ic, kz, ky, kx) const auto& srcDims = src[0]->getStaticDims(); const auto& weiDims = src[1]->getStaticDims(); @@ -161,7 +159,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, }; // weight [IC, OC, KD, KH, KW] auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { - return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; }; // Strides in elements @@ -210,16 +208,15 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, // Compute packed bases for ky_lo size_t pack_base_z0 = (oc0 * KD + static_cast(kz)) * KH; size_t pack_base_z1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; - size_t pack_base_z2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; - size_t pack_base_z3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + // oc2/oc3 computed in second dual call; no need for precomputed bases size_t pack_base_y0 = (pack_base_z0 + static_cast(ky_lo)) * KW; - size_t pack_base_y1 = has_oc1 ? (pack_base_z1 + static_cast(ky_lo)) * KW : 0; - size_t pack_base_y2 = has_oc2 ? (pack_base_z2 + static_cast(ky_lo)) * KW : 0; - size_t pack_base_y3 = has_oc3 ? (pack_base_z3 + static_cast(ky_lo)) * KW : 0; - const size_t pack_base0 = (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC_f16; - const size_t pack_base1 = has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; - const size_t pack_base2 = has_oc2 ? (pack_base_y2 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; - const size_t pack_base3 = has_oc3 ? (pack_base_y3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + size_t pack_base_y1 = + has_oc1 ? (pack_base_z1 + static_cast(ky_lo)) * KW : 0; + // oc2/oc3 will be handled in the second dual call below + const size_t pack_base0 = + (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t pack_base1 = + has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; jit_conv3d_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -234,7 +231,8 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, a.src_dx = sizeof(uint16_t); a.src_dy = IW * sizeof(uint16_t); a.wei = m_wei_packed_f16.data() + pack_base0; - if (has_oc1) a.wei2 = m_wei_packed_f16.data() + pack_base1; + if (has_oc1) + a.wei2 = m_wei_packed_f16.data() + pack_base1; // oc2/oc3 handled in a follow-up dual call a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; @@ -259,10 +257,20 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, a.acc2 = has_oc1 ? &acc1 : nullptr; a.repeats = IC / 8; a.tail = IC % 8; - const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); - const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)) : 0; + const size_t w_base0 = idx_wei(0, + oc0, + static_cast(kz), + static_cast(ky), + static_cast(kx)); + const size_t w_base1 = has_oc1 ? idx_wei(0, + oc1, + static_cast(kz), + static_cast(ky), + static_cast(kx)) + : 0; a.wei = wei_p + w_base0; - if (has_oc1) a.wei2 = wei_p + w_base1; + if (has_oc1) + a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; (*m_ip_kernel_f16)(&a); @@ -277,10 +285,20 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, a.acc2 = has_oc3 ? &acc3 : nullptr; a.repeats = IC / 8; a.tail = IC % 8; - const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); - const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)) : 0; + const size_t w_base2 = idx_wei(0, + oc2, + static_cast(kz), + static_cast(ky), + static_cast(kx)); + const size_t w_base3 = has_oc3 ? idx_wei(0, + oc3, + static_cast(kz), + static_cast(ky), + static_cast(kx)) + : 0; a.wei = wei_p + w_base2; - if (has_oc3) a.wei2 = wei_p + w_base3; + if (has_oc3) + a.wei2 = wei_p + w_base3; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; (*m_ip_kernel_f16)(&a); @@ -294,24 +312,38 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, // Generic path (stride > 1): keep modulus checks for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); - if (SD == 0) continue; - if (iz_num % static_cast(SD) != 0) continue; + if (SD == 0) + continue; + if (iz_num % static_cast(SD) != 0) + continue; const ptrdiff_t id = iz_num / static_cast(SD); - if (id < 0 || id >= static_cast(ID)) continue; + if (id < 0 || id >= static_cast(ID)) + continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); - if (SH == 0) continue; - if (iy_num % static_cast(SH) != 0) continue; + if (SH == 0) + continue; + if (iy_num % static_cast(SH) != 0) + continue; const ptrdiff_t ihh = iy_num / static_cast(SH); - if (ihh < 0 || ihh >= static_cast(IH)) continue; + if (ihh < 0 || ihh >= static_cast(IH)) + continue; for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ix_num = static_cast(ow_) + PW0 - static_cast(kx); - if (SW == 0) continue; - if (ix_num % static_cast(SW) != 0) continue; + const ptrdiff_t ix_num = + static_cast(ow_) + PW0 - static_cast(kx); + if (SW == 0) + continue; + if (ix_num % static_cast(SW) != 0) + continue; const ptrdiff_t iww = ix_num / static_cast(SW); - if (iww < 0 || iww >= static_cast(IW)) continue; - - const size_t s_base0 = idx_src(n, 0, static_cast(id), static_cast(ihh), static_cast(iww)); + if (iww < 0 || iww >= static_cast(IW)) + continue; + + const size_t s_base0 = idx_src(n, + 0, + static_cast(id), + static_cast(ihh), + static_cast(iww)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; @@ -324,17 +356,20 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, a.repeats = IC / 8; a.tail = IC % 8; if (m_wei_packed_ready_f16) { - const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t pack_base0 = + (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; a.wei = m_wei_packed_f16.data() + pack_base0; if (has_oc1) { - const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t pack_base1 = + (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; a.wei2 = m_wei_packed_f16.data() + pack_base1; } a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; } else { a.wei = wei_p + w_base0; - if (has_oc1) a.wei2 = wei_p + w_base1; + if (has_oc1) + a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; } @@ -349,22 +384,31 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, if (bprec == ov::element::f32) { const float* b = reinterpret_cast(src[2]->getData()); acc0 += b[oc0]; - if (has_oc1) acc1 += b[oc1]; - if (has_oc2) acc2 += b[oc2]; - if (has_oc3) acc3 += b[oc3]; + if (has_oc1) + acc1 += b[oc1]; + if (has_oc2) + acc2 += b[oc2]; + if (has_oc3) + acc3 += b[oc3]; } else if (bprec == ov::element::f16) { const uint16_t* b = reinterpret_cast(src[2]->getData()); acc0 += static_cast(ov::float16(b[oc0])); - if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); - if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); - if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); + if (has_oc1) + acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) + acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) + acc3 += static_cast(ov::float16(b[oc3])); } } dst_p[idx_dst(n, oc0, od, oh, ow_)] = ov::float16(acc0).to_bits(); - if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1).to_bits(); - if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2).to_bits(); - if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3).to_bits(); + if (has_oc1) + dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1).to_bits(); + if (has_oc2) + dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2).to_bits(); + if (has_oc3) + dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3).to_bits(); } } } @@ -373,8 +417,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, ov::parallel_for3d(N, (OC + 3) / 4, OD, worker); } -void JitDeconv3DExecutor::exec_fp32(const std::vector& src, - const std::vector& dst) { +void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const std::vector& dst) { // NCDHW, f32 const auto& srcDims = src[0]->getStaticDims(); const auto& weiDims = src[1]->getStaticDims(); @@ -406,7 +449,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, return (((n * OC + c) * OD + z) * OH + y) * OW + x; }; auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { - return ((((ic) * OC + oc) * KD + kz) * KH + ky) * KW + kx; + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; }; const size_t src_c_stride_elems = ID * IH * IW; @@ -447,7 +490,9 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const size_t ix0 = static_cast(tx - kx_lo); for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { const size_t iy = static_cast(ty - ky); - const size_t ix = ix0; (void)iy0; (void)ky_base; + const size_t ix = ix0; + (void)iy0; + (void)ky_base; const size_t s_base = idx_src(n, 0, iz, iy, ix); // pair 0 @@ -462,10 +507,18 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, a.tail = IC % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + const size_t base1 = + (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); @@ -485,10 +538,18 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, a.tail = IC % 4; a.kw_cnt = kw_count; a.src_dx = sizeof(float); - const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; + const size_t base3 = + (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); @@ -503,24 +564,38 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, // generic stride path with modulus checks for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); - if (SD == 0) continue; - if (iz_num % static_cast(SD) != 0) continue; + if (SD == 0) + continue; + if (iz_num % static_cast(SD) != 0) + continue; const ptrdiff_t id = iz_num / static_cast(SD); - if (id < 0 || id >= static_cast(ID)) continue; + if (id < 0 || id >= static_cast(ID)) + continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); - if (SH == 0) continue; - if (iy_num % static_cast(SH) != 0) continue; + if (SH == 0) + continue; + if (iy_num % static_cast(SH) != 0) + continue; const ptrdiff_t ihh = iy_num / static_cast(SH); - if (ihh < 0 || ihh >= static_cast(IH)) continue; + if (ihh < 0 || ihh >= static_cast(IH)) + continue; for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ix_num = static_cast(ow_) + PW0 - static_cast(kx); - if (SW == 0) continue; - if (ix_num % static_cast(SW) != 0) continue; + const ptrdiff_t ix_num = + static_cast(ow_) + PW0 - static_cast(kx); + if (SW == 0) + continue; + if (ix_num % static_cast(SW) != 0) + continue; const ptrdiff_t iww = ix_num / static_cast(SW); - if (iww < 0 || iww >= static_cast(IW)) continue; - - const size_t s_base0 = idx_src(n, 0, static_cast(id), static_cast(ihh), static_cast(iww)); + if (iww < 0 || iww >= static_cast(IW)) + continue; + + const size_t s_base0 = idx_src(n, + 0, + static_cast(id), + static_cast(ihh), + static_cast(iww)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; @@ -533,17 +608,20 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, a.repeats = IC / 4; a.tail = IC % 4; if (m_wei_packed_ready_f32) { - const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pack_base0 = + (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + pack_base0; if (has_oc1) { - const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pack_base1 = + (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; a.wei2 = m_wei_packed_f32.data() + pack_base1; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; } else { a.wei = wei_p + w_base0; - if (has_oc1) a.wei2 = wei_p + w_base1; + if (has_oc1) + a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(float); a.wei_blk_stride = a.wei_stride * 4; } @@ -557,20 +635,32 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const auto& bprec = src[2]->getPrecision(); if (bprec == ov::element::f32) { const float* b = reinterpret_cast(src[2]->getData()); - acc0 += b[oc0]; if (has_oc1) acc1 += b[oc1]; if (has_oc2) acc2 += b[oc2]; if (has_oc3) acc3 += b[oc3]; + acc0 += b[oc0]; + if (has_oc1) + acc1 += b[oc1]; + if (has_oc2) + acc2 += b[oc2]; + if (has_oc3) + acc3 += b[oc3]; } else if (bprec == ov::element::f16) { const uint16_t* b = reinterpret_cast(src[2]->getData()); acc0 += static_cast(ov::float16(b[oc0])); - if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); - if (has_oc2) acc2 += static_cast(ov::float16(b[oc2])); - if (has_oc3) acc3 += static_cast(ov::float16(b[oc3])); + if (has_oc1) + acc1 += static_cast(ov::float16(b[oc1])); + if (has_oc2) + acc2 += static_cast(ov::float16(b[oc2])); + if (has_oc3) + acc3 += static_cast(ov::float16(b[oc3])); } } dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0; - if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; - if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2; - if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3; + if (has_oc1) + dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; + if (has_oc2) + dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2; + if (has_oc3) + dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3; } } } @@ -581,7 +671,8 @@ bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, const std::vector& srcDescs, const std::vector& dstDescs) const { // Support 5D NCDHW, fp16 and fp32 - if (srcDescs.size() < 2 || dstDescs.empty()) return false; + if (srcDescs.size() < 2 || dstDescs.empty()) + return false; if (srcDescs[0]->getShape().getRank() != 5 || srcDescs[1]->getShape().getRank() != 5 || dstDescs[0]->getShape().getRank() != 5) { return false; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index 09ea8fc608cbf0..c59f3538a39073 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -7,9 +7,9 @@ #include #include -#include "nodes/executors/deconv.hpp" #include "nodes/executors/aarch64/jit_conv3d.hpp" #include "nodes/executors/aarch64/jit_conv3d_f32.hpp" +#include "nodes/executors/deconv.hpp" namespace ov::intel_cpu { @@ -27,7 +27,9 @@ class JitDeconv3DExecutor : public DeconvExecutor { const std::vector& dst, const void* post_ops_data_) override; - [[nodiscard]] impl_desc_type getImplType() const override { return impl_desc_type::jit_asimd; } + [[nodiscard]] impl_desc_type getImplType() const override { + return impl_desc_type::jit_asimd; + } private: std::vector m_srcDescs; @@ -39,7 +41,7 @@ class JitDeconv3DExecutor : public DeconvExecutor { // packed weights std::vector m_wei_packed_f16; - std::vector m_wei_packed_f32; + std::vector m_wei_packed_f32; bool m_wei_packed_ready_f16{false}; bool m_wei_packed_ready_f32{false}; size_t m_padded_IC_f16{0}; diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp index b77d3d62fa1d5f..12b6b34737061c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp @@ -24,8 +24,7 @@ const std::vector& getDeconvExecutorsList() { static std::vector descs = { // Prefer ACL builder first for stability/perf; fallback to AArch64 JIT if ACL not supported OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - OV_CPU_INSTANCE_ARM64(ExecutorType::Jit, std::make_shared()) - }; + OV_CPU_INSTANCE_ARM64(ExecutorType::Jit, std::make_shared())}; return descs; } From b4da43411744071ec1175357d2f3ef95bd904fbb Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Sun, 19 Oct 2025 23:34:38 +0200 Subject: [PATCH 06/20] Refactor AArch64 JIT 3D Deconvolution and Convolution Executors for improved readability by adopting consistent naming, reformatting, and splitting large code blocks into helper functions. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 1293 +++++++++-------- .../nodes/executors/aarch64/jit_conv3d.hpp | 4 + .../executors/aarch64/jit_conv3d_f32.cpp | 17 +- .../nodes/executors/aarch64/jit_deconv3d.cpp | 215 +-- 4 files changed, 774 insertions(+), 755 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 12a977e4045908..2a55a7ed834d7b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -4,20 +4,22 @@ #include "nodes/executors/aarch64/jit_conv3d.hpp" +#include +#include +#include +#include + #include +#include #include #include -#include #include #include #include "cpu_memory.h" -#include "memory_desc/cpu_memory_desc.h" -#include "nodes/executors/implementation_utils.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/core/type/float16.hpp" -#include "utils/general_utils.h" // helper for jit_kernel_cast #include "utils/cpu_utils.hpp" // no direct NEON intrinsics are used here; we rely on Xbyak_aarch64 @@ -27,642 +29,642 @@ using namespace dnnl::impl::cpu::aarch64; namespace ov::intel_cpu { // --------------------------- JIT kernel (placeholder) --------------------------- -JitConv3DKernelF16::JitConv3DKernelF16() {} +JitConv3DKernelF16::JitConv3DKernelF16() = default; void JitConv3DKernelF16::create_ker() { jit_generator::create_kernel(); ker_ = jit_kernel_cast(jit_ker()); } -void JitConv3DKernelF16::generate() { +void JitConv3DKernelF16::gen_minimal_kernel() { using namespace Xbyak_aarch64; // Minimal stable kernel (dual-OC, in-kernel kx loop) + const XReg reg_args = abi_param1; // x0 + // Load essential arguments (absolute offsets from jit_conv3d_call_args) + const XReg reg_src = x1; // const uint16_t* src + const XReg reg_wei = x2; // const uint16_t* wei + const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) + const XReg reg_reps = x4; // size_t repeats + const XReg reg_tail = x5; // size_t tail + const XReg reg_src_stride = x6; // size_t src_stride (bytes) + const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) + const XReg reg_acc = x8; // float* acc + const XReg reg_acc2 = x9; // float* acc2 (optional) + + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 40)); + ldr(reg_tail, ptr(reg_args, 48)); + ldr(reg_src_stride, ptr(reg_args, 56)); + ldr(reg_wei_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 88)); + ldr(reg_acc2, ptr(reg_args, 96)); + + Label Lsingle, Ldone; + // Additional labels for kx-loop variants + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + cbz(reg_acc2, Lsingle); + // Jump to in-kernel kx loop dual-OC path (safe, call-clobbered only) + b(Ldual_kx); - { - const XReg reg_args = abi_param1; // x0 - // Load essential arguments (absolute offsets from jit_conv3d_call_args) - const XReg reg_src = x1; // const uint16_t* src - const XReg reg_wei = x2; // const uint16_t* wei - const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) - const XReg reg_reps = x4; // size_t repeats - const XReg reg_tail = x5; // size_t tail - const XReg reg_src_stride = x6; // size_t src_stride (bytes) - const XReg reg_wei_stride = x7; // size_t wei_stride (bytes) - const XReg reg_acc = x8; // float* acc - const XReg reg_acc2 = x9; // float* acc2 (optional) - - ldr(reg_src, ptr(reg_args, 0)); - ldr(reg_wei, ptr(reg_args, 8)); - ldr(reg_wei2, ptr(reg_args, 16)); - ldr(reg_reps, ptr(reg_args, 40)); - ldr(reg_tail, ptr(reg_args, 48)); - ldr(reg_src_stride, ptr(reg_args, 56)); - ldr(reg_wei_stride, ptr(reg_args, 64)); - ldr(reg_acc, ptr(reg_args, 88)); - ldr(reg_acc2, ptr(reg_args, 96)); - - Label Lsingle, Ldone; - // Additional labels for kx-loop variants - Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; - Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; - cbz(reg_acc2, Lsingle); - // Jump to in-kernel kx loop dual-OC path (safe, call-clobbered only) - b(Ldual_kx); - - // Dual-OC with in-kernel kx loop (v20 for oc0, v21 for oc1) - L(Ldual_kx); - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - eor(VReg16B(21), VReg16B(21), VReg16B(21)); - // Load kx-loop controls and set bases - const XReg reg_kw_cnt = x12; - const XReg reg_src_dx = x13; - const XReg reg_wei_dx = x14; - ldr(reg_kw_cnt, ptr(reg_args, 104)); - ldr(reg_src_dx, ptr(reg_args, 112)); - ldr(reg_wei_dx, ptr(reg_args, 120)); - const XReg q_src_base = x15; - const XReg q_wei_base = x16; - const XReg q_wei2_base = x17; - const XReg reg_wei_blk_stride2 = x10; - ldr(reg_wei_blk_stride2, ptr(reg_args, 80)); - mov(q_src_base, reg_src); - mov(q_wei_base, reg_wei); - mov(q_wei2_base, reg_wei2); - // Treat kw_cnt==0 as 1 - cbnz(reg_kw_cnt, Lkx_d); - mov(reg_kw_cnt, 1); - L(Lkx_d); - // Reset pointers and repeats for this kx - ldr(reg_reps, ptr(reg_args, 40)); - mov(reg_src, q_src_base); - mov(reg_wei, q_wei_base); - mov(reg_wei2, q_wei2_base); - // Channel repeats over 8-lane blocks - Label Lrep_d_kx; - L(Lrep_d_kx); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_d_kx); - // Load src lanes into v0 - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // Load weights for oc0/oc1 (vector fast path if stride==2) - Label Lw_np_d, Lw_done_d2; - cmp(reg_wei_stride, 2); - b(NE, Lw_np_d); - ld1(VReg8H(1), ptr(reg_wei)); - ld1(VReg8H(2), ptr(reg_wei2)); - add(reg_wei, reg_wei, reg_wei_blk_stride2); - add(reg_wei2, reg_wei2, reg_wei_blk_stride2); - b(Lw_done_d2); - L(Lw_np_d); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); - L(Lw_done_d2); - // MAC into accumulators - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); - fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - sub(reg_reps, reg_reps, 1); - b(Lrep_d_kx); - // Tail handling per kx - L(Ltail_prep_d_kx); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - eor(VReg16B(2), VReg16B(2), VReg16B(2)); - cmp(reg_tail, 0); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); - L(Ltail_done_d_kx); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); - fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - // advance bases to next kx and continue - sub(reg_kw_cnt, reg_kw_cnt, 1); - add(q_src_base, q_src_base, reg_src_dx); - add(q_wei_base, q_wei_base, reg_wei_dx); - add(q_wei2_base, q_wei2_base, reg_wei_dx); - cbnz(reg_kw_cnt, Lkx_d); - // reduce and store accumulators - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - faddp(VReg4S(21), VReg4S(21), VReg4S(21)); - faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); - fadd(SReg(1), SReg(1), SReg(21)); - str(SReg(1), ptr(reg_acc2)); - b(Ldone); + // Dual-OC with in-kernel kx loop (v20 for oc0, v21 for oc1) + L(Ldual_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + // Load kx-loop controls and set bases + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_kw_cnt, ptr(reg_args, 104)); + ldr(reg_src_dx, ptr(reg_args, 112)); + ldr(reg_wei_dx, ptr(reg_args, 120)); + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; + const XReg reg_wei_blk_stride2 = x10; + ldr(reg_wei_blk_stride2, ptr(reg_args, 80)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + // Treat kw_cnt==0 as 1 + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + L(Lkx_d); + // Reset pointers and repeats for this kx + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + // Channel repeats over 8-lane blocks + Label Lrep_d_kx; + L(Lrep_d_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + // Load src lanes into v0 + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load weights for oc0/oc1 (vector fast path if stride==2) + Label Lw_np_d, Lw_done_d2; + cmp(reg_wei_stride, 2); + b(NE, Lw_np_d); + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_blk_stride2); + add(reg_wei2, reg_wei2, reg_wei_blk_stride2); + b(Lw_done_d2); + L(Lw_np_d); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Lw_done_d2); + // MAC into accumulators + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d_kx); + // Tail handling per kx + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // advance bases to next kx and continue + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + // reduce and store accumulators + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); - // Dual-OC path: v20 (oc0), v21 (oc1) - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - eor(VReg16B(21), VReg16B(21), VReg16B(21)); + // Dual-OC path: v20 (oc0), v21 (oc1) + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); - Label Lrep_d, Ltail_prep_d, Ltail_done_d; - L(Lrep_d); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_d); - // Load src lanes (v0) - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 - Label Ldw_np_d, Ldw_done_d; - cmp(reg_wei_stride, 2); - b(NE, Ldw_np_d); - ld1(VReg8H(1), ptr(reg_wei)); - ld1(VReg8H(2), ptr(reg_wei2)); - add(reg_wei, reg_wei, 16); - add(reg_wei2, reg_wei2, 16); - b(Ldw_done_d); - L(Ldw_np_d); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); - L(Ldw_done_d); - // MAC into v20/v21 - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); - fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - sub(reg_reps, reg_reps, 1); - b(Lrep_d); - - // Tail handling - L(Ltail_prep_d); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - eor(VReg16B(2), VReg16B(2), VReg16B(2)); - // lanes 0..7 guarded by tail - cmp(reg_tail, 0); - b(LE, Ltail_done_d); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_d); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_d); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_d); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_d); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_d); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_d); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_d); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); + Label Lrep_d, Ltail_prep_d, Ltail_done_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d); + // Load src lanes (v0) + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 + Label Ldw_np_d, Ldw_done_d; + cmp(reg_wei_stride, 2); + b(NE, Ldw_np_d); + ld1(VReg8H(1), ptr(reg_wei)); + ld1(VReg8H(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, 16); + add(reg_wei2, reg_wei2, 16); + b(Ldw_done_d); + L(Ldw_np_d); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ldw_done_d); + // MAC into v20/v21 + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); - L(Ltail_done_d); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); - fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - // horizontal add and store - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - faddp(VReg4S(21), VReg4S(21), VReg4S(21)); - faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); - fadd(SReg(1), SReg(1), SReg(21)); - str(SReg(1), ptr(reg_acc2)); - b(Ldone); - - // Single-OC path - L(Lsingle); - // Jump to in-kernel kx loop single-OC path - b(Lsingle_kx); - // Single-OC with in-kernel kx loop - L(Lsingle_kx); - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - const XReg s_kw_cnt = x12; - const XReg s_src_dx = x13; - const XReg s_wei_dx = x14; - ldr(s_kw_cnt, ptr(reg_args, 104)); - ldr(s_src_dx, ptr(reg_args, 112)); - ldr(s_wei_dx, ptr(reg_args, 120)); - const XReg s_src_base = x15; - const XReg s_wei_base = x16; - const XReg s_wei_blk_stride2 = x10; - ldr(s_wei_blk_stride2, ptr(reg_args, 80)); - mov(s_src_base, reg_src); - mov(s_wei_base, reg_wei); - cbnz(s_kw_cnt, Lkx_s); - mov(s_kw_cnt, 1); - Label Lrep_s_kx; - L(Lkx_s); - ldr(reg_reps, ptr(reg_args, 40)); - mov(reg_src, s_src_base); - mov(reg_wei, s_wei_base); - L(Lrep_s_kx); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_s_kx); - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // weights (vector fast path if stride==2) - Label Lw_np_s, Lw_done_s2; - cmp(reg_wei_stride, 2); - b(NE, Lw_np_s); - ld1(VReg8H(1), ptr(reg_wei)); - add(reg_wei, reg_wei, s_wei_blk_stride2); - b(Lw_done_s2); - L(Lw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Lw_done_s2); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - sub(reg_reps, reg_reps, 1); - b(Lrep_s_kx); - L(Ltail_prep_s_kx); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ltail_done_s_kx); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - // advance bases - sub(s_kw_cnt, s_kw_cnt, 1); - add(s_src_base, s_src_base, s_src_dx); - add(s_wei_base, s_wei_base, s_wei_dx); - cbnz(s_kw_cnt, Lkx_s); - // reduce/store - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - Label Lrep_s, Ltail_prep_s, Ltail_done_s; - L(Lrep_s); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_s); - // src lanes - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // wei lanes — vector fast path if wei_stride==2 - Label Ldw_np_s, Ldw_done_s; - cmp(reg_wei_stride, 2); - b(NE, Ldw_np_s); - ld1(VReg8H(1), ptr(reg_wei)); - add(reg_wei, reg_wei, 16); - b(Ldw_done_s); - L(Ldw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ldw_done_s); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - sub(reg_reps, reg_reps, 1); - b(Lrep_s); + // Tail handling + L(Ltail_prep_d); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + // lanes 0..7 guarded by tail + cmp(reg_tail, 0); + b(LE, Ltail_done_d); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_d); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_d); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_d); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_d); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + ld1(VReg(2).h[7], ptr(reg_wei2)); - // Tail (single) - L(Ltail_prep_s); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); - b(LE, Ltail_done_s); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_s); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_s); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_s); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_s); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_s); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_s); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_s); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ltail_done_s); - fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); - fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); + L(Ltail_done_d); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); + fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); + // horizontal add and store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); - L(Ldone); - ret(); - } + // Single-OC path + L(Lsingle); + // Jump to in-kernel kx loop single-OC path + b(Lsingle_kx); + // Single-OC with in-kernel kx loop + L(Lsingle_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + const XReg s_kw_cnt = x12; + const XReg s_src_dx = x13; + const XReg s_wei_dx = x14; + ldr(s_kw_cnt, ptr(reg_args, 104)); + ldr(s_src_dx, ptr(reg_args, 112)); + ldr(s_wei_dx, ptr(reg_args, 120)); + const XReg s_src_base = x15; + const XReg s_wei_base = x16; + const XReg s_wei_blk_stride2 = x10; + ldr(s_wei_blk_stride2, ptr(reg_args, 80)); + mov(s_src_base, reg_src); + mov(s_wei_base, reg_wei); + cbnz(s_kw_cnt, Lkx_s); + mov(s_kw_cnt, 1); + Label Lrep_s_kx; + L(Lkx_s); + ldr(reg_reps, ptr(reg_args, 40)); + mov(reg_src, s_src_base); + mov(reg_wei, s_wei_base); + L(Lrep_s_kx); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // weights (vector fast path if stride==2) + Label Lw_np_s, Lw_done_s2; + cmp(reg_wei_stride, 2); + b(NE, Lw_np_s); + ld1(VReg8H(1), ptr(reg_wei)); + add(reg_wei, reg_wei, s_wei_blk_stride2); + b(Lw_done_s2); + L(Lw_np_s); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + L(Lw_done_s2); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s_kx); + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s_kx); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + // advance bases + sub(s_kw_cnt, s_kw_cnt, 1); + add(s_src_base, s_src_base, s_src_dx); + add(s_wei_base, s_wei_base, s_wei_dx); + cbnz(s_kw_cnt, Lkx_s); + // reduce/store + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + Label Lrep_s, Ltail_prep_s, Ltail_done_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s); + // src lanes + ld1(VReg(0).h[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[3], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[4], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[5], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[6], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).h[7], ptr(reg_src)); + // wei lanes — vector fast path if wei_stride==2 + Label Ldw_np_s, Ldw_done_s; + cmp(reg_wei_stride, 2); + b(NE, Ldw_np_s); + ld1(VReg8H(1), ptr(reg_wei)); + add(reg_wei, reg_wei, 16); + b(Ldw_done_s); + L(Ldw_np_s); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ldw_done_s); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + // Tail (single) + L(Ltail_prep_s); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s); + ld1(VReg(0).h[0], ptr(reg_src)); + ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s); + ld1(VReg(0).h[1], ptr(reg_src)); + ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s); + ld1(VReg(0).h[2], ptr(reg_src)); + ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s); + ld1(VReg(0).h[3], ptr(reg_src)); + ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); + b(LE, Ltail_done_s); + ld1(VReg(0).h[4], ptr(reg_src)); + ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); + b(LE, Ltail_done_s); + ld1(VReg(0).h[5], ptr(reg_src)); + ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); + b(LE, Ltail_done_s); + ld1(VReg(0).h[6], ptr(reg_src)); + ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); + b(LE, Ltail_done_s); + ld1(VReg(0).h[7], ptr(reg_src)); + ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s); + fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); + fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + + L(Ldone); + ret(); +} +void JitConv3DKernelF16::gen_optimized_kernel() { + using namespace Xbyak_aarch64; // abi_param1 -> args pointer const XReg reg_args = abi_param1; // Use call-clobbered registers to avoid saving/restoring callee-saved regs @@ -1614,13 +1616,19 @@ void JitConv3DKernelF16::generate() { ret(); } +void JitConv3DKernelF16::generate() { + // Keep body small for clang-tidy readability-function-size + gen_minimal_kernel(); + gen_optimized_kernel(); +} + // --------------------------- Executor --------------------------- -static inline const uint16_t* ptr_f16(const MemoryPtr& m) { - return reinterpret_cast(m->getData()); +[[maybe_unused]] static inline auto ptr_f16(const MemoryPtr& mem) -> const uint16_t* { + return reinterpret_cast(mem->getData()); } -[[maybe_unused]] static inline uint16_t* ptr_f16(MemoryPtr& m) { - return reinterpret_cast(m->getData()); +[[maybe_unused]] static inline auto ptr_f16(MemoryPtr& mem) -> uint16_t* { + return reinterpret_cast(mem->getData()); } JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, @@ -1770,8 +1778,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { const size_t kw_count = static_cast(kx_hi - kx_lo + 1); for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { const size_t iz = static_cast(iz0 + kz); - const size_t iy = static_cast(iy0 + ky_lo); - const size_t ix = static_cast(ix0 + kx_lo); + // iy/ix for ky_lo/kx_lo not needed; use iy2/ix2 per ky below if (m_wei_packed_ready) { // Loop over ky in host; kernel handles kx via kw_cnt @@ -2009,7 +2016,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { auto bia = memory.at(ARG_BIAS); const auto bprec = bia->getDescPtr()->getPrecision(); if (bprec == ov::element::f32) { - const float* b = reinterpret_cast(bia->getData()); + const auto* b = reinterpret_cast(bia->getData()); acc0 += b[oc0]; if (has_oc1) acc1 += b[oc1]; @@ -2018,7 +2025,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { if (has_oc3) acc3 += b[oc3]; } else if (bprec == ov::element::f16) { - const uint16_t* b = reinterpret_cast(bia->getData()); + const auto* b = reinterpret_cast(bia->getData()); acc0 += static_cast(ov::float16(b[oc0])); if (has_oc1) acc1 += static_cast(ov::float16(b[oc1])); @@ -2036,21 +2043,21 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { : m_prelu_slopes[std::min(oc, m_prelu_slopes.size() - 1)]; }; const float s0 = slope_at(oc0); - if (acc0 < 0.f) + if (acc0 < 0.0F) acc0 *= s0; if (has_oc1) { const float s1 = slope_at(oc1); - if (acc1 < 0.f) + if (acc1 < 0.0F) acc1 *= s1; } if (has_oc2) { const float s2 = slope_at(oc2); - if (acc2 < 0.f) + if (acc2 < 0.0F) acc2 *= s2; } if (has_oc3) { const float s3 = slope_at(oc3); - if (acc3 < 0.f) + if (acc3 < 0.0F) acc3 *= s3; } } @@ -2091,7 +2098,7 @@ void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { m_padded_C_f32 = (C + 3) / 4 * 4; const size_t total = OC * KD * KH * KW * m_padded_C_f32; m_wei_packed_f32.assign(total, 0.0f); - const float* wsrc = reinterpret_cast(wei->getData()); + const auto* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; @@ -2140,9 +2147,9 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; - const float* src_p = reinterpret_cast(src->getData()); - const float* wei_p = reinterpret_cast(wei->getData()); - float* dst_p = reinterpret_cast(dst->getData()); + const auto* src_p = reinterpret_cast(src->getData()); + const auto* wei_p = reinterpret_cast(wei->getData()); + auto* dst_p = reinterpret_cast(dst->getData()); auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { return (((n * C + c) * ID + z) * IH + y) * IW + x; @@ -2175,7 +2182,7 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { for (size_t ow = 0; ow < OW; ++ow) { const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; - float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; if (SD == 1 && SH == 1 && SW == 1) { const ptrdiff_t kz_lo = std::max(0, -iz0); @@ -2422,7 +2429,7 @@ void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& m // Pack layout: [OC, KD, KH, KW, Ct] const size_t total = OC * KD * KH * KW * m_padded_C; m_wei_packed.assign(total, static_cast(0)); - const uint16_t* wsrc = reinterpret_cast(wei->getData()); + const auto* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index fab4b7264c738f..88e8f0dc8d58cb 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -60,6 +60,10 @@ class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { private: void generate() override; + // Split large codegen into smaller helpers to satisfy clang-tidy limits + void gen_minimal_kernel(); + void gen_optimized_kernel(); + jit_fn ker_{nullptr}; }; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp index ac116c9e37e674..7bb989c1f041e5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -4,6 +4,11 @@ #include "nodes/executors/aarch64/jit_conv3d_f32.hpp" +#include +#include +#include +#include + #include #include #include @@ -344,8 +349,8 @@ void JitConv3DExecutorF32::ensure_weights_packed(const MemoryArgs& memory) { const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; m_padded_C = (C + 3) / 4 * 4; const size_t total = OC * KD * KH * KW * m_padded_C; - m_wei_packed.assign(total, 0.0f); - const float* wsrc = reinterpret_cast(wei->getData()); + m_wei_packed.assign(total, 0.0F); + const auto* wsrc = reinterpret_cast(wei->getData()); auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; @@ -394,9 +399,9 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; - const float* src_p = reinterpret_cast(src->getData()); - const float* wei_p = reinterpret_cast(wei->getData()); - float* dst_p = reinterpret_cast(dst->getData()); + const auto* src_p = reinterpret_cast(src->getData()); + const auto* wei_p = reinterpret_cast(wei->getData()); + auto* dst_p = reinterpret_cast(dst->getData()); auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { return (((n * C + c) * ID + z) * IH + y) * IW + x; @@ -429,7 +434,7 @@ void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { for (size_t ow = 0; ow < OW; ++ow) { const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; - float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; if (SD == 1 && SH == 1 && SW == 1) { const ptrdiff_t kz_lo = std::max(0, -iz0); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index c929b3269e9f30..e43a0b4de4b4b2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -4,12 +4,13 @@ #include "nodes/executors/aarch64/jit_deconv3d.hpp" +#include #include #include #include #include "cpu_memory.h" -#include "memory_desc/cpu_memory_desc.h" +#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/core/type/float16.hpp" @@ -51,7 +52,7 @@ void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector(0)); - const uint16_t* wsrc = reinterpret_cast(src[1]->getData()); + const auto* wsrc = reinterpret_cast(src[1]->getData()); auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; @@ -88,8 +89,8 @@ void JitDeconv3DExecutor::ensure_weights_packed_f32(const std::vector(src[1]->getData()); + m_wei_packed_f32.assign(total, 0.0F); + const auto* wsrc = reinterpret_cast(src[1]->getData()); auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; @@ -179,7 +180,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st { for (size_t oh = 0; oh < OH; ++oh) { for (size_t ow_ = 0; ow_ < OW; ++ow_) { - float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; if (SD == 1 && SH == 1 && SW == 1) { // Fast path: contiguous tap ranges, no modulus checks @@ -316,8 +317,8 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st continue; if (iz_num % static_cast(SD) != 0) continue; - const ptrdiff_t id = iz_num / static_cast(SD); - if (id < 0 || id >= static_cast(ID)) + const ptrdiff_t id_idx = iz_num / static_cast(SD); + if (id_idx < 0 || id_idx >= static_cast(ID)) continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); @@ -341,7 +342,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t s_base0 = idx_src(n, 0, - static_cast(id), + static_cast(id_idx), static_cast(ihh), static_cast(iww)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); @@ -382,23 +383,23 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { const auto& bprec = src[2]->getPrecision(); if (bprec == ov::element::f32) { - const float* b = reinterpret_cast(src[2]->getData()); - acc0 += b[oc0]; + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += bias_ptr[oc0]; if (has_oc1) - acc1 += b[oc1]; + acc1 += bias_ptr[oc1]; if (has_oc2) - acc2 += b[oc2]; + acc2 += bias_ptr[oc2]; if (has_oc3) - acc3 += b[oc3]; + acc3 += bias_ptr[oc3]; } else if (bprec == ov::element::f16) { - const uint16_t* b = reinterpret_cast(src[2]->getData()); - acc0 += static_cast(ov::float16(b[oc0])); + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(bias_ptr[oc0])); if (has_oc1) - acc1 += static_cast(ov::float16(b[oc1])); + acc1 += static_cast(ov::float16(bias_ptr[oc1])); if (has_oc2) - acc2 += static_cast(ov::float16(b[oc2])); + acc2 += static_cast(ov::float16(bias_ptr[oc2])); if (has_oc3) - acc3 += static_cast(ov::float16(b[oc3])); + acc3 += static_cast(ov::float16(bias_ptr[oc3])); } } @@ -438,9 +439,9 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const ptrdiff_t PH0 = deconvAttrs.paddingL.size() > 1 ? deconvAttrs.paddingL[1] : 0; const ptrdiff_t PW0 = deconvAttrs.paddingL.size() > 2 ? deconvAttrs.paddingL[2] : 0; - const float* src_p = reinterpret_cast(src[0]->getData()); - const float* wei_p = reinterpret_cast(src[1]->getData()); - float* dst_p = reinterpret_cast(dst[0]->getData()); + const auto* src_p = reinterpret_cast(src[0]->getData()); + const auto* wei_p = reinterpret_cast(src[1]->getData()); + auto* dst_p = reinterpret_cast(dst[0]->getData()); auto idx_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { return (((n * IC + c) * ID + z) * IH + y) * IW + x; @@ -468,94 +469,94 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st for (size_t od = 0; od < OD; ++od) { for (size_t oh = 0; oh < OH; ++oh) { for (size_t ow_ = 0; ow_ < OW; ++ow_) { - float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; + float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; if (SD == 1 && SH == 1 && SW == 1) { // contiguous tap range in each dimension - const ptrdiff_t tz = static_cast(od) + PD0; - const ptrdiff_t ty = static_cast(oh) + PH0; - const ptrdiff_t tx = static_cast(ow_) + PW0; - const ptrdiff_t kz_lo = std::max(0, tz - static_cast(ID) + 1); - const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tz); - const ptrdiff_t ky_lo = std::max(0, ty - static_cast(IH) + 1); - const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, ty); - const ptrdiff_t kx_lo = std::max(0, tx - static_cast(IW) + 1); - const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, tx); + const ptrdiff_t tz_pos = static_cast(od) + PD0; + const ptrdiff_t ty_pos = static_cast(oh) + PH0; + const ptrdiff_t tx_pos = static_cast(ow_) + PW0; + const ptrdiff_t kz_lo = std::max(0, tz_pos - static_cast(ID) + 1); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tz_pos); + const ptrdiff_t ky_lo = std::max(0, ty_pos - static_cast(IH) + 1); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, ty_pos); + const ptrdiff_t kx_lo = std::max(0, tx_pos - static_cast(IW) + 1); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, tx_pos); if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + const auto kw_count = static_cast(kx_hi - kx_lo + 1); for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { - const size_t iz = static_cast(tz - kz); - const size_t ky_base = static_cast(ky_lo); - const size_t iy0 = static_cast(ty - ky_lo); - const size_t ix0 = static_cast(tx - kx_lo); + const auto iz_idx = static_cast(tz_pos - kz); + const auto ky_base = static_cast(ky_lo); + const auto iy0 = static_cast(ty_pos - ky_lo); + const auto ix0 = static_cast(tx_pos - kx_lo); for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy = static_cast(ty - ky); - const size_t ix = ix0; + const size_t iy_idx = static_cast(ty_pos - ky); + const size_t ix_idx = ix0; (void)iy0; (void)ky_base; - const size_t s_base = idx_src(n, 0, iz, iy, ix); + const size_t s_base = idx_src(n, 0, iz_idx, iy_idx, ix_idx); // pair 0 { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = IC / 4; - a.tail = IC % 4; - a.kw_cnt = kw_count; - a.src_dx = sizeof(float); + jit_conv3d_f32_call_args args{}; + args.src = src_p + s_base; + args.src_stride = src_c_stride_elems * sizeof(float); + args.src_blk_stride = args.src_stride * 4; + args.acc = &acc0; + args.acc2 = has_oc1 ? &acc1 : nullptr; + args.repeats = IC / 4; + args.tail = IC % 4; + args.kw_cnt = kw_count; + args.src_dx = sizeof(float); const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - a.wei = m_wei_packed_f32.data() + base0; + args.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - a.wei2 = m_wei_packed_f32.data() + base1; + args.wei2 = m_wei_packed_f32.data() + base1; } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = m_padded_IC_f32 * sizeof(float); - (*m_ip_kernel_f32)(&a); + args.wei_stride = sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; + args.wei_dx = m_padded_IC_f32 * sizeof(float); + (*m_ip_kernel_f32)(&args); } // pair 1 if (has_oc2) { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = IC / 4; - a.tail = IC % 4; - a.kw_cnt = kw_count; - a.src_dx = sizeof(float); + jit_conv3d_f32_call_args args2{}; + args2.src = src_p + s_base; + args2.src_stride = src_c_stride_elems * sizeof(float); + args2.src_blk_stride = args2.src_stride * 4; + args2.acc = &acc2; + args2.acc2 = has_oc3 ? &acc3 : nullptr; + args2.repeats = IC / 4; + args2.tail = IC % 4; + args2.kw_cnt = kw_count; + args2.src_dx = sizeof(float); const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - a.wei = m_wei_packed_f32.data() + base2; + args2.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - a.wei2 = m_wei_packed_f32.data() + base3; + args2.wei2 = m_wei_packed_f32.data() + base3; } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = m_padded_IC_f32 * sizeof(float); - (*m_ip_kernel_f32)(&a); + args2.wei_stride = sizeof(float); + args2.wei_blk_stride = args2.wei_stride * 4; + args2.wei_dx = m_padded_IC_f32 * sizeof(float); + (*m_ip_kernel_f32)(&args2); } } } @@ -568,8 +569,8 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st continue; if (iz_num % static_cast(SD) != 0) continue; - const ptrdiff_t id = iz_num / static_cast(SD); - if (id < 0 || id >= static_cast(ID)) + const ptrdiff_t id_idx = iz_num / static_cast(SD); + if (id_idx < 0 || id_idx >= static_cast(ID)) continue; for (size_t ky = 0; ky < KH; ++ky) { const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); @@ -593,39 +594,39 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t s_base0 = idx_src(n, 0, - static_cast(id), + static_cast(id_idx), static_cast(ihh), static_cast(iww)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = IC / 4; - a.tail = IC % 4; + jit_conv3d_f32_call_args args{}; + args.src = src_p + s_base0; + args.src_stride = src_c_stride_elems * sizeof(float); + args.src_blk_stride = args.src_stride * 4; + args.acc = &acc0; + args.acc2 = has_oc1 ? &acc1 : nullptr; + args.repeats = IC / 4; + args.tail = IC % 4; if (m_wei_packed_ready_f32) { const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a.wei = m_wei_packed_f32.data() + pack_base0; + args.wei = m_wei_packed_f32.data() + pack_base0; if (has_oc1) { const size_t pack_base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a.wei2 = m_wei_packed_f32.data() + pack_base1; + args.wei2 = m_wei_packed_f32.data() + pack_base1; } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; + args.wei_stride = sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; } else { - a.wei = wei_p + w_base0; + args.wei = wei_p + w_base0; if (has_oc1) - a.wei2 = wei_p + w_base1; - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; + args.wei2 = wei_p + w_base1; + args.wei_stride = wei_ic_stride_elems * sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; } - (*m_ip_kernel_f32)(&a); + (*m_ip_kernel_f32)(&args); } } } @@ -634,23 +635,23 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { const auto& bprec = src[2]->getPrecision(); if (bprec == ov::element::f32) { - const float* b = reinterpret_cast(src[2]->getData()); - acc0 += b[oc0]; + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += bias_ptr[oc0]; if (has_oc1) - acc1 += b[oc1]; + acc1 += bias_ptr[oc1]; if (has_oc2) - acc2 += b[oc2]; + acc2 += bias_ptr[oc2]; if (has_oc3) - acc3 += b[oc3]; + acc3 += bias_ptr[oc3]; } else if (bprec == ov::element::f16) { - const uint16_t* b = reinterpret_cast(src[2]->getData()); - acc0 += static_cast(ov::float16(b[oc0])); + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0 += static_cast(ov::float16(bias_ptr[oc0])); if (has_oc1) - acc1 += static_cast(ov::float16(b[oc1])); + acc1 += static_cast(ov::float16(bias_ptr[oc1])); if (has_oc2) - acc2 += static_cast(ov::float16(b[oc2])); + acc2 += static_cast(ov::float16(bias_ptr[oc2])); if (has_oc3) - acc3 += static_cast(ov::float16(b[oc3])); + acc3 += static_cast(ov::float16(bias_ptr[oc3])); } } @@ -677,11 +678,13 @@ bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, dstDescs[0]->getShape().getRank() != 5) { return false; } - const auto s0 = srcDescs[0]->getPrecision(); - const auto s1 = srcDescs[1]->getPrecision(); - const auto d0 = dstDescs[0]->getPrecision(); - const bool fp16_ok = (s0 == ov::element::f16 && s1 == ov::element::f16 && d0 == ov::element::f16); - const bool fp32_ok = (s0 == ov::element::f32 && s1 == ov::element::f32 && d0 == ov::element::f32); + const auto src0_prec = srcDescs[0]->getPrecision(); + const auto src1_prec = srcDescs[1]->getPrecision(); + const auto dst0_prec = dstDescs[0]->getPrecision(); + const bool fp16_ok = + (src0_prec == ov::element::f16 && src1_prec == ov::element::f16 && dst0_prec == ov::element::f16); + const bool fp32_ok = + (src0_prec == ov::element::f32 && src1_prec == ov::element::f32 && dst0_prec == ov::element::f32); return fp16_ok || fp32_ok; } From 787e3f093e628476159d719f51873dd500785086 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Tue, 21 Oct 2025 18:29:00 +0200 Subject: [PATCH 07/20] Add early weight preparation and alternative packing for S=2 in AArch64 JIT 3D Deconvolution Executors to optimize initialization and grouped workloads. --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 17 + .../nodes/executors/aarch64/jit_conv3d.cpp | 18 +- .../nodes/executors/aarch64/jit_conv3d.hpp | 3 + .../executors/aarch64/jit_conv3d_f32.cpp | 9 + .../nodes/executors/aarch64/jit_deconv3d.cpp | 2238 +++++++++++++++-- .../nodes/executors/aarch64/jit_deconv3d.hpp | 10 + 6 files changed, 2051 insertions(+), 244 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 658ab78bf0e959..cffcea2f6b45a1 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -39,6 +39,10 @@ #include "nodes/common/blocked_desc_creator.h" #include "nodes/common/dnnl_executor.h" #include "nodes/executors/deconv_list.hpp" +#include "utils/arch_macros.h" +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif #include "nodes/executors/executor.hpp" #include "nodes/node_config.h" #include "onednn/dnnl.h" @@ -1134,6 +1138,19 @@ void Deconvolution::prepareParams() { execPtr = result.first; OPENVINO_ASSERT(execPtr, "Primitive descriptor was not found for node ", getName(), "."); +#if defined(OPENVINO_ARCH_ARM64) + // Early weight packing for AArch64 JIT to minimize first-inference latency + if (auto jitExec = std::dynamic_pointer_cast(execPtr)) { + std::vector srcMemories; + // src[0] = input, src[1] = weights, src[2] = bias (optional) + srcMemories.push_back(getSrcMemoryAtPort(0)); + srcMemories.push_back(getSrcMemoryAtPort(1)); + // Bias not needed for packing + jitExec->prepare_weights_early(srcMemories); + } +#endif + + primArgs[DNNL_ARG_SRC] = srcMemPtr->getPrimitive(); primArgs[DNNL_ARG_DST] = dstMemPtr->getPrimitive(); if (weightIsConst) { diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 2a55a7ed834d7b..3ee52439221675 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -721,8 +721,10 @@ void JitConv3DKernelF16::gen_optimized_kernel() { ldr(reg_src_dy, ptr(reg_args, 152)); // src dy step in bytes (y dimension) ldr(reg_wei_dy, ptr(reg_args, 160)); // wei dy step in bytes (y dimension) - // Force single-ky iteration and disable quad-OC for stability on macOS arm64 - mov(reg_kh_cnt, 1); + // Optionally force single-ky iteration for stability on certain platforms + if (m_force_single_kh_) { + mov(reg_kh_cnt, 1); + } eor(reg_acc4, reg_acc4, reg_acc4); Label Lsingle, Lend_all; @@ -1667,7 +1669,7 @@ JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, } bool JitConv3DExecutor::supports(const ConvConfig& cfg) { - // Require 5D NCDHW, FP16/FP32 src/wei/dst, group=1, no dilation, stride 1 or 2 + // Require 5D NCDHW, FP16 or FP32 src/wei/dst, group=1, no dilation, stride 1 or 2 if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) return false; if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) @@ -2398,14 +2400,14 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { } } - // Store (convert FP32 accumulators to FP16 bits) - dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); + // Store FP32 accumulators directly to FP32 destination + dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; if (has_oc1) - dst_p[index_dst(n, oc1, od, oh, ow)] = ov::float16(acc1).to_bits(); + dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; if (has_oc2) - dst_p[index_dst(n, oc2, od, oh, ow)] = ov::float16(acc2).to_bits(); + dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; if (has_oc3) - dst_p[index_dst(n, oc3, od, oh, ow)] = ov::float16(acc3).to_bits(); + dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; } } } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index 88e8f0dc8d58cb..0e4b7189850476 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -65,6 +65,9 @@ class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { void gen_optimized_kernel(); jit_fn ker_{nullptr}; + bool m_force_single_kh_{true}; +public: + void set_force_single_kh(bool v) { m_force_single_kh_ = v; } }; // AArch64 JIT Convolution (FP16) executor for 3D conv (NCDHW) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp index 7bb989c1f041e5..8b58711cc34cae 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -141,7 +141,12 @@ void JitConv3DKernelF32::generate() { add(reg_wei2, reg_wei2, reg_wei_stride); ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); + // advance to next 4-channel block for next repeat + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); L(Lw_done_d); + // advance src to next 4-channel block for next repeat + add(reg_src, reg_src, reg_src_stride); // MAC fmla(VReg4S(20), VReg4S(0), VReg4S(1)); fmla(VReg4S(21), VReg4S(0), VReg4S(2)); @@ -245,7 +250,11 @@ void JitConv3DKernelF32::generate() { ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).s[3], ptr(reg_wei)); + // advance to next 4-channel block for next repeat + add(reg_wei, reg_wei, reg_wei_stride); L(Lw_done_s); + // advance src to next 4-channel block for next repeat + add(reg_src, reg_src, reg_src_stride); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); sub(reg_reps, reg_reps, 1); b(Lrep_s); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index e43a0b4de4b4b2..8a0f06dceec12a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -34,6 +34,10 @@ bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, m_ip_kernel_f32->create_ker(); } else { m_ip_kernel_f16 = std::make_unique(); + // Allow enabling in-kernel ky loop via env for FP16 + if (std::getenv("OV_AARCH64_DECONV3D_KY_LOOP_F16") && std::string(std::getenv("OV_AARCH64_DECONV3D_KY_LOOP_F16")) == "1") { + m_ip_kernel_f16->set_force_single_kh(false); + } m_ip_kernel_f16->create_ker(); } return true; @@ -42,78 +46,305 @@ bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, void JitDeconv3DExecutor::ensure_weights_packed_f16(const std::vector& src) { if (m_wei_packed_ready_f16) return; - // src[1] holds weights for deconv with shape [IC, OC, KD, KH, KW] + // src[1] holds weights for deconv with shape: + // - no-group: [IC, OC, KD, KH, KW] + // - group: [G, ICg, OCg, KD, KH, KW] const auto& weiDims = src[1]->getStaticDims(); - if (weiDims.size() != 5) - return; - const size_t IC = weiDims[0]; - const size_t OC = weiDims[1]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - m_padded_IC_f16 = (IC + 7) / 8 * 8; - const size_t total = OC * KD * KH * KW * m_padded_IC_f16; - m_wei_packed_f16.assign(total, static_cast(0)); const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f16 = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_f16.assign(total, static_cast(0)); - auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; - }; - auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; - const size_t blk = ic / 8; - const size_t lane = ic % 8; - return base + blk * 8 + lane; - }; + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t blk = ic / 8; + const size_t lane = ic % 8; + return base + blk * 8 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_f16[idx_wei_pack(oc, ic, kz, ky, kx)] = + wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f16 = true; + return; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f16 = (ICg + 7) / 8 * 8; // per-group padding + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_f16.assign(total, static_cast(0)); - for (size_t oc = 0; oc < OC; ++oc) { - for (size_t kz = 0; kz < KD; ++kz) { - for (size_t ky = 0; ky < KH; ++ky) { - for (size_t kx = 0; kx < KW; ++kx) { - for (size_t ic = 0; ic < IC; ++ic) { - m_wei_packed_f16[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + auto idx_wei_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) -> size_t { + // layout [G, ICg, OCg, KD, KH, KW] + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + auto idx_wei_pack = [&](size_t oc_global, size_t icg, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t blk = icg / 8; + const size_t lane = icg % 8; + return base + blk * 8 + lane; + }; + + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_f16[idx_wei_pack(oc_global, icg, kz, ky, kx)] = + wsrc[idx_wei_src_g(g, icg, ocg, kz, ky, kx)]; + } + } } } } } + m_wei_packed_ready_f16 = true; + return; } - m_wei_packed_ready_f16 = true; } void JitDeconv3DExecutor::ensure_weights_packed_f32(const std::vector& src) { if (m_wei_packed_ready_f32) return; const auto& weiDims = src[1]->getStaticDims(); - if (weiDims.size() != 5) - return; - const size_t IC = weiDims[0]; - const size_t OC = weiDims[1]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - m_padded_IC_f32 = (IC + 3) / 4 * 4; - const size_t total = OC * KD * KH * KW * m_padded_IC_f32; - m_wei_packed_f32.assign(total, 0.0F); const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f32 = (IC + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_f32.assign(total, 0.0F); - auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; - }; - auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - const size_t blk = ic / 4; - const size_t lane = ic % 4; - return base + blk * 4 + lane; - }; + auto idx_wei_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + auto idx_wei_pack = [&](size_t oc, size_t ic, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t blk = ic / 4; + const size_t lane = ic % 4; + return base + blk * 4 + lane; + }; + + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_f32[idx_wei_pack(oc, ic, kz, ky, kx)] = + wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_ready_f32 = true; + return; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f32 = (ICg + 3) / 4 * 4; + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_f32.assign(total, 0.0F); + + auto idx_wei_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) -> size_t { + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + auto idx_wei_pack = [&](size_t oc_global, size_t icg, size_t kz, size_t ky, size_t kx) -> size_t { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t blk = icg / 4; + const size_t lane = icg % 4; + return base + blk * 4 + lane; + }; + + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + for (size_t kx = 0; kx < KW; ++kx) { + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_f32[idx_wei_pack(oc_global, icg, kz, ky, kx)] = + wsrc[idx_wei_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + } + m_wei_packed_ready_f32 = true; + return; + } +} + +// Alternative even/odd packing for S=2 (FP16) +void JitDeconv3DExecutor::ensure_weights_packed_s2_f16(const std::vector& src) { + if (m_wei_packed_s2_ready_f16) + return; + const auto& weiDims = src[1]->getStaticDims(); + const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f16 = (IC + 7) / 8 * 8; + const size_t total = OC * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_s2_f16.assign(total, static_cast(0)); + auto idx_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + // evens + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_s2_f16[base + (ic / 8) * 8 + (ic % 8)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; + } + } + // odds + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t ic = 0; ic < IC; ++ic) { + m_wei_packed_s2_f16[base + (ic / 8) * 8 + (ic % 8)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_s2_ready_f16 = true; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f16 = (ICg + 7) / 8 * 8; + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f16; + m_wei_packed_s2_f16.assign(total, static_cast(0)); + auto idx_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) { + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_s2_f16[base + (icg / 8) * 8 + (icg % 8)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f16; + for (size_t icg = 0; icg < ICg; ++icg) { + m_wei_packed_s2_f16[base + (icg / 8) * 8 + (icg % 8)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + } + m_wei_packed_s2_ready_f16 = true; + } +} - for (size_t oc = 0; oc < OC; ++oc) { - for (size_t kz = 0; kz < KD; ++kz) { - for (size_t ky = 0; ky < KH; ++ky) { - for (size_t kx = 0; kx < KW; ++kx) { - for (size_t ic = 0; ic < IC; ++ic) { - m_wei_packed_f32[idx_wei_pack(oc, ic, kz, ky, kx)] = wsrc[idx_wei_src(ic, oc, kz, ky, kx)]; +// Alternative even/odd packing for S=2 (FP32) +void JitDeconv3DExecutor::ensure_weights_packed_s2_f32(const std::vector& src) { + if (m_wei_packed_s2_ready_f32) + return; + const auto& weiDims = src[1]->getStaticDims(); + const auto* wsrc = reinterpret_cast(src[1]->getData()); + if (weiDims.size() == 5) { + const size_t IC = weiDims[0]; + const size_t OC = weiDims[1]; + const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + m_padded_IC_f32 = (IC + 3) / 4 * 4; + const size_t total = OC * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_s2_f32.assign(total, 0.0F); + auto idx_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { + return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + }; + for (size_t oc = 0; oc < OC; ++oc) { + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; + for (size_t ic = 0; ic < IC; ++ic) + m_wei_packed_s2_f32[base + (ic / 4) * 4 + (ic % 4)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; + } + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; + for (size_t ic = 0; ic < IC; ++ic) + m_wei_packed_s2_f32[base + (ic / 4) * 4 + (ic % 4)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; } } } } + m_wei_packed_s2_ready_f32 = true; + } else if (weiDims.size() == 6) { + const size_t G = weiDims[0]; + const size_t ICg = weiDims[1]; + const size_t OCg = weiDims[2]; + const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; + const size_t OC_total = G * OCg; + m_padded_IC_f32 = (ICg + 3) / 4 * 4; + const size_t total = OC_total * KD * KH * KW * m_padded_IC_f32; + m_wei_packed_s2_f32.assign(total, 0.0F); + auto idx_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) { + return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); + }; + for (size_t g = 0; g < G; ++g) { + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + for (size_t kz = 0; kz < KD; ++kz) { + for (size_t ky = 0; ky < KH; ++ky) { + size_t pos = 0; + for (size_t kx = 0; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; + for (size_t icg = 0; icg < ICg; ++icg) + m_wei_packed_s2_f32[base + (icg / 4) * 4 + (icg % 4)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + for (size_t kx = 1; kx < KW; kx += 2, ++pos) { + const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; + for (size_t icg = 0; icg < ICg; ++icg) + m_wei_packed_s2_f32[base + (icg / 4) * 4 + (icg % 4)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; + } + } + } + } + } + m_wei_packed_s2_ready_f32 = true; } - m_wei_packed_ready_f32 = true; } void JitDeconv3DExecutor::exec(const std::vector& src, @@ -126,6 +357,50 @@ void JitDeconv3DExecutor::exec(const std::vector& src, } } +// helper toggles declared below + +static inline bool deconv3d_pack_enabled() { return true; } + +static inline bool deconv3d_fastpath_f16_enabled() { return true; } +static inline bool deconv3d_fastpath_f32_enabled() { return true; } + +static inline bool deconv3d_fastpath_f32_s2_enabled() { return true; } + +static inline bool deconv3d_kyloop_f16_enabled() { return false; } + +static inline bool deconv3d_s2_grouped_enabled() { return true; } + +static inline bool deconv3d_tile2_enabled() { return true; } + +static inline bool deconv3d_tile2_f32_enabled() { return false; } + +static inline bool deconv3d_pack_s2_enabled() { return true; } + +static inline bool deconv3d_prefetch_enabled() { return true; } + +static inline bool deconv3d_force_ref() { return false; } + +static inline bool deconv3d_s2_kxloop_f16_enabled() { return false; } + +void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { + // Pack weights (and S=2 even/odd layout) ahead of first exec to reduce cold-start latency + if (m_is_fp32) { + if (deconv3d_pack_enabled()) { + ensure_weights_packed_f32(src); + if (deconv3d_pack_s2_enabled()) { + ensure_weights_packed_s2_f32(src); + } + } + } else { + if (deconv3d_pack_enabled()) { + ensure_weights_packed_f16(src); + if (deconv3d_pack_s2_enabled()) { + ensure_weights_packed_s2_f16(src); + } + } + } +} + void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const std::vector& dst) { // NCDHW, fp16: compute each output pixel (n, oc, od, oh, ow) as a sum over (ic, kz, ky, kx) const auto& srcDims = src[0]->getStaticDims(); @@ -135,9 +410,14 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t N = srcDims[0]; const size_t IC = srcDims[1]; const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; - // Deconv weights layout: [IC, OC, KD, KH, KW] - const size_t OC = weiDims[1]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + // Use OC from dst to support grouped weights layout + const size_t OC = dstDims[1]; + // Weights: no-group [IC, OC, KD, KH, KW]; grouped [G, ICg, OCg, KD, KH, KW] + const bool grouped = weiDims.size() == 6; + [[maybe_unused]] const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + const size_t KD = weiDims[grouped ? 3 : 2], KH = weiDims[grouped ? 4 : 3], KW = weiDims[grouped ? 5 : 4]; const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; @@ -158,31 +438,117 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { return (((n * OC + c) * OD + z) * OH + y) * OW + x; }; - // weight [IC, OC, KD, KH, KW] - auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { - return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + // weight: no-group [IC, OC, KD, KH, KW]; grouped [G, ICg, OCg, KD, KH, KW] + auto idx_wei = [&](size_t ic_or_icg, size_t oc_global, size_t kz, size_t ky, size_t kx) { + if (!grouped) { + return ((((ic_or_icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + const size_t g = oc_global / OCg; + const size_t ocg = oc_global % OCg; + return ((((((g * ICg + ic_or_icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); }; // Strides in elements const size_t src_c_stride_elems = ID * IH * IW; - const size_t wei_ic_stride_elems = OC * KD * KH * KW; + const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; + + if (deconv3d_pack_enabled()) { + ensure_weights_packed_f16(src); + if (deconv3d_pack_s2_enabled()) { + ensure_weights_packed_s2_f16(src); + } + } + const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; + const ptrdiff_t OPH0 = deconvAttrs.outputPadding.size() > 1 ? deconvAttrs.outputPadding[1] : 0; + const ptrdiff_t OPW0 = deconvAttrs.outputPadding.size() > 2 ? deconvAttrs.outputPadding[2] : 0; + + // Effective dilations are stored as (dilation - 1) inside attrs; convert to actual factors + const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; + const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; + const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; + + // Reference-correct fallback (env: OV_AARCH64_DECONV3D_REF=1) + if (deconv3d_force_ref()) { + const bool grouped = weiDims.size() == 6; + const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + std::fill_n(dst_p, N * OC * OD * OH * OW, static_cast(0)); + for (size_t n = 0; n < N; ++n) { + for (size_t g = 0; g < G; ++g) { + for (size_t id = 0; id < ID; ++id) { + for (size_t ih = 0; ih < IH; ++ih) { + for (size_t iw_ = 0; iw_ < IW; ++iw_) { + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t od = static_cast(id) * static_cast(SD) - PD0 + + static_cast(kz * dilD) + OPD0; + if (od < 0 || od >= static_cast(OD)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t oh = static_cast(ih) * static_cast(SH) - PH0 + + static_cast(ky * dilH) + OPH0; + if (oh < 0 || oh >= static_cast(OH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ow = static_cast(iw_) * static_cast(SW) - + PW0 + static_cast(kx * dilW) + OPW0; + if (ow < 0 || ow >= static_cast(OW)) + continue; + for (size_t icg = 0; icg < ICg; ++icg) { + const size_t ic_global = g * ICg + icg; + const float sval = static_cast( + ov::float16(src_p[idx_src(n, ic_global, id, ih, iw_)])); + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + size_t w_off; + if (grouped) { + // layout [G, ICg, OCg, KD, KH, KW] + w_off = + (((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW + + kx); + } else { + // layout [IC, OC, KD, KH, KW] + w_off = ((((icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + const float w = static_cast(ov::float16(wei_p[w_off])); + const size_t doff = idx_dst(n, + oc_global, + static_cast(od), + static_cast(oh), + static_cast(ow)); + const float accum = + sval * w + static_cast(ov::float16(dst_p[doff])); + dst_p[doff] = ov::float16(accum).to_bits(); + } + } + } + } + } + } + } + } + } + } + return; + } - ensure_weights_packed_f16(src); auto worker = [&](size_t n, size_t oc_quad, size_t od) { const size_t oc0 = oc_quad * 4; + const size_t g = OCg ? (oc0 / OCg) : 0; + const size_t ocg0 = OCg ? (oc0 % OCg) : oc0; const size_t oc1 = oc0 + 1; const size_t oc2 = oc0 + 2; const size_t oc3 = oc0 + 3; - const bool has_oc1 = oc1 < OC; - const bool has_oc2 = oc2 < OC; - const bool has_oc3 = oc3 < OC; + const bool has_oc1 = (ocg0 + 1) < OCg && oc1 < OC; + const bool has_oc2 = (ocg0 + 2) < OCg && oc2 < OC; + const bool has_oc3 = (ocg0 + 3) < OCg && oc3 < OC; const size_t n_base = n * IC * ID * IH * IW; { for (size_t oh = 0; oh < OH; ++oh) { for (size_t ow_ = 0; ow_ < OW; ++ow_) { float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; - if (SD == 1 && SH == 1 && SW == 1) { + if (deconv3d_fastpath_f16_enabled() && SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { // Fast path: contiguous tap ranges, no modulus checks const ptrdiff_t tzd = static_cast(od) + PD0; const ptrdiff_t tyd = static_cast(oh) + PH0; @@ -199,55 +565,225 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { const size_t id = static_cast(tzd - kz); const size_t src_z_off = id * IH * IW; - const size_t kh_count = static_cast(ky_hi - ky_lo + 1); - const size_t ihh = static_cast(tyd - ky_lo); - const size_t src_y_off = ihh * IW; - size_t s_base_row = n_base + src_z_off + src_y_off; - const size_t kw_count = static_cast(kx_hi - kx_lo + 1); + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + (void)kx_hi; + (void)kx_lo; if (m_wei_packed_ready_f16) { - const size_t s_base0 = s_base_row + static_cast(txd - kx_lo); - // Compute packed bases for ky_lo - size_t pack_base_z0 = (oc0 * KD + static_cast(kz)) * KH; - size_t pack_base_z1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; - // oc2/oc3 computed in second dual call; no need for precomputed bases - size_t pack_base_y0 = (pack_base_z0 + static_cast(ky_lo)) * KW; - size_t pack_base_y1 = - has_oc1 ? (pack_base_z1 + static_cast(ky_lo)) * KW : 0; - // oc2/oc3 will be handled in the second dual call below - const size_t pack_base0 = - (pack_base_y0 + static_cast(kx_lo)) * m_padded_IC_f16; - const size_t pack_base1 = - has_oc1 ? (pack_base_y1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - // Compute only oc0/oc1 in this call; oc2/oc3 will be handled by a second dual call - a.repeats = IC / 8; - a.tail = IC % 8; - a.kw_cnt = kw_count; - a.kh_cnt = kh_count; - a.src_dx = sizeof(uint16_t); - a.src_dy = IW * sizeof(uint16_t); - a.wei = m_wei_packed_f16.data() + pack_base0; - if (has_oc1) - a.wei2 = m_wei_packed_f16.data() + pack_base1; - // oc2/oc3 handled in a follow-up dual call - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); - a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); - (*m_ip_kernel_f16)(&a); - } else { - // Generic ky+kx loops (not packed) + if (deconv3d_kyloop_f16_enabled()) { + // In-kernel ky + kx (packed weights): + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + const auto kh_count = static_cast(ky_hi - ky_lo + 1); + const size_t s_base_x0 = s_base_row + static_cast(tyd - ky_lo) * IW + static_cast(txd); + // Start from rightmost tap for positive src_dx + const size_t s_base0 = s_base_x0 - static_cast(kx_hi); + // Precompute packed bases at ky_lo + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t py0 = (pz0 + static_cast(ky_lo)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky_lo)) * KW : 0; + const size_t base0 = (py0 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? (py1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.kh_cnt = kh_count; + a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); + a.wei = m_wei_packed_f16.data() + base0; + if (has_oc1) + a.wei2 = m_wei_packed_f16.data() + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + const size_t pz2 = (oc2 * KD + static_cast(kz)) * KH; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + const size_t py2 = (pz2 + static_cast(ky_lo)) * KW; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky_lo)) * KW : 0; + const size_t base2 = (py2 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? (py3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.kh_cnt = kh_count; + a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); + a.wei = m_wei_packed_f16.data() + base2; + if (has_oc3) + a.wei2 = m_wei_packed_f16.data() + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + continue; // handled full ky range + } for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t ihh2 = static_cast(tyd - ky); - size_t s_base_row2 = n_base + src_z_off + ihh2 * IW; - size_t iww = static_cast(txd - kx_lo); - for (ptrdiff_t kx = kx_lo; kx <= kx_hi; ++kx, ++iww) { - const size_t s_base0 = s_base_row2 + iww; + const size_t ihh = static_cast(tyd - ky); + const size_t s_base_x0 = s_base_row + ihh * IW + static_cast(txd); + // Precompute ky-dependent packed bases (no kx loop in-kernel) + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + // Start from rightmost tap to keep src_dx positive (+1 element per kx) + const size_t s_base0 = s_base_x0 - static_cast(kx_hi); + // Packed weights advance by padded_IC per kx; start from leftmost kx + const size_t base0 = (py0 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? (py1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + // pair 0: oc0/oc1 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = m_wei_packed_f16.data() + base0; + if (has_oc1) + a.wei2 = m_wei_packed_f16.data() + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1: oc2/oc3 + if (has_oc2) { + const size_t pz2 = (oc2 * KD + static_cast(kz)) * KH; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + const size_t py2 = (pz2 + static_cast(ky)) * KW; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const size_t base2 = (py2 + static_cast(kx_lo)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? (py3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.wei = m_wei_packed_f16.data() + base2; + if (has_oc3) + a.wei2 = m_wei_packed_f16.data() + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } else { + // Raw weights fast-path + if (deconv3d_kyloop_f16_enabled()) { + // In-kernel ky + kx (raw weights): + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + const auto kh_count = static_cast(ky_hi - ky_lo + 1); + const size_t ih0 = static_cast(tyd - ky_lo); + const size_t s_base_row2 = + n_base + (g * ICg) * src_c_stride_elems + src_z_off + ih0 * IW; + const size_t s_base0 = s_base_row2 + static_cast(txd - kx_hi); + // pair 0 + { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.kh_cnt = kh_count; + a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); + const size_t w_base0 = idx_wei(0, + oc0, + static_cast(kz), + static_cast(ky_lo), + static_cast(kx_lo)); + const size_t w_base1 = has_oc1 ? idx_wei(0, + oc1, + static_cast(kz), + static_cast(ky_lo), + static_cast(kx_lo)) + : 0; + a.wei = wei_p + w_base0; + if (has_oc1) + a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + a.wei_dy = KW * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + a.kh_cnt = kh_count; + a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); + const size_t w_base2 = idx_wei(0, + oc2, + static_cast(kz), + static_cast(ky_lo), + static_cast(kx_lo)); + const size_t w_base3 = has_oc3 ? idx_wei(0, + oc3, + static_cast(kz), + static_cast(ky_lo), + static_cast(kx_lo)) + : 0; + a.wei = wei_p + w_base2; + if (has_oc3) + a.wei2 = wei_p + w_base3; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + a.wei_dy = KW * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } else { + // In-kernel kx only + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t ihh2 = static_cast(tyd - ky); + const size_t s_base_row2 = n_base + (g * ICg) * src_c_stride_elems + src_z_off + ihh2 * IW; + const auto kw_count = static_cast(kx_hi - kx_lo + 1); + const size_t s_base0 = s_base_row2 + static_cast(txd - kx_hi); // pair 0 { jit_conv3d_call_args a{}; @@ -256,24 +792,17 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.src_blk_stride = a.src_stride * 8; a.acc = &acc0; a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = IC / 8; - a.tail = IC % 8; - const size_t w_base0 = idx_wei(0, - oc0, - static_cast(kz), - static_cast(ky), - static_cast(kx)); - const size_t w_base1 = has_oc1 ? idx_wei(0, - oc1, - static_cast(kz), - static_cast(ky), - static_cast(kx)) - : 0; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; a.wei = wei_p + w_base0; - if (has_oc1) - a.wei2 = wei_p + w_base1; + if (has_oc1) a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); (*m_ip_kernel_f16)(&a); } // pair 1 @@ -284,24 +813,455 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.src_blk_stride = a.src_stride * 8; a.acc = &acc2; a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = IC / 8; - a.tail = IC % 8; - const size_t w_base2 = idx_wei(0, - oc2, - static_cast(kz), - static_cast(ky), - static_cast(kx)); - const size_t w_base3 = has_oc3 ? idx_wei(0, - oc3, - static_cast(kz), - static_cast(ky), - static_cast(kx)) - : 0; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; a.wei = wei_p + w_base2; - if (has_oc3) - a.wei2 = wei_p + w_base3; + if (has_oc3) a.wei2 = wei_p + w_base3; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } + } + } + } + } else if (deconv3d_fastpath_f16_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && m_wei_packed_ready_f16 && (!grouped || deconv3d_s2_grouped_enabled())) { + // Fast path S=2, dil=1 (packed weights): parity-filtered taps without modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + // X2 micro-tiling over output width for stride=2: compute (ow, ow+2) together when possible + if (deconv3d_tile2_enabled() && (ow_ + 2) < OW) { + float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ + float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 + + const ptrdiff_t txd1 = static_cast(ow_ + 2) + PW0; + const ptrdiff_t kx_lo1 = std::max(0, txd1 - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi1 = std::min(static_cast(KW) - 1, txd1); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + + + + // Even/odd S=2 packing selection + const bool use_s2_pack_tile2 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; + const uint16_t* wei_pack_ptr_tile2 = use_s2_pack_tile2 ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + auto pack_index_eo_tile2 = [&](size_t py, size_t kx) { + if (!use_s2_pack_tile2) return py + kx; + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + + // Pass A: kx subset for txd (ow_) + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw0 = static_cast((txd - kx) / 2); + if (iw0 >= IW) continue; + const size_t iw1 = iw0 + 1; // for ow_+2 + const size_t s_base0 = s_base_row + ih * IW + iw0; + // pair 0 for ow_ + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0a; + a.acc2 = has_oc1 ? &acc1a : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + if (iw1 < IW) { + const size_t s_base1 = s_base0 + 1; + // pair 0 for ow_+2 + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + } + // pair 1 (oc2/oc3), ow_ + if (has_oc2) { + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2a; + a.acc2 = has_oc3 ? &acc3a : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + // pair 1 for ow_+2 + if (has_oc2 && (iw1 < IW)) { + const size_t s_base1 = s_base0 + 1; + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + } + + // Pass B: extra kx subset for txd1 (ow_+2) only (complement of Pass A) + for (ptrdiff_t kx = kx_lo1 + ((txd1 - kx_lo1) & 1); kx <= kx_hi1; kx += 2) { + const ptrdiff_t iw0_tmp = (txd - kx) / 2; // may be out-of-range or wrong parity for ow_ + const bool covered_in_A = (kx >= kx_lo && kx <= kx_hi && (((txd - kx) & 1) == 0) && (iw0_tmp >= 0 && iw0_tmp < static_cast(IW))); + if (covered_in_A) continue; // already accumulated in Pass A for ow_+2 + const ptrdiff_t iw1_tmp = (txd1 - kx) / 2; + if (iw1_tmp < 0 || iw1_tmp >= static_cast(IW)) continue; + const size_t iw1 = static_cast(iw1_tmp); + const size_t s_base1 = s_base_row + ih * IW + iw1; + // pair 0 for ow_+2 only + { + const size_t base0 = pack_index_eo_tile2(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_tile2(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_tile2 + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + if (has_oc2) { + const size_t base2 = pack_index_eo_tile2(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_tile2(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_tile2 + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_tile2 + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + } + } + } + } + + // Optional fused bias for both outputs (ow_, ow_+2) + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += bias_ptr[oc0]; + if (has_oc1) acc1a += bias_ptr[oc1]; + if (has_oc2) acc2a += bias_ptr[oc2]; + if (has_oc3) acc3a += bias_ptr[oc3]; + acc0b += bias_ptr[oc0]; + if (has_oc1) acc1b += bias_ptr[oc1]; + if (has_oc2) acc2b += bias_ptr[oc2]; + if (has_oc3) acc3b += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1a += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2a += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3a += static_cast(ov::float16(bias_ptr[oc3])); + acc0b += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1b += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2b += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3b += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + // Store both outputs + dst_p[idx_dst(n, oc0, od, oh, ow_)] = ov::float16(acc0a).to_bits(); + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = ov::float16(acc1a).to_bits(); + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = ov::float16(acc2a).to_bits(); + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = ov::float16(acc3a).to_bits(); + + const size_t ow2 = ow_ + 2; + dst_p[idx_dst(n, oc0, od, oh, ow2)] = ov::float16(acc0b).to_bits(); + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow2)] = ov::float16(acc1b).to_bits(); + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow2)] = ov::float16(acc2b).to_bits(); + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow2)] = ov::float16(acc3b).to_bits(); + + ow_ += 2; // skip next two positions (we computed ow and ow+2); for-loop ++ will advance to ow+3 + continue; + } + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + if (deconv3d_s2_kxloop_f16_enabled()) { + // In-kernel kx loop for parity taps: step weights by 2*IC and src by -1 in X + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const bool use_s2_pack = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; + const uint16_t* wei_pack_ptr = use_s2_pack ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + auto pack_index_eo = [&](size_t py, size_t kx) { + if (!use_s2_pack) return py + kx; + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + const ptrdiff_t kx_start = kx_lo + ((txd - kx_lo) & 1); + const size_t iw_start = static_cast((txd - kx_start) / 2); + if (iw_start >= IW) continue; + const size_t s_base0 = s_base_row + ih * IW + iw_start; + const size_t kw_count = static_cast((kx_hi - kx_start) / 2 + 1); + // pair 0 + { + const size_t base0 = pack_index_eo(py0, static_cast(kx_start)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo(py1, static_cast(kx_start)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = static_cast(-static_cast(sizeof(uint16_t))); + a.wei = wei_pack_ptr + base0; + if (has_oc1) a.wei2 = wei_pack_ptr + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = (use_s2_pack ? m_padded_IC_f16 : 2 * m_padded_IC_f16) * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + const size_t base2 = pack_index_eo(py2, static_cast(kx_start)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo(py3, static_cast(kx_start)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = kw_count; + a.src_dx = static_cast(-static_cast(sizeof(uint16_t))); + a.wei = wei_pack_ptr + base2; + if (has_oc3) a.wei2 = wei_pack_ptr + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = (use_s2_pack ? m_padded_IC_f16 : 2 * m_padded_IC_f16) * sizeof(uint16_t); + (*m_ip_kernel_f16)(&a); + } + } + } + } else { + // Per-tap parity stepping (original path) + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + const bool use_s2_pack_orig = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; + const uint16_t* wei_pack_ptr_orig = use_s2_pack_orig ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + auto pack_index_eo_orig = [&](size_t py, size_t kx) { + if (!use_s2_pack_orig) return py + kx; + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw = static_cast((txd - kx) / 2); + if (iw >= IW) continue; + const size_t s_base0 = s_base_row + ih * IW + iw; + // pair 0 + { + const size_t base0 = pack_index_eo_orig(py0, static_cast(kx)) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? pack_index_eo_orig(py1, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_orig + base0; + if (has_oc1) a.wei2 = wei_pack_ptr_orig + base1; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } + (*m_ip_kernel_f16)(&a); + } + // pair 1 + if (has_oc2) { + const size_t base2 = pack_index_eo_orig(py2, static_cast(kx)) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? pack_index_eo_orig(py3, static_cast(kx)) * m_padded_IC_f16 : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_pack_ptr_orig + base2; + if (has_oc3) a.wei2 = wei_pack_ptr_orig + base3; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + if (deconv3d_prefetch_enabled()) { + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); + } (*m_ip_kernel_f16)(&a); } } @@ -309,42 +1269,137 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } + } else if (deconv3d_fastpath_f16_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && !m_wei_packed_ready_f16 && (!grouped || deconv3d_s2_grouped_enabled())) { + // Fast path S=2, dil=1 (raw weights): parity-filtered taps without modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw = static_cast((txd - kx) / 2); + if (iw >= IW) continue; + const size_t s_base0 = s_base_row + ih * IW + iw; + // pair 0 (oc0/oc1) + { + const size_t w_base0 = idx_wei(0, + oc0, + static_cast(kz), + static_cast(ky), + static_cast(kx)); + const size_t w_base1 = has_oc1 ? idx_wei(0, + oc1, + static_cast(kz), + static_cast(ky), + static_cast(kx)) + : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w_base0; + if (has_oc1) + a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + (*m_ip_kernel_f16)(&a); + } + // pair 1 (oc2/oc3) + if (has_oc2) { + const size_t w_base2 = idx_wei(0, + oc2, + static_cast(kz), + static_cast(ky), + static_cast(kx)); + const size_t w_base3 = has_oc3 ? idx_wei(0, + oc3, + static_cast(kz), + static_cast(ky), + static_cast(kx)) + : 0; + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = wei_p + w_base2; + if (has_oc3) + a.wei2 = wei_p + w_base3; + a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; + (*m_ip_kernel_f16)(&a); + } + } + } + } + } } else { - // Generic path (stride > 1): keep modulus checks + // Generic path (stride/dilation): modulus checks for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); + const ptrdiff_t id_num = + static_cast(od) + PD0 - static_cast(kz * dilD); if (SD == 0) continue; - if (iz_num % static_cast(SD) != 0) + if (id_num % static_cast(SD) != 0) continue; - const ptrdiff_t id_idx = iz_num / static_cast(SD); + const ptrdiff_t id_idx = id_num / static_cast(SD); if (id_idx < 0 || id_idx >= static_cast(ID)) continue; for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); + const ptrdiff_t iy_num = + static_cast(oh) + PH0 - static_cast(ky * dilH); if (SH == 0) continue; if (iy_num % static_cast(SH) != 0) continue; - const ptrdiff_t ihh = iy_num / static_cast(SH); - if (ihh < 0 || ihh >= static_cast(IH)) + const ptrdiff_t ih_idx = iy_num / static_cast(SH); + if (ih_idx < 0 || ih_idx >= static_cast(IH)) continue; for (size_t kx = 0; kx < KW; ++kx) { const ptrdiff_t ix_num = - static_cast(ow_) + PW0 - static_cast(kx); + static_cast(ow_) + PW0 - static_cast(kx * dilW); if (SW == 0) continue; if (ix_num % static_cast(SW) != 0) continue; - const ptrdiff_t iww = ix_num / static_cast(SW); - if (iww < 0 || iww >= static_cast(IW)) + const ptrdiff_t iw_idx = ix_num / static_cast(SW); + if (iw_idx < 0 || iw_idx >= static_cast(IW)) continue; const size_t s_base0 = idx_src(n, - 0, + g * ICg, static_cast(id_idx), - static_cast(ihh), - static_cast(iww)); + static_cast(ih_idx), + static_cast(iw_idx)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; @@ -354,8 +1409,10 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.src_blk_stride = a.src_stride * 8; a.acc = &acc0; a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = IC / 8; - a.tail = IC % 8; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; if (m_wei_packed_ready_f16) { const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; @@ -367,14 +1424,53 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; } else { a.wei = wei_p + w_base0; if (has_oc1) a.wei2 = wei_p + w_base1; a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = 0; } (*m_ip_kernel_f16)(&a); + + // second pair for oc2/oc3 + if (has_oc2) { + const size_t w_base2 = idx_wei(0, oc2, kz, ky, kx); + const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, kz, ky, kx) : 0; + jit_conv3d_call_args a2{}; + a2.src = src_p + s_base0; + a2.src_stride = src_c_stride_elems * sizeof(uint16_t); + a2.src_blk_stride = a2.src_stride * 8; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = ICg / 8; + a2.tail = ICg % 8; + a2.kw_cnt = 1; + a2.src_dx = 0; + if (m_wei_packed_ready_f16) { + const size_t pack_base2 = + (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + a2.wei = m_wei_packed_f16.data() + pack_base2; + if (has_oc3) { + const size_t pack_base3 = + (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + a2.wei2 = m_wei_packed_f16.data() + pack_base3; + } + a2.wei_stride = sizeof(uint16_t); + a2.wei_blk_stride = a2.wei_stride * 8; + a2.wei_dx = 0; + } else { + a2.wei = wei_p + w_base2; + if (has_oc3) + a2.wei2 = wei_p + w_base3; + a2.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); + a2.wei_blk_stride = a2.wei_stride * 8; + a2.wei_dx = 0; + } + (*m_ip_kernel_f16)(&a2); + } } } } @@ -427,8 +1523,12 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t N = srcDims[0]; const size_t IC = srcDims[1]; const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; - const size_t OC = weiDims[1]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; + const size_t OC = dstDims[1]; + const bool grouped = weiDims.size() == 6; + [[maybe_unused]] const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + const size_t KD = weiDims[grouped ? 3 : 2], KH = weiDims[grouped ? 4 : 3], KW = weiDims[grouped ? 5 : 4]; const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; const size_t SD = deconvAttrs.stride.size() > 0 ? static_cast(deconvAttrs.stride[0]) : 1; @@ -449,29 +1549,110 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st auto idx_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { return (((n * OC + c) * OD + z) * OH + y) * OW + x; }; - auto idx_wei = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { - return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; + auto idx_wei = [&](size_t ic_or_icg, size_t oc_global, size_t kz, size_t ky, size_t kx) { + if (!grouped) { + return ((((ic_or_icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + const size_t g = oc_global / OCg; + const size_t ocg = oc_global % OCg; + return ((((((g * ICg + ic_or_icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); }; const size_t src_c_stride_elems = ID * IH * IW; - const size_t wei_ic_stride_elems = OC * KD * KH * KW; + const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; + + if (deconv3d_pack_enabled()) { + ensure_weights_packed_f32(src); + if (deconv3d_pack_s2_enabled()) { + ensure_weights_packed_s2_f32(src); + } + } + // Output padding and dilations + const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; + const ptrdiff_t OPH0 = deconvAttrs.outputPadding.size() > 1 ? deconvAttrs.outputPadding[1] : 0; + const ptrdiff_t OPW0 = deconvAttrs.outputPadding.size() > 2 ? deconvAttrs.outputPadding[2] : 0; + const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; + const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; + const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; + const bool dbg_check = []() { + const char* e = std::getenv("OV_AARCH64_DECONV3D_CHECK"); + return (e && e[0] == '1' && e[1] == '\0'); + }(); + // Reference-correct fallback (env: OV_AARCH64_DECONV3D_REF=1) + if (deconv3d_force_ref()) { + const bool grouped = weiDims.size() == 6; + const size_t G = grouped ? weiDims[0] : 1; + const size_t ICg = grouped ? weiDims[1] : IC; + const size_t OCg = grouped ? weiDims[2] : OC; + std::fill_n(dst_p, N * OC * OD * OH * OW, 0.0F); + for (size_t n = 0; n < N; ++n) { + for (size_t g = 0; g < G; ++g) { + for (size_t id = 0; id < ID; ++id) { + for (size_t ih = 0; ih < IH; ++ih) { + for (size_t iw_ = 0; iw_ < IW; ++iw_) { + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t od = static_cast(id) * static_cast(SD) - PD0 + + static_cast(kz * dilD) + OPD0; + if (od < 0 || od >= static_cast(OD)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t oh = static_cast(ih) * static_cast(SH) - PH0 + + static_cast(ky * dilH) + OPH0; + if (oh < 0 || oh >= static_cast(OH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ow = static_cast(iw_) * static_cast(SW) - + PW0 + static_cast(kx * dilW) + OPW0; + if (ow < 0 || ow >= static_cast(OW)) + continue; + for (size_t icg = 0; icg < ICg; ++icg) { + const size_t ic_global = g * ICg + icg; + const float sval = src_p[idx_src(n, ic_global, id, ih, iw_)]; + for (size_t ocg = 0; ocg < OCg; ++ocg) { + const size_t oc_global = g * OCg + ocg; + size_t w_off; + if (grouped) { + w_off = + (((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW + + kx); + } else { + w_off = ((((icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; + } + dst_p[idx_dst(n, + oc_global, + static_cast(od), + static_cast(oh), + static_cast(ow))] += sval * wei_p[w_off]; + } + } + } + } + } + } + } + } + } + } + return; + } - ensure_weights_packed_f32(src); ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { const size_t oc0 = oc_quad * 4; + const size_t g = OCg ? (oc0 / OCg) : 0; + const size_t ocg0 = OCg ? (oc0 % OCg) : oc0; const size_t oc1 = std::min(oc0 + 1, OC); const size_t oc2 = std::min(oc0 + 2, OC); const size_t oc3 = std::min(oc0 + 3, OC); - const bool has_oc1 = oc1 < OC; - const bool has_oc2 = oc2 < OC; - const bool has_oc3 = oc3 < OC; + const bool has_oc1 = (ocg0 + 1) < OCg && oc1 < OC; + const bool has_oc2 = (ocg0 + 2) < OCg && oc2 < OC; + const bool has_oc3 = (ocg0 + 3) < OCg && oc3 < OC; for (size_t od = 0; od < OD; ++od) { for (size_t oh = 0; oh < OH; ++oh) { for (size_t ow_ = 0; ow_ < OW; ++ow_) { float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; - if (SD == 1 && SH == 1 && SW == 1) { + if (deconv3d_fastpath_f32_enabled() && SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { // contiguous tap range in each dimension const ptrdiff_t tz_pos = static_cast(od) + PD0; const ptrdiff_t ty_pos = static_cast(oh) + PH0; @@ -494,7 +1675,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t ix_idx = ix0; (void)iy0; (void)ky_base; - const size_t s_base = idx_src(n, 0, iz_idx, iy_idx, ix_idx); + const size_t s_base = idx_src(n, g * ICg, iz_idx, iy_idx, ix_idx); // pair 0 { @@ -504,27 +1685,47 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args.src_blk_stride = args.src_stride * 4; args.acc = &acc0; args.acc2 = has_oc1 ? &acc1 : nullptr; - args.repeats = IC / 4; - args.tail = IC % 4; + args.repeats = ICg / 4; + args.tail = ICg % 4; args.kw_cnt = kw_count; - args.src_dx = sizeof(float); - const size_t base0 = - (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_IC_f32; - args.wei = m_wei_packed_f32.data() + base0; - if (has_oc1) { - const size_t base1 = - (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * + args.src_dx = static_cast(-static_cast(sizeof(float))); + if (m_wei_packed_ready_f32) { + const size_t base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - args.wei2 = m_wei_packed_f32.data() + base1; + args.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args.wei2 = m_wei_packed_f32.data() + base1; + } + args.wei_stride = sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; + args.wei_dx = m_padded_IC_f32 * sizeof(float); + } else { + const size_t w_base0 = idx_wei(0, + oc0, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w_base1 = has_oc1 ? idx_wei(0, + oc1, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; + args.wei = wei_p + w_base0; + if (has_oc1) + args.wei2 = wei_p + w_base1; + args.wei_stride = wei_ic_stride_elems * sizeof(float); + args.wei_blk_stride = args.wei_stride * 4; + args.wei_dx = sizeof(float); } - args.wei_stride = sizeof(float); - args.wei_blk_stride = args.wei_stride * 4; - args.wei_dx = m_padded_IC_f32 * sizeof(float); (*m_ip_kernel_f32)(&args); } // pair 1 @@ -535,98 +1736,595 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args2.src_blk_stride = args2.src_stride * 4; args2.acc = &acc2; args2.acc2 = has_oc3 ? &acc3 : nullptr; - args2.repeats = IC / 4; - args2.tail = IC % 4; + args2.repeats = ICg / 4; + args2.tail = ICg % 4; args2.kw_cnt = kw_count; - args2.src_dx = sizeof(float); - const size_t base2 = - (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_IC_f32; - args2.wei = m_wei_packed_f32.data() + base2; - if (has_oc3) { - const size_t base3 = - (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * + args2.src_dx = static_cast(-static_cast(sizeof(float))); + if (m_wei_packed_ready_f32) { + const size_t base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_IC_f32; - args2.wei2 = m_wei_packed_f32.data() + base3; + args2.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + + static_cast(ky)) * + KW + + static_cast(kx_lo)) * + m_padded_IC_f32; + args2.wei2 = m_wei_packed_f32.data() + base3; + } + args2.wei_stride = sizeof(float); + args2.wei_blk_stride = args2.wei_stride * 4; + args2.wei_dx = m_padded_IC_f32 * sizeof(float); + } else { + const size_t w_base2 = idx_wei(0, + oc2, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)); + const size_t w_base3 = has_oc3 ? idx_wei(0, + oc3, + static_cast(kz), + static_cast(ky), + static_cast(kx_lo)) + : 0; + args2.wei = wei_p + w_base2; + if (has_oc3) + args2.wei2 = wei_p + w_base3; + args2.wei_stride = wei_ic_stride_elems * sizeof(float); + args2.wei_blk_stride = args2.wei_stride * 4; + args2.wei_dx = sizeof(float); } - args2.wei_stride = sizeof(float); - args2.wei_blk_stride = args2.wei_stride * 4; - args2.wei_dx = m_padded_IC_f32 * sizeof(float); (*m_ip_kernel_f32)(&args2); } } } } + } else if (deconv3d_fastpath_f32_s2_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && (!grouped || deconv3d_s2_grouped_enabled())) { + // Fast path S=2, dil=1 (packed weights preferred): parity-filtered taps without modulus checks + const ptrdiff_t tzd = static_cast(od) + PD0; + const ptrdiff_t tyd = static_cast(oh) + PH0; + const ptrdiff_t txd = static_cast(ow_) + PW0; + + const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); + const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); + const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); + const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); + const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); + + // X2 micro-tiling over output width for stride=2: compute (ow_, ow_+2) together (FP32 uses its own gate) + if (deconv3d_tile2_f32_enabled() && (ow_ + 2) < OW) { + float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ + float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 + const ptrdiff_t txd1 = static_cast(ow_ + 2) + PW0; + const ptrdiff_t kx_lo1 = std::max(0, txd1 - static_cast(IW * 2) + 2); + const ptrdiff_t kx_hi1 = std::min(static_cast(KW) - 1, txd1); + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t src_z_off = id * IH * IW; + const size_t src_cg0 = g * ICg; + size_t s_base_row = n * IC * ID * IH * IW + src_cg0 * src_c_stride_elems + src_z_off; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + + // FP32 S=2 even/odd packing support (tile2 branch) + const bool use_s2_pack_tile2_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; + const float* wei_pack_ptr_tile2_f32 = use_s2_pack_tile2_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); + auto pack_index_eo_tile2_f32 = [&](size_t py, size_t kx) { + if (!use_s2_pack_tile2_f32) return py + kx; + const size_t even_count = (KW + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); + }; + + // Pass A: main kx set valid for ow_ + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw0 = static_cast((txd - kx) / 2); + if (iw0 >= IW) continue; + const size_t iw1 = iw0 + 1; // for ow_+2 + const size_t s_base0 = s_base_row + ih * IW + iw0; + // pair 0 for ow_ + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0a; + a.acc2 = has_oc1 ? &acc1a : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base0; + if (has_oc1) { + const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base0; + if (has_oc1) { + const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base1; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + // For ow_+2 if in-bounds + if (iw1 < IW) { + const size_t s_base1 = s_base0 + 1; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base0; + if (has_oc1) { + const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base0; + if (has_oc1) { + const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base1; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + // pair 1 (oc2/oc3) for ow_ + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2a; + a.acc2 = has_oc3 ? &acc3a : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base2; + if (has_oc3) { + const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base2; + if (has_oc3) { + const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base3; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + // pair 1 for ow_+2 + if (has_oc2 && (iw1 < IW)) { + const size_t s_base1_b = s_base0 + 1; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1_b; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base2; + if (has_oc3) { + const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base2; + if (has_oc3) { + const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base3; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + } + + // Pass B: extra kx set valid only для ow_+2 (комплемент по паритету/границам) + for (ptrdiff_t kx = kx_lo1 + ((txd1 - kx_lo1) & 1); kx <= kx_hi1; kx += 2) { + const ptrdiff_t iw0_tmp = (txd - kx) / 2; + const bool covered_in_A = (kx >= kx_lo && kx <= kx_hi && (((txd - kx) & 1) == 0) && + (iw0_tmp >= 0 && iw0_tmp < static_cast(IW))); + if (covered_in_A) continue; + const ptrdiff_t iw1_tmp = (txd1 - kx) / 2; + if (iw1_tmp < 0 || iw1_tmp >= static_cast(IW)) continue; + const size_t iw1 = static_cast(iw1_tmp); + const size_t s_base1 = s_base_row + ih * IW + iw1; + // pair 0 for ow_+2 only + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0b; + a.acc2 = has_oc1 ? &acc1b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base0; + if (has_oc1) { + const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base0; + if (has_oc1) { + const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base1; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base1; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2b; + a.acc2 = has_oc3 ? &acc3b : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_tile2_f32 + base2; + if (has_oc3) { + const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_tile2_f32 + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base2; + if (has_oc3) { + const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base3; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + } + } + } + } + + // Optional fused bias for both outputs + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData() != nullptr) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += bias_ptr[oc0]; + if (has_oc1) acc1a += bias_ptr[oc1]; + if (has_oc2) acc2a += bias_ptr[oc2]; + if (has_oc3) acc3a += bias_ptr[oc3]; + acc0b += bias_ptr[oc0]; + if (has_oc1) acc1b += bias_ptr[oc1]; + if (has_oc2) acc2b += bias_ptr[oc2]; + if (has_oc3) acc3b += bias_ptr[oc3]; + } else if (bprec == ov::element::f16) { + const auto* bias_ptr = reinterpret_cast(src[2]->getData()); + acc0a += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1a += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2a += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3a += static_cast(ov::float16(bias_ptr[oc3])); + acc0b += static_cast(ov::float16(bias_ptr[oc0])); + if (has_oc1) acc1b += static_cast(ov::float16(bias_ptr[oc1])); + if (has_oc2) acc2b += static_cast(ov::float16(bias_ptr[oc2])); + if (has_oc3) acc3b += static_cast(ov::float16(bias_ptr[oc3])); + } + } + + // Store results for both outputs + dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0a; + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1a; + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow_)] = acc2a; + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow_)] = acc3a; + + const size_t ow2 = ow_ + 2; + dst_p[idx_dst(n, oc0, od, oh, ow2)] = acc0b; + if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow2)] = acc1b; + if (has_oc2) dst_p[idx_dst(n, oc2, od, oh, ow2)] = acc2b; + if (has_oc3) dst_p[idx_dst(n, oc3, od, oh, ow2)] = acc3b; + + ow_ += 2; // skip next two positions; for-loop ++ will advance to ow_+3 + continue; + } + + if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { + for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { + const size_t id = static_cast((tzd - kz) / 2); + if (id >= ID) continue; + const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; + const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; + const size_t pz2 = has_oc2 ? (oc2 * KD + static_cast(kz)) * KH : 0; + const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; + for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { + const size_t ih = static_cast((tyd - ky) / 2); + if (ih >= IH) continue; + const size_t py0 = (pz0 + static_cast(ky)) * KW; + const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; + const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; + const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; + for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { + const size_t iw = static_cast((txd - kx) / 2); + if (iw >= IW) continue; + // Base source offset for this (id, ih, iw) + const size_t s_base0 = idx_src(n, g * ICg, id, ih, iw); + // pair 0 (oc0, oc1) + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + // FP32 S=2 even/odd packing selection (non-tile2) + const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; + const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); + auto pack_index_eo_f32 = [&](size_t py, size_t kx_) { + if (!use_s2_pack_f32) return py + kx_; + const size_t even_count = (KW + 1) / 2; + return py + ((kx_ & 1) ? (even_count + (kx_ / 2)) : (kx_ / 2)); + }; + const size_t base0 = pack_index_eo_f32(py0, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_f32 + base0; + if (has_oc1) { + const size_t base1 = pack_index_eo_f32(py1, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_f32 + base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base0; + if (has_oc1) { + const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base1; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + // pair 1 (oc2, oc3) + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc2; + a.acc2 = has_oc3 ? &acc3 : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; + const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); + auto pack_index_eo_f32 = [&](size_t py, size_t kx_) { + if (!use_s2_pack_f32) return py + kx_; + const size_t even_count = (KW + 1) / 2; + return py + ((kx_ & 1) ? (even_count + (kx_ / 2)) : (kx_ / 2)); + }; + const size_t base2 = pack_index_eo_f32(py2, static_cast(kx)) * m_padded_IC_f32; + a.wei = wei_pack_ptr_f32 + base2; + if (has_oc3) { + const size_t base3 = pack_index_eo_f32(py3, static_cast(kx)) * m_padded_IC_f32; + a.wei2 = wei_pack_ptr_f32 + base3; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } else { + const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei = wei_p + w_base2; + if (has_oc3) { + const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); + a.wei2 = wei_p + w_base3; + } + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + } + (*m_ip_kernel_f32)(&a); + } + } + } + } + } } else { - // generic stride path with modulus checks + // Generic path (stride/dilation): modulus checks for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t iz_num = static_cast(od) + PD0 - static_cast(kz); + const ptrdiff_t id_num = + static_cast(od) + PD0 - static_cast(kz * dilD); if (SD == 0) continue; - if (iz_num % static_cast(SD) != 0) + if (id_num % static_cast(SD) != 0) continue; - const ptrdiff_t id_idx = iz_num / static_cast(SD); + const ptrdiff_t id_idx = id_num / static_cast(SD); if (id_idx < 0 || id_idx >= static_cast(ID)) continue; for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t iy_num = static_cast(oh) + PH0 - static_cast(ky); + const ptrdiff_t iy_num = + static_cast(oh) + PH0 - static_cast(ky * dilH); if (SH == 0) continue; if (iy_num % static_cast(SH) != 0) continue; - const ptrdiff_t ihh = iy_num / static_cast(SH); - if (ihh < 0 || ihh >= static_cast(IH)) + const ptrdiff_t ih_idx = iy_num / static_cast(SH); + if (ih_idx < 0 || ih_idx >= static_cast(IH)) continue; for (size_t kx = 0; kx < KW; ++kx) { const ptrdiff_t ix_num = - static_cast(ow_) + PW0 - static_cast(kx); + static_cast(ow_) + PW0 - static_cast(kx * dilW); if (SW == 0) continue; if (ix_num % static_cast(SW) != 0) continue; - const ptrdiff_t iww = ix_num / static_cast(SW); - if (iww < 0 || iww >= static_cast(IW)) + const ptrdiff_t iw_idx = ix_num / static_cast(SW); + if (iw_idx < 0 || iw_idx >= static_cast(IW)) continue; const size_t s_base0 = idx_src(n, - 0, + g * ICg, static_cast(id_idx), - static_cast(ihh), - static_cast(iww)); + static_cast(ih_idx), + static_cast(iw_idx)); const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; - jit_conv3d_f32_call_args args{}; - args.src = src_p + s_base0; - args.src_stride = src_c_stride_elems * sizeof(float); - args.src_blk_stride = args.src_stride * 4; - args.acc = &acc0; - args.acc2 = has_oc1 ? &acc1 : nullptr; - args.repeats = IC / 4; - args.tail = IC % 4; - if (m_wei_packed_ready_f32) { - const size_t pack_base0 = - (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - args.wei = m_wei_packed_f32.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - args.wei2 = m_wei_packed_f32.data() + pack_base1; + // pair 0 + { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(float); + a.src_blk_stride = a.src_stride * 4; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = ICg / 4; + a.tail = ICg % 4; + a.kw_cnt = 1; + a.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t pack_base0 = + (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = + (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + pack_base1; + } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + } else { + a.wei = wei_p + w_base0; + if (has_oc1) + a.wei2 = wei_p + w_base1; + a.wei_stride = wei_ic_stride_elems * sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; } - args.wei_stride = sizeof(float); - args.wei_blk_stride = args.wei_stride * 4; - } else { - args.wei = wei_p + w_base0; - if (has_oc1) - args.wei2 = wei_p + w_base1; - args.wei_stride = wei_ic_stride_elems * sizeof(float); - args.wei_blk_stride = args.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + const size_t w_base2 = idx_wei(0, oc2, kz, ky, kx); + const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, kz, ky, kx) : 0; + jit_conv3d_f32_call_args a2{}; + a2.src = src_p + s_base0; + a2.src_stride = src_c_stride_elems * sizeof(float); + a2.src_blk_stride = a2.src_stride * 4; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = ICg / 4; + a2.tail = ICg % 4; + a2.kw_cnt = 1; + a2.src_dx = 0; + if (m_wei_packed_ready_f32) { + const size_t pack_base2 = + (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a2.wei = m_wei_packed_f32.data() + pack_base2; + if (has_oc3) { + const size_t pack_base3 = + (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + a2.wei2 = m_wei_packed_f32.data() + pack_base3; + } + a2.wei_stride = sizeof(float); + a2.wei_blk_stride = a2.wei_stride * 4; + } else { + a2.wei = wei_p + w_base2; + if (has_oc3) + a2.wei2 = wei_p + w_base3; + a2.wei_stride = wei_ic_stride_elems * sizeof(float); + a2.wei_blk_stride = a2.wei_stride * 4; + } + a2.wei_dx = 0; + (*m_ip_kernel_f32)(&a2); } - (*m_ip_kernel_f32)(&args); } } } @@ -655,6 +2353,72 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } } + if (dbg_check && n == 0 && od < 2 && oh < 2 && ow_ < 2) { + auto ref_acc = [&](size_t ocg_idx) { + float r = 0.0f; + // Reference: output-driven accumulation, same mapping as generic + for (size_t kz = 0; kz < KD; ++kz) { + const ptrdiff_t id_num = static_cast(od) + PD0 - + static_cast(kz * dilD); + if (id_num % static_cast(SD) != 0) + continue; + const ptrdiff_t id_idx = id_num / static_cast(SD); + if (id_idx < 0 || id_idx >= static_cast(ID)) + continue; + for (size_t ky = 0; ky < KH; ++ky) { + const ptrdiff_t iy_num = static_cast(oh) + PH0 - + static_cast(ky * dilH); + if (iy_num % static_cast(SH) != 0) + continue; + const ptrdiff_t ih_idx = iy_num / static_cast(SH); + if (ih_idx < 0 || ih_idx >= static_cast(IH)) + continue; + for (size_t kx = 0; kx < KW; ++kx) { + const ptrdiff_t ix_num = static_cast(ow_) + PW0 - + static_cast(kx * dilW); + if (ix_num % static_cast(SW) != 0) + continue; + const ptrdiff_t iw_idx = ix_num / static_cast(SW); + if (iw_idx < 0 || iw_idx >= static_cast(IW)) + continue; + const size_t s_off = idx_src(n, + g * ICg, + static_cast(id_idx), + static_cast(ih_idx), + static_cast(iw_idx)); + for (size_t icg = 0; icg < ICg; ++icg) { + const size_t w_off = idx_wei(icg, g * OCg + ocg_idx, kz, ky, kx); + r += src_p[s_off + icg * src_c_stride_elems] * wei_p[w_off]; + } + } + } + } + if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData()) { + const auto& bprec = src[2]->getPrecision(); + if (bprec == ov::element::f32) { + r += reinterpret_cast(src[2]->getData())[g * OCg + ocg_idx]; + } else if (bprec == ov::element::f16) { + r += static_cast(ov::float16( + reinterpret_cast(src[2]->getData())[g * OCg + ocg_idx])); + } + } + return r; + }; + const float r0 = ref_acc(ocg0 + 0); + if (std::fabs(r0 - acc0) > 1e-3f) { + fprintf(stderr, + "[DECONV3D-CHK] mismatch at (n=%zu,g=%zu,oc=%zu,od=%zu,oh=%zu,ow=%zu): ref=%f got=%f\n", + n, + g, + oc0, + od, + oh, + ow_, + r0, + acc0); + } + } + dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0; if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; @@ -674,8 +2438,10 @@ bool AArch64JitDeconvExecutorBuilder::isSupported(const DeconvAttrs& attrs, // Support 5D NCDHW, fp16 and fp32 if (srcDescs.size() < 2 || dstDescs.empty()) return false; - if (srcDescs[0]->getShape().getRank() != 5 || srcDescs[1]->getShape().getRank() != 5 || - dstDescs[0]->getShape().getRank() != 5) { + const auto src0_rank = srcDescs[0]->getShape().getRank(); + const auto wei_rank = srcDescs[1]->getShape().getRank(); + const auto dst0_rank = dstDescs[0]->getShape().getRank(); + if (src0_rank != 5 || (wei_rank != 5 && wei_rank != 6) || dst0_rank != 5) { return false; } const auto src0_prec = srcDescs[0]->getPrecision(); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index c59f3538a39073..4d6dc361c406b3 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -31,6 +31,9 @@ class JitDeconv3DExecutor : public DeconvExecutor { return impl_desc_type::jit_asimd; } + // Early weight preparation to avoid first-inference overhead + void prepare_weights_early(const std::vector& src); + private: std::vector m_srcDescs; std::vector m_dstDescs; @@ -42,13 +45,20 @@ class JitDeconv3DExecutor : public DeconvExecutor { // packed weights std::vector m_wei_packed_f16; std::vector m_wei_packed_f32; + // alternative packing for S=2 (even/odd taps) + std::vector m_wei_packed_s2_f16; + std::vector m_wei_packed_s2_f32; bool m_wei_packed_ready_f16{false}; bool m_wei_packed_ready_f32{false}; + bool m_wei_packed_s2_ready_f16{false}; + bool m_wei_packed_s2_ready_f32{false}; size_t m_padded_IC_f16{0}; size_t m_padded_IC_f32{0}; void ensure_weights_packed_f16(const std::vector& src); void ensure_weights_packed_f32(const std::vector& src); + void ensure_weights_packed_s2_f16(const std::vector& src); + void ensure_weights_packed_s2_f32(const std::vector& src); void exec_fp16(const std::vector& src, const std::vector& dst); void exec_fp32(const std::vector& src, const std::vector& dst); }; From 09483f193beb9e4a64d2d16efaf59efe3cbe184d Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Tue, 21 Oct 2025 18:51:14 +0200 Subject: [PATCH 08/20] Refactor AArch64 JIT 3D Deconvolution Executor by introducing `pack_index_eo_idx` helper to replace redundant lambda functions, improving code reuse and readability. --- .../nodes/executors/aarch64/jit_deconv3d.cpp | 55 ++++++++----------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 8a0f06dceec12a..d4c03c5f11d74e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -382,6 +382,14 @@ static inline bool deconv3d_force_ref() { return false; } static inline bool deconv3d_s2_kxloop_f16_enabled() { return false; } +// Common helper for even/odd S=2 packing index +static inline size_t pack_index_eo_idx(size_t KW_param, size_t py, size_t kx, bool use_pack) { + if (!use_pack) + return py + kx; + const size_t even_count = (KW_param + 1) / 2; + return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); +} + void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { // Pack weights (and S=2 even/odd layout) ahead of first exec to reduce cold-start latency if (m_is_fp32) { @@ -1825,11 +1833,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st // FP32 S=2 even/odd packing support (tile2 branch) const bool use_s2_pack_tile2_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; const float* wei_pack_ptr_tile2_f32 = use_s2_pack_tile2_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - auto pack_index_eo_tile2_f32 = [&](size_t py, size_t kx) { - if (!use_s2_pack_tile2_f32) return py + kx; - const size_t even_count = (KW + 1) / 2; - return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); - }; // Pass A: main kx set valid for ow_ for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { @@ -1850,10 +1853,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base1; } a.wei_stride = sizeof(float); @@ -1886,10 +1889,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base1; } a.wei_stride = sizeof(float); @@ -1921,10 +1924,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base3; } a.wei_stride = sizeof(float); @@ -1957,10 +1960,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base3; } a.wei_stride = sizeof(float); @@ -2004,10 +2007,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_tile2_f32(py0, static_cast(kx)) * m_padded_IC_f32; + const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_tile2_f32(py1, static_cast(kx)) * m_padded_IC_f32; + const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base1; } a.wei_stride = sizeof(float); @@ -2038,10 +2041,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_tile2_f32(py2, static_cast(kx)) * m_padded_IC_f32; + const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_tile2_f32 + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_tile2_f32(py3, static_cast(kx)) * m_padded_IC_f32; + const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_tile2_f32 + base3; } a.wei_stride = sizeof(float); @@ -2143,15 +2146,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st // FP32 S=2 even/odd packing selection (non-tile2) const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - auto pack_index_eo_f32 = [&](size_t py, size_t kx_) { - if (!use_s2_pack_f32) return py + kx_; - const size_t even_count = (KW + 1) / 2; - return py + ((kx_ & 1) ? (even_count + (kx_ / 2)) : (kx_ / 2)); - }; - const size_t base0 = pack_index_eo_f32(py0, static_cast(kx)) * m_padded_IC_f32; + const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_f32 + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_f32(py1, static_cast(kx)) * m_padded_IC_f32; + const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_f32 + base1; } a.wei_stride = sizeof(float); @@ -2185,15 +2183,10 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st if (m_wei_packed_ready_f32) { const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - auto pack_index_eo_f32 = [&](size_t py, size_t kx_) { - if (!use_s2_pack_f32) return py + kx_; - const size_t even_count = (KW + 1) / 2; - return py + ((kx_ & 1) ? (even_count + (kx_ / 2)) : (kx_ / 2)); - }; - const size_t base2 = pack_index_eo_f32(py2, static_cast(kx)) * m_padded_IC_f32; + const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; a.wei = wei_pack_ptr_f32 + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_f32(py3, static_cast(kx)) * m_padded_IC_f32; + const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; a.wei2 = wei_pack_ptr_f32 + base3; } a.wei_stride = sizeof(float); From 61a4af788442218c93dc05b8a3a0ae3056c9ef78 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 14:36:30 +0200 Subject: [PATCH 09/20] Remove unused environment variables, redundant code paths, and obsolete logic in AArch64 JIT 3D Deconvolution Executor for improved maintainability and clarity. --- .../nodes/executors/aarch64/jit_deconv3d.cpp | 296 +++--------------- 1 file changed, 35 insertions(+), 261 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index d4c03c5f11d74e..f437f559a4347e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -34,10 +34,6 @@ bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, m_ip_kernel_f32->create_ker(); } else { m_ip_kernel_f16 = std::make_unique(); - // Allow enabling in-kernel ky loop via env for FP16 - if (std::getenv("OV_AARCH64_DECONV3D_KY_LOOP_F16") && std::string(std::getenv("OV_AARCH64_DECONV3D_KY_LOOP_F16")) == "1") { - m_ip_kernel_f16->set_force_single_kh(false); - } m_ip_kernel_f16->create_ker(); } return true; @@ -391,13 +387,9 @@ static inline size_t pack_index_eo_idx(size_t KW_param, size_t py, size_t kx, bo } void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { - // Pack weights (and S=2 even/odd layout) ahead of first exec to reduce cold-start latency if (m_is_fp32) { if (deconv3d_pack_enabled()) { ensure_weights_packed_f32(src); - if (deconv3d_pack_s2_enabled()) { - ensure_weights_packed_s2_f32(src); - } } } else { if (deconv3d_pack_enabled()) { @@ -466,80 +458,12 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st ensure_weights_packed_s2_f16(src); } } - const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; - const ptrdiff_t OPH0 = deconvAttrs.outputPadding.size() > 1 ? deconvAttrs.outputPadding[1] : 0; - const ptrdiff_t OPW0 = deconvAttrs.outputPadding.size() > 2 ? deconvAttrs.outputPadding[2] : 0; // Effective dilations are stored as (dilation - 1) inside attrs; convert to actual factors const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; - // Reference-correct fallback (env: OV_AARCH64_DECONV3D_REF=1) - if (deconv3d_force_ref()) { - const bool grouped = weiDims.size() == 6; - const size_t G = grouped ? weiDims[0] : 1; - const size_t ICg = grouped ? weiDims[1] : IC; - const size_t OCg = grouped ? weiDims[2] : OC; - std::fill_n(dst_p, N * OC * OD * OH * OW, static_cast(0)); - for (size_t n = 0; n < N; ++n) { - for (size_t g = 0; g < G; ++g) { - for (size_t id = 0; id < ID; ++id) { - for (size_t ih = 0; ih < IH; ++ih) { - for (size_t iw_ = 0; iw_ < IW; ++iw_) { - for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t od = static_cast(id) * static_cast(SD) - PD0 + - static_cast(kz * dilD) + OPD0; - if (od < 0 || od >= static_cast(OD)) - continue; - for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t oh = static_cast(ih) * static_cast(SH) - PH0 + - static_cast(ky * dilH) + OPH0; - if (oh < 0 || oh >= static_cast(OH)) - continue; - for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ow = static_cast(iw_) * static_cast(SW) - - PW0 + static_cast(kx * dilW) + OPW0; - if (ow < 0 || ow >= static_cast(OW)) - continue; - for (size_t icg = 0; icg < ICg; ++icg) { - const size_t ic_global = g * ICg + icg; - const float sval = static_cast( - ov::float16(src_p[idx_src(n, ic_global, id, ih, iw_)])); - for (size_t ocg = 0; ocg < OCg; ++ocg) { - const size_t oc_global = g * OCg + ocg; - size_t w_off; - if (grouped) { - // layout [G, ICg, OCg, KD, KH, KW] - w_off = - (((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW + - kx); - } else { - // layout [IC, OC, KD, KH, KW] - w_off = ((((icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; - } - const float w = static_cast(ov::float16(wei_p[w_off])); - const size_t doff = idx_dst(n, - oc_global, - static_cast(od), - static_cast(oh), - static_cast(ow)); - const float accum = - sval * w + static_cast(ov::float16(dst_p[doff])); - dst_p[doff] = ov::float16(accum).to_bits(); - } - } - } - } - } - } - } - } - } - } - return; - } - auto worker = [&](size_t n, size_t oc_quad, size_t od) { const size_t oc0 = oc_quad * 4; const size_t g = OCg ? (oc0 / OCg) : 0; @@ -578,74 +502,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st (void)kx_hi; (void)kx_lo; if (m_wei_packed_ready_f16) { - if (deconv3d_kyloop_f16_enabled()) { - // In-kernel ky + kx (packed weights): - const auto kw_count = static_cast(kx_hi - kx_lo + 1); - const auto kh_count = static_cast(ky_hi - ky_lo + 1); - const size_t s_base_x0 = s_base_row + static_cast(tyd - ky_lo) * IW + static_cast(txd); - // Start from rightmost tap for positive src_dx - const size_t s_base0 = s_base_x0 - static_cast(kx_hi); - // Precompute packed bases at ky_lo - const size_t pz0 = (oc0 * KD + static_cast(kz)) * KH; - const size_t pz1 = has_oc1 ? (oc1 * KD + static_cast(kz)) * KH : 0; - const size_t py0 = (pz0 + static_cast(ky_lo)) * KW; - const size_t py1 = has_oc1 ? (pz1 + static_cast(ky_lo)) * KW : 0; - const size_t base0 = (py0 + static_cast(kx_lo)) * m_padded_IC_f16; - const size_t base1 = has_oc1 ? (py1 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; - // pair 0 - { - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.kh_cnt = kh_count; - a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); - a.wei = m_wei_packed_f16.data() + base0; - if (has_oc1) - a.wei2 = m_wei_packed_f16.data() + base1; - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); - a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); - (*m_ip_kernel_f16)(&a); - } - // pair 1 - if (has_oc2) { - const size_t pz2 = (oc2 * KD + static_cast(kz)) * KH; - const size_t pz3 = has_oc3 ? (oc3 * KD + static_cast(kz)) * KH : 0; - const size_t py2 = (pz2 + static_cast(ky_lo)) * KW; - const size_t py3 = has_oc3 ? (pz3 + static_cast(ky_lo)) * KW : 0; - const size_t base2 = (py2 + static_cast(kx_lo)) * m_padded_IC_f16; - const size_t base3 = has_oc3 ? (py3 + static_cast(kx_lo)) * m_padded_IC_f16 : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.kh_cnt = kh_count; - a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); - a.wei = m_wei_packed_f16.data() + base2; - if (has_oc3) - a.wei2 = m_wei_packed_f16.data() + base3; - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); - a.wei_dy = KW * m_padded_IC_f16 * sizeof(uint16_t); - (*m_ip_kernel_f16)(&a); - } - continue; // handled full ky range - } for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { const size_t ihh = static_cast(tyd - ky); const size_t s_base_x0 = s_base_row + ihh * IW + static_cast(txd); @@ -1571,9 +1427,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st if (deconv3d_pack_enabled()) { ensure_weights_packed_f32(src); - if (deconv3d_pack_s2_enabled()) { - ensure_weights_packed_s2_f32(src); - } } // Output padding and dilations const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; @@ -1582,11 +1435,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; - const bool dbg_check = []() { - const char* e = std::getenv("OV_AARCH64_DECONV3D_CHECK"); - return (e && e[0] == '1' && e[1] == '\0'); - }(); - // Reference-correct fallback (env: OV_AARCH64_DECONV3D_REF=1) + if (deconv3d_force_ref()) { const bool grouped = weiDims.size() == 6; const size_t G = grouped ? weiDims[0] : 1; @@ -1803,8 +1652,8 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); - // X2 micro-tiling over output width for stride=2: compute (ow_, ow_+2) together (FP32 uses its own gate) - if (deconv3d_tile2_f32_enabled() && (ow_ + 2) < OW) { + // X2 micro-tiling over output width for stride=2: compute (ow_, ow_+2) together (disabled for FP32) + if (false && (ow_ + 2) < OW) { float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 const ptrdiff_t txd1 = static_cast(ow_ + 2) + PW0; @@ -1830,10 +1679,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; - // FP32 S=2 even/odd packing support (tile2 branch) - const bool use_s2_pack_tile2_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; - const float* wei_pack_ptr_tile2_f32 = use_s2_pack_tile2_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - // Pass A: main kx set valid for ow_ for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { const size_t iw0 = static_cast((txd - kx) / 2); @@ -1853,11 +1698,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base0; + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base1; + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -1889,11 +1734,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base0; + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base1; + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -1924,11 +1769,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base2; + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base3; + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -1960,11 +1805,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base2; + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base3; + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -2007,11 +1852,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base0; + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base1; + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -2041,11 +1886,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_tile2_f32 + base2; + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_tile2_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_tile2_f32 + base3; + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -2143,14 +1988,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - // FP32 S=2 even/odd packing selection (non-tile2) - const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; - const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - const size_t base0 = pack_index_eo_idx(KW, py0, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_f32 + base0; + const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { - const size_t base1 = pack_index_eo_idx(KW, py1, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_f32 + base1; + const size_t base1 = (py1 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base1; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -2181,13 +2023,11 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.kw_cnt = 1; a.src_dx = 0; if (m_wei_packed_ready_f32) { - const bool use_s2_pack_f32 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f32; - const float* wei_pack_ptr_f32 = use_s2_pack_f32 ? m_wei_packed_s2_f32.data() : m_wei_packed_f32.data(); - const size_t base2 = pack_index_eo_idx(KW, py2, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; - a.wei = wei_pack_ptr_f32 + base2; + const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; + a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { - const size_t base3 = pack_index_eo_idx(KW, py3, static_cast(kx), use_s2_pack_f32) * m_padded_IC_f32; - a.wei2 = wei_pack_ptr_f32 + base3; + const size_t base3 = (py3 + static_cast(kx)) * m_padded_IC_f32; + a.wei2 = m_wei_packed_f32.data() + base3; } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; @@ -2346,72 +2186,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } } - if (dbg_check && n == 0 && od < 2 && oh < 2 && ow_ < 2) { - auto ref_acc = [&](size_t ocg_idx) { - float r = 0.0f; - // Reference: output-driven accumulation, same mapping as generic - for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t id_num = static_cast(od) + PD0 - - static_cast(kz * dilD); - if (id_num % static_cast(SD) != 0) - continue; - const ptrdiff_t id_idx = id_num / static_cast(SD); - if (id_idx < 0 || id_idx >= static_cast(ID)) - continue; - for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t iy_num = static_cast(oh) + PH0 - - static_cast(ky * dilH); - if (iy_num % static_cast(SH) != 0) - continue; - const ptrdiff_t ih_idx = iy_num / static_cast(SH); - if (ih_idx < 0 || ih_idx >= static_cast(IH)) - continue; - for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ix_num = static_cast(ow_) + PW0 - - static_cast(kx * dilW); - if (ix_num % static_cast(SW) != 0) - continue; - const ptrdiff_t iw_idx = ix_num / static_cast(SW); - if (iw_idx < 0 || iw_idx >= static_cast(IW)) - continue; - const size_t s_off = idx_src(n, - g * ICg, - static_cast(id_idx), - static_cast(ih_idx), - static_cast(iw_idx)); - for (size_t icg = 0; icg < ICg; ++icg) { - const size_t w_off = idx_wei(icg, g * OCg + ocg_idx, kz, ky, kx); - r += src_p[s_off + icg * src_c_stride_elems] * wei_p[w_off]; - } - } - } - } - if (deconvAttrs.withBiasesParam && src.size() > 2 && src[2] && src[2]->getData()) { - const auto& bprec = src[2]->getPrecision(); - if (bprec == ov::element::f32) { - r += reinterpret_cast(src[2]->getData())[g * OCg + ocg_idx]; - } else if (bprec == ov::element::f16) { - r += static_cast(ov::float16( - reinterpret_cast(src[2]->getData())[g * OCg + ocg_idx])); - } - } - return r; - }; - const float r0 = ref_acc(ocg0 + 0); - if (std::fabs(r0 - acc0) > 1e-3f) { - fprintf(stderr, - "[DECONV3D-CHK] mismatch at (n=%zu,g=%zu,oc=%zu,od=%zu,oh=%zu,ow=%zu): ref=%f got=%f\n", - n, - g, - oc0, - od, - oh, - ow_, - r0, - acc0); - } - } - dst_p[idx_dst(n, oc0, od, oh, ow_)] = acc0; if (has_oc1) dst_p[idx_dst(n, oc1, od, oh, ow_)] = acc1; From c8d4547b958e3c8aef0e604cc51a60bf547c6545 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 15:07:05 +0200 Subject: [PATCH 10/20] Remove unused helper functions, redundant conditions, and raw weight paths in AArch64 JIT 3D Deconvolution Executor for cleaner and more maintainable code. --- .../nodes/executors/aarch64/jit_deconv3d.cpp | 343 ++++-------------- 1 file changed, 73 insertions(+), 270 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index f437f559a4347e..6533b14f7ca907 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -353,51 +353,14 @@ void JitDeconv3DExecutor::exec(const std::vector& src, } } -// helper toggles declared below - -static inline bool deconv3d_pack_enabled() { return true; } - -static inline bool deconv3d_fastpath_f16_enabled() { return true; } -static inline bool deconv3d_fastpath_f32_enabled() { return true; } - -static inline bool deconv3d_fastpath_f32_s2_enabled() { return true; } - -static inline bool deconv3d_kyloop_f16_enabled() { return false; } - -static inline bool deconv3d_s2_grouped_enabled() { return true; } - -static inline bool deconv3d_tile2_enabled() { return true; } - -static inline bool deconv3d_tile2_f32_enabled() { return false; } - -static inline bool deconv3d_pack_s2_enabled() { return true; } - -static inline bool deconv3d_prefetch_enabled() { return true; } - -static inline bool deconv3d_force_ref() { return false; } - -static inline bool deconv3d_s2_kxloop_f16_enabled() { return false; } - -// Common helper for even/odd S=2 packing index -static inline size_t pack_index_eo_idx(size_t KW_param, size_t py, size_t kx, bool use_pack) { - if (!use_pack) - return py + kx; - const size_t even_count = (KW_param + 1) / 2; - return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); -} +// (no additional helpers) void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { if (m_is_fp32) { - if (deconv3d_pack_enabled()) { - ensure_weights_packed_f32(src); - } + ensure_weights_packed_f32(src); } else { - if (deconv3d_pack_enabled()) { - ensure_weights_packed_f16(src); - if (deconv3d_pack_s2_enabled()) { - ensure_weights_packed_s2_f16(src); - } - } + ensure_weights_packed_f16(src); + ensure_weights_packed_s2_f16(src); } } @@ -452,12 +415,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t src_c_stride_elems = ID * IH * IW; const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; - if (deconv3d_pack_enabled()) { - ensure_weights_packed_f16(src); - if (deconv3d_pack_s2_enabled()) { - ensure_weights_packed_s2_f16(src); - } - } + // Always prepare packed weights (standard + S=2 even/odd) + ensure_weights_packed_f16(src); + ensure_weights_packed_s2_f16(src); // Effective dilations are stored as (dilation - 1) inside attrs; convert to actual factors const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; @@ -480,7 +440,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st for (size_t ow_ = 0; ow_ < OW; ++ow_) { float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; - if (deconv3d_fastpath_f16_enabled() && SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { + if (SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { // Fast path: contiguous tap ranges, no modulus checks const ptrdiff_t tzd = static_cast(od) + PD0; const ptrdiff_t tyd = static_cast(oh) + PH0; @@ -564,8 +524,8 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } else { - // Raw weights fast-path - if (deconv3d_kyloop_f16_enabled()) { + // Raw weights fast-path (removed in product mode) + if (false) { // In-kernel ky + kx (raw weights): const auto kw_count = static_cast(kx_hi - kx_lo + 1); const auto kh_count = static_cast(ky_hi - ky_lo + 1); @@ -695,7 +655,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } - } else if (deconv3d_fastpath_f16_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && m_wei_packed_ready_f16 && (!grouped || deconv3d_s2_grouped_enabled())) { + } else if (SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1) { // Fast path S=2, dil=1 (packed weights): parity-filtered taps without modulus checks const ptrdiff_t tzd = static_cast(od) + PD0; const ptrdiff_t tyd = static_cast(oh) + PH0; @@ -709,7 +669,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); // X2 micro-tiling over output width for stride=2: compute (ow, ow+2) together when possible - if (deconv3d_tile2_enabled() && (ow_ + 2) < OW) { + if ((ow_ + 2) < OW) { float acc0a = 0.0F, acc1a = 0.0F, acc2a = 0.0F, acc3a = 0.0F; // for ow_ float acc0b = 0.0F, acc1b = 0.0F, acc2b = 0.0F, acc3b = 0.0F; // for ow_+2 @@ -739,10 +699,8 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st // Even/odd S=2 packing selection - const bool use_s2_pack_tile2 = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; - const uint16_t* wei_pack_ptr_tile2 = use_s2_pack_tile2 ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + const uint16_t* wei_pack_ptr_tile2 = m_wei_packed_s2_f16.data(); auto pack_index_eo_tile2 = [&](size_t py, size_t kx) { - if (!use_s2_pack_tile2) return py + kx; const size_t even_count = (KW + 1) / 2; return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); }; @@ -772,11 +730,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } if (iw1 < IW) { @@ -800,11 +756,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } } @@ -827,11 +781,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } // pair 1 for ow_+2 @@ -854,11 +806,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } } @@ -891,11 +841,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } if (has_oc2) { @@ -916,11 +864,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } } @@ -971,7 +917,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - if (deconv3d_s2_kxloop_f16_enabled()) { + if (false) { // In-kernel kx loop for parity taps: step weights by 2*IC and src by -1 in X for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { const size_t id = static_cast((tzd - kz) / 2); @@ -990,10 +936,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; - const bool use_s2_pack = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; - const uint16_t* wei_pack_ptr = use_s2_pack ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + const bool use_s2_pack = true; + const uint16_t* wei_pack_ptr = m_wei_packed_s2_f16.data(); auto pack_index_eo = [&](size_t py, size_t kx) { - if (!use_s2_pack) return py + kx; const size_t even_count = (KW + 1) / 2; return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); }; @@ -1065,7 +1010,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; - const bool use_s2_pack_orig = deconv3d_pack_s2_enabled() && m_wei_packed_s2_ready_f16; + const bool use_s2_pack_orig = true; const uint16_t* wei_pack_ptr_orig = use_s2_pack_orig ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); auto pack_index_eo_orig = [&](size_t py, size_t kx) { if (!use_s2_pack_orig) return py + kx; @@ -1095,11 +1040,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } // pair 1 @@ -1121,11 +1064,9 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - if (deconv3d_prefetch_enabled()) { - __builtin_prefetch(a.src + 64); - __builtin_prefetch(a.wei + 64); - if (a.wei2) __builtin_prefetch(a.wei2 + 64); - } + __builtin_prefetch(a.src + 64); + __builtin_prefetch(a.wei + 64); + if (a.wei2) __builtin_prefetch(a.wei2 + 64); (*m_ip_kernel_f16)(&a); } } @@ -1133,7 +1074,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } - } else if (deconv3d_fastpath_f16_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && !m_wei_packed_ready_f16 && (!grouped || deconv3d_s2_grouped_enabled())) { + } else if (false) { // Fast path S=2, dil=1 (raw weights): parity-filtered taps without modulus checks const ptrdiff_t tzd = static_cast(od) + PD0; const ptrdiff_t tyd = static_cast(oh) + PH0; @@ -1264,8 +1205,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st static_cast(id_idx), static_cast(ih_idx), static_cast(iw_idx)); - const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); - const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; + // raw weight indices removed in product mode jit_conv3d_call_args a{}; a.src = src_p + s_base0; @@ -1277,7 +1217,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.tail = ICg % 8; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f16) { + { const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; a.wei = m_wei_packed_f16.data() + pack_base0; @@ -1289,20 +1229,12 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - } else { - a.wei = wei_p + w_base0; - if (has_oc1) - a.wei2 = wei_p + w_base1; - a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = 0; } (*m_ip_kernel_f16)(&a); // second pair for oc2/oc3 if (has_oc2) { - const size_t w_base2 = idx_wei(0, oc2, kz, ky, kx); - const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, kz, ky, kx) : 0; + // raw weight indices removed in product mode jit_conv3d_call_args a2{}; a2.src = src_p + s_base0; a2.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -1313,7 +1245,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a2.tail = ICg % 8; a2.kw_cnt = 1; a2.src_dx = 0; - if (m_wei_packed_ready_f16) { + { const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; a2.wei = m_wei_packed_f16.data() + pack_base2; @@ -1325,13 +1257,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st a2.wei_stride = sizeof(uint16_t); a2.wei_blk_stride = a2.wei_stride * 8; a2.wei_dx = 0; - } else { - a2.wei = wei_p + w_base2; - if (has_oc3) - a2.wei2 = wei_p + w_base3; - a2.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a2.wei_blk_stride = a2.wei_stride * 8; - a2.wei_dx = 0; } (*m_ip_kernel_f16)(&a2); } @@ -1425,9 +1350,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t src_c_stride_elems = ID * IH * IW; const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; - if (deconv3d_pack_enabled()) { - ensure_weights_packed_f32(src); - } + ensure_weights_packed_f32(src); // Output padding and dilations const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; const ptrdiff_t OPH0 = deconvAttrs.outputPadding.size() > 1 ? deconvAttrs.outputPadding[1] : 0; @@ -1436,7 +1359,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; - if (deconv3d_force_ref()) { + if (false) { const bool grouped = weiDims.size() == 6; const size_t G = grouped ? weiDims[0] : 1; const size_t ICg = grouped ? weiDims[1] : IC; @@ -1509,7 +1432,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st for (size_t ow_ = 0; ow_ < OW; ++ow_) { float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; - if (deconv3d_fastpath_f32_enabled() && SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { + if (SD == 1 && SH == 1 && SW == 1 && dilD == 1 && dilH == 1 && dilW == 1) { // contiguous tap range in each dimension const ptrdiff_t tz_pos = static_cast(od) + PD0; const ptrdiff_t ty_pos = static_cast(oh) + PH0; @@ -1546,7 +1469,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args.tail = ICg % 4; args.kw_cnt = kw_count; args.src_dx = static_cast(-static_cast(sizeof(float))); - if (m_wei_packed_ready_f32) { + if (true) { const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + @@ -1564,25 +1487,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args.wei_stride = sizeof(float); args.wei_blk_stride = args.wei_stride * 4; args.wei_dx = m_padded_IC_f32 * sizeof(float); - } else { - const size_t w_base0 = idx_wei(0, - oc0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w_base1 = has_oc1 ? idx_wei(0, - oc1, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - args.wei = wei_p + w_base0; - if (has_oc1) - args.wei2 = wei_p + w_base1; - args.wei_stride = wei_ic_stride_elems * sizeof(float); - args.wei_blk_stride = args.wei_stride * 4; - args.wei_dx = sizeof(float); - } + } else { /* unreachable: raw weights path removed */ } (*m_ip_kernel_f32)(&args); } // pair 1 @@ -1597,7 +1502,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args2.tail = ICg % 4; args2.kw_cnt = kw_count; args2.src_dx = static_cast(-static_cast(sizeof(float))); - if (m_wei_packed_ready_f32) { + if (true) { const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + @@ -1615,31 +1520,13 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st args2.wei_stride = sizeof(float); args2.wei_blk_stride = args2.wei_stride * 4; args2.wei_dx = m_padded_IC_f32 * sizeof(float); - } else { - const size_t w_base2 = idx_wei(0, - oc2, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w_base3 = has_oc3 ? idx_wei(0, - oc3, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - args2.wei = wei_p + w_base2; - if (has_oc3) - args2.wei2 = wei_p + w_base3; - args2.wei_stride = wei_ic_stride_elems * sizeof(float); - args2.wei_blk_stride = args2.wei_stride * 4; - args2.wei_dx = sizeof(float); - } + } else { /* unreachable: raw weights path removed */ } (*m_ip_kernel_f32)(&args2); } } } } - } else if (deconv3d_fastpath_f32_s2_enabled() && SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1 && (!grouped || deconv3d_s2_grouped_enabled())) { + } else if (SD == 2 && SH == 2 && SW == 2 && dilD == 1 && dilH == 1 && dilW == 1) { // Fast path S=2, dil=1 (packed weights preferred): parity-filtered taps without modulus checks const ptrdiff_t tzd = static_cast(od) + PD0; const ptrdiff_t tyd = static_cast(oh) + PH0; @@ -1697,7 +1584,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { @@ -1707,17 +1594,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base0; - if (has_oc1) { - const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base1; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } // For ow_+2 if in-bounds @@ -1733,7 +1610,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { @@ -1768,7 +1645,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { @@ -1778,17 +1655,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base2; - if (has_oc3) { - const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base3; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } // pair 1 for ow_+2 @@ -1814,17 +1681,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base2; - if (has_oc3) { - const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base3; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } } @@ -1851,7 +1708,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { @@ -1861,17 +1718,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base0; - if (has_oc1) { - const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base1; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } if (has_oc2) { @@ -1895,17 +1742,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base2; - if (has_oc3) { - const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base3; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } } @@ -1987,7 +1824,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base0 = (py0 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base0; if (has_oc1) { @@ -1997,17 +1834,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base0 = idx_wei(0, oc0, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base0; - if (has_oc1) { - const size_t w_base1 = idx_wei(0, oc1, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base1; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } // pair 1 (oc2, oc3) @@ -2022,7 +1849,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t base2 = (py2 + static_cast(kx)) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + base2; if (has_oc3) { @@ -2032,17 +1859,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; - } else { - const size_t w_base2 = idx_wei(0, oc2, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei = wei_p + w_base2; - if (has_oc3) { - const size_t w_base3 = idx_wei(0, oc3, static_cast(kz), static_cast(ky), static_cast(kx)); - a.wei2 = wei_p + w_base3; - } - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - } + } else { /* unreachable */ } (*m_ip_kernel_f32)(&a); } } @@ -2087,8 +1904,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st static_cast(id_idx), static_cast(ih_idx), static_cast(iw_idx)); - const size_t w_base0 = idx_wei(0, oc0, kz, ky, kx); - const size_t w_base1 = has_oc1 ? idx_wei(0, oc1, kz, ky, kx) : 0; + // raw weight indices removed // pair 0 { @@ -2102,7 +1918,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; a.wei = m_wei_packed_f32.data() + pack_base0; @@ -2113,20 +1929,13 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } a.wei_stride = sizeof(float); a.wei_blk_stride = a.wei_stride * 4; - } else { - a.wei = wei_p + w_base0; - if (has_oc1) - a.wei2 = wei_p + w_base1; - a.wei_stride = wei_ic_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - } + } else { /* unreachable */ } a.wei_dx = 0; (*m_ip_kernel_f32)(&a); } // pair 1 if (has_oc2) { - const size_t w_base2 = idx_wei(0, oc2, kz, ky, kx); - const size_t w_base3 = has_oc3 ? idx_wei(0, oc3, kz, ky, kx) : 0; + // raw weight indices removed jit_conv3d_f32_call_args a2{}; a2.src = src_p + s_base0; a2.src_stride = src_c_stride_elems * sizeof(float); @@ -2137,7 +1946,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st a2.tail = ICg % 4; a2.kw_cnt = 1; a2.src_dx = 0; - if (m_wei_packed_ready_f32) { + if (true) { const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; a2.wei = m_wei_packed_f32.data() + pack_base2; @@ -2148,13 +1957,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } a2.wei_stride = sizeof(float); a2.wei_blk_stride = a2.wei_stride * 4; - } else { - a2.wei = wei_p + w_base2; - if (has_oc3) - a2.wei2 = wei_p + w_base3; - a2.wei_stride = wei_ic_stride_elems * sizeof(float); - a2.wei_blk_stride = a2.wei_stride * 4; - } + } else { /* unreachable */ } a2.wei_dx = 0; (*m_ip_kernel_f32)(&a2); } From 49ef0c78c3e261fcd09b161b3bdc0baa56846d24 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 18:27:22 +0200 Subject: [PATCH 11/20] Introduce early weight preparation in AArch64 JIT 3D Convolution and Deconvolution Executors, remove redundant FP32 paths, and streamline packed weight logic for reduced initialization latency and improved maintainability. --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 23 +- .../nodes/executors/aarch64/jit_conv3d.cpp | 381 ++++++------------ .../nodes/executors/aarch64/jit_conv3d.hpp | 3 + .../nodes/executors/aarch64/jit_deconv3d.cpp | 75 +--- .../nodes/executors/aarch64/jit_deconv3d.hpp | 29 +- .../src/nodes/executors/deconv_list.hpp | 44 ++ 6 files changed, 211 insertions(+), 344 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index cffcea2f6b45a1..b1db674d8c1455 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -983,10 +983,12 @@ void Deconvolution::prepareParams() { dstMemoryDescs.push_back(getChildEdgeAt(i)->getMemory().getDescWithType()); } - execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutor(deconvAttrs, - srcMemoryDescs, - dstMemoryDescs, - *attr); + // Build executor with constructor-time early packing (for JIT on ARM64); falls back to regular path + std::vector srcMemoriesEarly; + srcMemoriesEarly.push_back(getSrcMemoryAtPort(0)); + srcMemoriesEarly.push_back(getSrcMemoryAtPort(1)); + execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutorWithMem( + deconvAttrs, srcMemoryDescs, dstMemoryDescs, *attr, srcMemoriesEarly); selected_pd->setImplementationType(execPtrDeconvACL->getImplType()); return; } @@ -1138,19 +1140,6 @@ void Deconvolution::prepareParams() { execPtr = result.first; OPENVINO_ASSERT(execPtr, "Primitive descriptor was not found for node ", getName(), "."); -#if defined(OPENVINO_ARCH_ARM64) - // Early weight packing for AArch64 JIT to minimize first-inference latency - if (auto jitExec = std::dynamic_pointer_cast(execPtr)) { - std::vector srcMemories; - // src[0] = input, src[1] = weights, src[2] = bias (optional) - srcMemories.push_back(getSrcMemoryAtPort(0)); - srcMemories.push_back(getSrcMemoryAtPort(1)); - // Bias not needed for packing - jitExec->prepare_weights_early(srcMemories); - } -#endif - - primArgs[DNNL_ARG_SRC] = srcMemPtr->getPrimitive(); primArgs[DNNL_ARG_DST] = dstMemPtr->getPrimitive(); if (weightIsConst) { diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 3ee52439221675..f341f3dcb9e369 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -1666,6 +1666,9 @@ JitConv3DExecutor::JitConv3DExecutor(const ConvAttrs& attrs, } } } + + // Early weight packing (only if shapes are static). Kept inside executor per policy. + prepare_weights_early(m_memory); } bool JitConv3DExecutor::supports(const ConvConfig& cfg) { @@ -1706,6 +1709,23 @@ bool JitConv3DExecutor::supports(const ConvConfig& cfg) { return true; } +void JitConv3DExecutor::prepare_weights_early(const MemoryArgs& memory) { + // Guard: only when shapes are static + auto src_it = memory.find(ARG_SRC); + auto wei_it = memory.find(ARG_WEI); + if (src_it == memory.end() || wei_it == memory.end() || !src_it->second || !wei_it->second) + return; + const auto& s = src_it->second->getDescPtr()->getShape(); + const auto& w = wei_it->second->getDescPtr()->getShape(); + if (!s.isStatic() || !w.isStatic()) + return; + if (m_is_fp32) { + ensure_weights_packed_f32(memory); + } else { + ensure_weights_packed(memory); + } +} + void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { // NCDHW auto src = memory.at(ARG_SRC); @@ -1731,7 +1751,6 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { const size_t PW0 = m_attrs.paddingL.size() > 2 ? static_cast(m_attrs.paddingL[2]) : 0; const uint16_t* src_p = ptr_f16(src); - const uint16_t* wei_p = ptr_f16(wei); uint16_t* dst_p = ptr_f16(dst); auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) -> size_t { @@ -1740,9 +1759,6 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { auto index_dst = [&](size_t n, size_t oc, size_t z, size_t y, size_t x) -> size_t { return (((n * OC + oc) * OD + z) * OH + y) * OW + x; }; - auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; - }; // Prepare packed weights once ensure_weights_packed(memory); @@ -1764,7 +1780,6 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; const size_t src_c_stride_elems = ID * IH * IW; - const size_t wei_c_stride_elems = KD * KH * KW; if (SD == 1 && SH == 1 && SW == 1) { const ptrdiff_t kz_lo = std::max(0, -iz0); @@ -1782,137 +1797,64 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { const size_t iz = static_cast(iz0 + kz); // iy/ix for ky_lo/kx_lo not needed; use iy2/ix2 per ky below - if (m_wei_packed_ready) { - // Loop over ky in host; kernel handles kx via kw_cnt - for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy2 = static_cast(iy0 + ky); - const size_t ix2 = static_cast(ix0 + kx_lo); - const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - const size_t pack_base0 = - (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + + // Loop over ky in host; kernel handles kx via kw_cnt (always packed) + for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { + const size_t iy2 = static_cast(iy0 + ky); + const size_t ix2 = static_cast(ix0 + kx_lo); + const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); + jit_conv3d_call_args a{}; + a.src = src_p + s_base2; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = &acc0; + a.acc2 = has_oc1 ? &acc1 : nullptr; + a.repeats = C / 8; + a.tail = C % 8; + a.kw_cnt = kw_count; + a.src_dx = sizeof(uint16_t); + const size_t pack_base0 = + (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + + static_cast(kx_lo)) * + m_padded_C; + a.wei = m_wei_packed.data() + pack_base0; + if (has_oc1) { + const size_t pack_base1 = + (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base1; - } - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_C * sizeof(uint16_t); - (*m_ip_kernel)(&a); - if (has_oc2) { - jit_conv3d_call_args a2{}; - a2.src = src_p + s_base2; - a2.src_stride = a.src_stride; - a2.src_blk_stride = a.src_blk_stride; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = a.repeats; - a2.tail = a.tail; - a2.kw_cnt = a.kw_cnt; - a2.src_dx = a.src_dx; - const size_t pack_base2 = - (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + + a.wei2 = m_wei_packed.data() + pack_base1; + } + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + a.wei_dx = m_padded_C * sizeof(uint16_t); + (*m_ip_kernel)(&a); + if (has_oc2) { + jit_conv3d_call_args a2{}; + a2.src = src_p + s_base2; + a2.src_stride = a.src_stride; + a2.src_blk_stride = a.src_blk_stride; + a2.acc = &acc2; + a2.acc2 = has_oc3 ? &acc3 : nullptr; + a2.repeats = a.repeats; + a2.tail = a.tail; + a2.kw_cnt = a.kw_cnt; + a2.src_dx = a.src_dx; + const size_t pack_base2 = + (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + + static_cast(kx_lo)) * + m_padded_C; + a2.wei = m_wei_packed.data() + pack_base2; + if (has_oc3) { + const size_t pack_base3 = + (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; - a2.wei = m_wei_packed.data() + pack_base2; - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a2.wei2 = m_wei_packed.data() + pack_base3; - } - a2.wei_stride = a.wei_stride; - a2.wei_blk_stride = a.wei_blk_stride; - a2.wei_dx = a.wei_dx; - (*m_ip_kernel)(&a2); - } - } - } else { - // Non-packed: keep ky loop outside and issue dual calls - for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy2 = static_cast(iy0 + ky); - const size_t ix2 = static_cast(ix0 + kx_lo); - const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); - const size_t w0_base = index_wei(oc0, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w1_base = has_oc1 ? index_wei(oc1, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - // pair 0 - { - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.wei = wei_p + w0_base; - if (has_oc1) - a.wei2 = wei_p + w1_base; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - (*m_ip_kernel)(&a); - } - if (has_oc2) { - const size_t w2_base = index_wei(oc2, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w3_base = has_oc3 ? index_wei(oc3, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.wei = wei_p + w2_base; - if (has_oc3) - a.wei2 = wei_p + w3_base; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - (*m_ip_kernel)(&a); + a2.wei2 = m_wei_packed.data() + pack_base3; } + a2.wei_stride = a.wei_stride; + a2.wei_blk_stride = a.wei_blk_stride; + a2.wei_dx = a.wei_dx; + (*m_ip_kernel)(&a2); } } } @@ -1935,8 +1877,7 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { static_cast(iz), static_cast(iy), static_cast(ix)); - const size_t w0_base = index_wei(oc0, 0, kz, ky, kx); - const size_t w1_base = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; + // raw base indices removed in product mode // pair 0 { jit_conv3d_call_args a{}; @@ -1948,35 +1889,23 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { a.acc = &acc0; a.acc2 = has_oc1 ? &acc1 : nullptr; - if (m_wei_packed_ready) { - // packed index: ((((oc*KD + kz)*KH + ky)*KW + kx)*paddedC) - const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base0; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); // contiguous halves - a.wei_blk_stride = a.wei_stride * 8; // logical - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base1; - } - (*m_ip_kernel)(&a); - } else { - a.wei = wei_p + w0_base; - if (has_oc1) - a.wei2 = wei_p + w1_base; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - (*m_ip_kernel)(&a); + // packed index: ((((oc*KD + kz)*KH + ky)*KW + kx)*paddedC) + const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei = m_wei_packed.data() + pack_base0; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = sizeof(uint16_t); // contiguous halves + a.wei_blk_stride = a.wei_stride * 8; // logical + if (has_oc1) { + const size_t pack_base1 = + (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base1; } + (*m_ip_kernel)(&a); } // pair 1 if (has_oc2) { - const size_t w2_base = index_wei(oc2, 0, kz, ky, kx); - const size_t w3_base = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; + // raw base indices removed in product mode jit_conv3d_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -1985,84 +1914,24 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { 8; // used logically, kernel advances by stride once after 8 lanes a.acc = &acc2; a.acc2 = has_oc3 ? &acc3 : nullptr; - if (m_wei_packed_ready) { - const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base2; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); // contiguous halves - a.wei_blk_stride = a.wei_stride * 8; // logical - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base3; - } - (*m_ip_kernel)(&a); - } else { - a.wei = wei_p + w2_base; - if (has_oc3) - a.wei2 = wei_p + w3_base; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = wei_c_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - (*m_ip_kernel)(&a); + const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei = m_wei_packed.data() + pack_base2; + a.repeats = C / 8; + a.tail = C % 8; + a.wei_stride = sizeof(uint16_t); + a.wei_blk_stride = a.wei_stride * 8; + if (has_oc3) { + const size_t pack_base3 = + (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + a.wei2 = m_wei_packed.data() + pack_base3; } + (*m_ip_kernel)(&a); } } } } } - // Optional fused bias (disabled by default) - if (m_apply_post_ops && m_attrs.withBias && memory.count(ARG_BIAS) && memory.at(ARG_BIAS)) { - auto bia = memory.at(ARG_BIAS); - const auto bprec = bia->getDescPtr()->getPrecision(); - if (bprec == ov::element::f32) { - const auto* b = reinterpret_cast(bia->getData()); - acc0 += b[oc0]; - if (has_oc1) - acc1 += b[oc1]; - if (has_oc2) - acc2 += b[oc2]; - if (has_oc3) - acc3 += b[oc3]; - } else if (bprec == ov::element::f16) { - const auto* b = reinterpret_cast(bia->getData()); - acc0 += static_cast(ov::float16(b[oc0])); - if (has_oc1) - acc1 += static_cast(ov::float16(b[oc1])); - if (has_oc2) - acc2 += static_cast(ov::float16(b[oc2])); - if (has_oc3) - acc3 += static_cast(ov::float16(b[oc3])); - } - } - - // Optional fused PReLU (apply after bias) — disabled by default - if (m_apply_post_ops && m_has_prelu && !m_prelu_slopes.empty()) { - const auto slope_at = [&](size_t oc) -> float { - return m_prelu_slopes.size() == 1 ? m_prelu_slopes[0] - : m_prelu_slopes[std::min(oc, m_prelu_slopes.size() - 1)]; - }; - const float s0 = slope_at(oc0); - if (acc0 < 0.0F) - acc0 *= s0; - if (has_oc1) { - const float s1 = slope_at(oc1); - if (acc1 < 0.0F) - acc1 *= s1; - } - if (has_oc2) { - const float s2 = slope_at(oc2); - if (acc2 < 0.0F) - acc2 *= s2; - } - if (has_oc3) { - const float s3 = slope_at(oc3); - if (acc3 < 0.0F) - acc3 *= s3; - } - } + // No optional post-ops in product mode dst_p[index_dst(n, oc0, od, oh, ow)] = ov::float16(acc0).to_bits(); if (has_oc1) @@ -2124,6 +1993,7 @@ void JitConv3DExecutor::ensure_weights_packed_f32(const MemoryArgs& memory) { } } m_wei_packed_ready_f32 = true; + // no global cache store } void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { @@ -2353,10 +2223,8 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { static_cast(ix)); // pair 0 { - const size_t w0 = index_wei(oc0, 0, kz, ky, kx); - const size_t w1 = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; a.src_stride = src_c_stride_elems * sizeof(float); a.src_blk_stride = a.src_stride * 4; a.acc = &acc0; @@ -2365,20 +2233,21 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = 1; a.src_dx = 0; - a.wei = wei_p + w0; - if (has_oc1) - a.wei2 = wei_p + w1; - a.wei_stride = wei_c_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel_f32)(&a); + const size_t base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base0; + if (has_oc1) { + const size_t base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base1; } - // pair 1 - if (has_oc2) { - const size_t w2 = index_wei(oc2, 0, kz, ky, kx); - const size_t w3 = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + } + // pair 1 + if (has_oc2) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; a.src_stride = src_c_stride_elems * sizeof(float); a.src_blk_stride = a.src_stride * 4; a.acc = &acc2; @@ -2387,14 +2256,17 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.tail = C % 4; a.kw_cnt = 1; a.src_dx = 0; - a.wei = wei_p + w2; - if (has_oc3) - a.wei2 = wei_p + w3; - a.wei_stride = wei_c_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel_f32)(&a); + const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + a.wei = m_wei_packed_f32.data() + base2; + if (has_oc3) { + const size_t base3 = (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + a.wei2 = m_wei_packed_f32.data() + base3; } + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + } } } } @@ -2455,4 +2327,5 @@ void ov::intel_cpu::JitConv3DExecutor::ensure_weights_packed(const MemoryArgs& m } } m_wei_packed_ready = true; + // no global cache store } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index 0e4b7189850476..fd39ed1423477c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -92,6 +92,9 @@ class JitConv3DExecutor : public Executor { static bool supports(const ConvConfig& cfg); + // Early weight preparation to reduce first-inference latency (product mode, no flags) + void prepare_weights_early(const MemoryArgs& memory); + private: // Simple reference fallback (parallelized) using FP16 data; correctness-first void run_naive_fp16(const MemoryArgs& memory); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 6533b14f7ca907..6c70a039d66138 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -273,75 +273,6 @@ void JitDeconv3DExecutor::ensure_weights_packed_s2_f16(const std::vector& src) { - if (m_wei_packed_s2_ready_f32) - return; - const auto& weiDims = src[1]->getStaticDims(); - const auto* wsrc = reinterpret_cast(src[1]->getData()); - if (weiDims.size() == 5) { - const size_t IC = weiDims[0]; - const size_t OC = weiDims[1]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - m_padded_IC_f32 = (IC + 3) / 4 * 4; - const size_t total = OC * KD * KH * KW * m_padded_IC_f32; - m_wei_packed_s2_f32.assign(total, 0.0F); - auto idx_src = [&](size_t ic, size_t oc, size_t kz, size_t ky, size_t kx) { - return ((((ic)*OC + oc) * KD + kz) * KH + ky) * KW + kx; - }; - for (size_t oc = 0; oc < OC; ++oc) { - for (size_t kz = 0; kz < KD; ++kz) { - for (size_t ky = 0; ky < KH; ++ky) { - size_t pos = 0; - for (size_t kx = 0; kx < KW; kx += 2, ++pos) { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; - for (size_t ic = 0; ic < IC; ++ic) - m_wei_packed_s2_f32[base + (ic / 4) * 4 + (ic % 4)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; - } - for (size_t kx = 1; kx < KW; kx += 2, ++pos) { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; - for (size_t ic = 0; ic < IC; ++ic) - m_wei_packed_s2_f32[base + (ic / 4) * 4 + (ic % 4)] = wsrc[idx_src(ic, oc, kz, ky, kx)]; - } - } - } - } - m_wei_packed_s2_ready_f32 = true; - } else if (weiDims.size() == 6) { - const size_t G = weiDims[0]; - const size_t ICg = weiDims[1]; - const size_t OCg = weiDims[2]; - const size_t KD = weiDims[3], KH = weiDims[4], KW = weiDims[5]; - const size_t OC_total = G * OCg; - m_padded_IC_f32 = (ICg + 3) / 4 * 4; - const size_t total = OC_total * KD * KH * KW * m_padded_IC_f32; - m_wei_packed_s2_f32.assign(total, 0.0F); - auto idx_src_g = [&](size_t g, size_t icg, size_t ocg, size_t kz, size_t ky, size_t kx) { - return ((((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW) + kx); - }; - for (size_t g = 0; g < G; ++g) { - for (size_t ocg = 0; ocg < OCg; ++ocg) { - const size_t oc_global = g * OCg + ocg; - for (size_t kz = 0; kz < KD; ++kz) { - for (size_t ky = 0; ky < KH; ++ky) { - size_t pos = 0; - for (size_t kx = 0; kx < KW; kx += 2, ++pos) { - const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; - for (size_t icg = 0; icg < ICg; ++icg) - m_wei_packed_s2_f32[base + (icg / 4) * 4 + (icg % 4)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; - } - for (size_t kx = 1; kx < KW; kx += 2, ++pos) { - const size_t base = (((oc_global * KD + kz) * KH + ky) * KW + pos) * m_padded_IC_f32; - for (size_t icg = 0; icg < ICg; ++icg) - m_wei_packed_s2_f32[base + (icg / 4) * 4 + (icg % 4)] = wsrc[idx_src_g(g, icg, ocg, kz, ky, kx)]; - } - } - } - } - } - m_wei_packed_s2_ready_f32 = true; - } -} void JitDeconv3DExecutor::exec(const std::vector& src, const std::vector& dst, @@ -356,6 +287,12 @@ void JitDeconv3DExecutor::exec(const std::vector& src, // (no additional helpers) void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { + if (src.size() < 2 || !src[0] || !src[1] || !src[0]->getDescPtr() || !src[1]->getDescPtr()) + return; + const auto& s = src[0]->getDescPtr()->getShape(); + const auto& w = src[1]->getDescPtr()->getShape(); + if (!s.isStatic() || !w.isStatic()) + return; if (m_is_fp32) { ensure_weights_packed_f32(src); } else { diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index 4d6dc361c406b3..83ad32e00b99c0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -16,6 +16,25 @@ namespace ov::intel_cpu { class JitDeconv3DExecutor : public DeconvExecutor { public: explicit JitDeconv3DExecutor(ExecutorContext::CPtr context) : DeconvExecutor(std::move(context)) {} + // Constructor with early weights preparation (product mode): + // expects src[0]=input, src[1]=weights; guards dynamic shapes internally + JitDeconv3DExecutor(const std::vector& src, ExecutorContext::CPtr context) + : DeconvExecutor(std::move(context)) { + // Derive precision from src[0] if available + if (!src.empty() && src[0] && src[0]->getDescPtr()) { + const auto prec = src[0]->getDescPtr()->getPrecision(); + m_is_fp32 = (prec == ov::element::f32); + } + if (m_is_fp32) { + m_ip_kernel_f32 = std::make_unique(); + m_ip_kernel_f32->create_ker(); + } else { + m_ip_kernel_f16 = std::make_unique(); + m_ip_kernel_f16->create_ker(); + } + // Early pack (static shapes only) + prepare_weights_early(src); + } ~JitDeconv3DExecutor() override = default; bool init(const DeconvAttrs& deconvAttrs, @@ -45,20 +64,17 @@ class JitDeconv3DExecutor : public DeconvExecutor { // packed weights std::vector m_wei_packed_f16; std::vector m_wei_packed_f32; - // alternative packing for S=2 (even/odd taps) + // alternative packing for S=2 (even/odd taps) — FP16 only std::vector m_wei_packed_s2_f16; - std::vector m_wei_packed_s2_f32; bool m_wei_packed_ready_f16{false}; bool m_wei_packed_ready_f32{false}; bool m_wei_packed_s2_ready_f16{false}; - bool m_wei_packed_s2_ready_f32{false}; size_t m_padded_IC_f16{0}; size_t m_padded_IC_f32{0}; void ensure_weights_packed_f16(const std::vector& src); void ensure_weights_packed_f32(const std::vector& src); void ensure_weights_packed_s2_f16(const std::vector& src); - void ensure_weights_packed_s2_f32(const std::vector& src); void exec_fp16(const std::vector& src, const std::vector& dst); void exec_fp32(const std::vector& src, const std::vector& dst); }; @@ -71,6 +87,11 @@ class AArch64JitDeconvExecutorBuilder : public DeconvExecutorBuilder { [[nodiscard]] DeconvExecutorPtr makeExecutor(ExecutorContext::CPtr context) const override { return std::make_shared(context); } + // Helper to create executor with early packing in constructor + [[nodiscard]] DeconvExecutorPtr makeExecutorWithMem(ExecutorContext::CPtr context, + const std::vector& src) const { + return std::make_shared(src, context); + } }; } // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp index cfa7eb0c8549d8..753c404b7b8d6f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp @@ -15,6 +15,9 @@ #if defined(OV_CPU_WITH_ACL) # include "acl/acl_deconv.hpp" #endif +#if defined(OPENVINO_ARCH_ARM64) +# include "nodes/executors/aarch64/jit_deconv3d.hpp" +#endif namespace ov::intel_cpu { @@ -69,6 +72,47 @@ class DeconvExecutorFactory : public ExecutorFactoryLegacy { OPENVINO_THROW("DeconvExecutorFactory: Supported executor is not found"); } + // ARM64 helper: build executor and allow constructor-time early packing using input/weight memories + DeconvExecutorPtr makeExecutorWithMem(const DeconvAttrs& deconvAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr, + const std::vector& srcMemories) { + auto build = [&](const DeconvExecutorDesc* desc) -> DeconvExecutorPtr { + // If this is our AArch64 JIT builder, construct with memories to trigger early packing in ctor +#if defined(OPENVINO_ARCH_ARM64) + if (auto jitBuilder = std::dynamic_pointer_cast(desc->builder)) { + auto executor = jitBuilder->makeExecutorWithMem(context, srcMemories); + if (executor->init(deconvAttrs, srcDescs, dstDescs, attr)) { + return executor; + } + } +#endif + // Fallback to regular path + auto executor = desc->builder->makeExecutor(context); + if (executor->init(deconvAttrs, srcDescs, dstDescs, attr)) { + return executor; + } + DeconvExecutorPtr ptr = nullptr; + return ptr; + }; + + if (chosenDesc) { + if (auto executor = build(chosenDesc)) { + return executor; + } + } + + for (const auto& sd : supportedDescs) { + if (auto executor = build(&sd)) { + chosenDesc = &sd; + return executor; + } + } + + OPENVINO_THROW("DeconvExecutorFactory: Supported executor is not found (with memories)"); + } + private: std::vector supportedDescs; const DeconvExecutorDesc* chosenDesc = nullptr; From ca892ff5ca1ab9584e866cb0b9518d3c6196fbe3 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 18:54:09 +0200 Subject: [PATCH 12/20] Remove unused fast paths, redundant logic, and obsolete conditions in AArch64 JIT 3D Deconvolution Executor to improve maintainability, clarity, and consistency. --- .../nodes/executors/aarch64/jit_deconv3d.cpp | 255 +----------------- 1 file changed, 9 insertions(+), 246 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 6c70a039d66138..43c55df0cde0d0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -284,7 +284,6 @@ void JitDeconv3DExecutor::exec(const std::vector& src, } } -// (no additional helpers) void JitDeconv3DExecutor::prepare_weights_early(const std::vector& src) { if (src.size() < 2 || !src[0] || !src[1] || !src[0]->getDescPtr() || !src[1]->getDescPtr()) @@ -461,84 +460,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } else { - // Raw weights fast-path (removed in product mode) - if (false) { - // In-kernel ky + kx (raw weights): - const auto kw_count = static_cast(kx_hi - kx_lo + 1); - const auto kh_count = static_cast(ky_hi - ky_lo + 1); - const size_t ih0 = static_cast(tyd - ky_lo); - const size_t s_base_row2 = - n_base + (g * ICg) * src_c_stride_elems + src_z_off + ih0 * IW; - const size_t s_base0 = s_base_row2 + static_cast(txd - kx_hi); - // pair 0 - { - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.kh_cnt = kh_count; - a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); - const size_t w_base0 = idx_wei(0, - oc0, - static_cast(kz), - static_cast(ky_lo), - static_cast(kx_lo)); - const size_t w_base1 = has_oc1 ? idx_wei(0, - oc1, - static_cast(kz), - static_cast(ky_lo), - static_cast(kx_lo)) - : 0; - a.wei = wei_p + w_base0; - if (has_oc1) - a.wei2 = wei_p + w_base1; - a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - a.wei_dy = KW * sizeof(uint16_t); - (*m_ip_kernel_f16)(&a); - } - // pair 1 - if (has_oc2) { - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - a.kh_cnt = kh_count; - a.src_dy = static_cast(-static_cast(IW * sizeof(uint16_t))); - const size_t w_base2 = idx_wei(0, - oc2, - static_cast(kz), - static_cast(ky_lo), - static_cast(kx_lo)); - const size_t w_base3 = has_oc3 ? idx_wei(0, - oc3, - static_cast(kz), - static_cast(ky_lo), - static_cast(kx_lo)) - : 0; - a.wei = wei_p + w_base2; - if (has_oc3) - a.wei2 = wei_p + w_base3; - a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = sizeof(uint16_t); - a.wei_dy = KW * sizeof(uint16_t); - (*m_ip_kernel_f16)(&a); - } - } else { + { // In-kernel kx only for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { const size_t ihh2 = static_cast(tyd - ky); @@ -854,8 +776,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - if (false) { - // In-kernel kx loop for parity taps: step weights by 2*IC and src by -1 in X for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { const size_t id = static_cast((tzd - kz) / 2); if (id >= ID) continue; @@ -873,7 +793,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; - const bool use_s2_pack = true; const uint16_t* wei_pack_ptr = m_wei_packed_s2_f16.data(); auto pack_index_eo = [&](size_t py, size_t kx) { const size_t even_count = (KW + 1) / 2; @@ -902,7 +821,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st if (has_oc1) a.wei2 = wei_pack_ptr + base1; a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = (use_s2_pack ? m_padded_IC_f16 : 2 * m_padded_IC_f16) * sizeof(uint16_t); + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); (*m_ip_kernel_f16)(&a); } // pair 1 @@ -923,13 +842,13 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st if (has_oc3) a.wei2 = wei_pack_ptr + base3; a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = (use_s2_pack ? m_padded_IC_f16 : 2 * m_padded_IC_f16) * sizeof(uint16_t); + a.wei_dx = m_padded_IC_f16 * sizeof(uint16_t); (*m_ip_kernel_f16)(&a); } } } - } else { - // Per-tap parity stepping (original path) + { + // Per-tap parity stepping for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { const size_t id = static_cast((tzd - kz) / 2); if (id >= ID) continue; @@ -947,10 +866,8 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st const size_t py1 = has_oc1 ? (pz1 + static_cast(ky)) * KW : 0; const size_t py2 = has_oc2 ? (pz2 + static_cast(ky)) * KW : 0; const size_t py3 = has_oc3 ? (pz3 + static_cast(ky)) * KW : 0; - const bool use_s2_pack_orig = true; - const uint16_t* wei_pack_ptr_orig = use_s2_pack_orig ? m_wei_packed_s2_f16.data() : m_wei_packed_f16.data(); + const uint16_t* wei_pack_ptr_orig = m_wei_packed_s2_f16.data(); auto pack_index_eo_orig = [&](size_t py, size_t kx) { - if (!use_s2_pack_orig) return py + kx; const size_t even_count = (KW + 1) / 2; return py + ((kx & 1) ? (even_count + (kx / 2)) : (kx / 2)); }; @@ -1011,99 +928,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } } - } else if (false) { - // Fast path S=2, dil=1 (raw weights): parity-filtered taps without modulus checks - const ptrdiff_t tzd = static_cast(od) + PD0; - const ptrdiff_t tyd = static_cast(oh) + PH0; - const ptrdiff_t txd = static_cast(ow_) + PW0; - - const ptrdiff_t kz_lo = std::max(0, tzd - static_cast(ID * 2) + 2); - const ptrdiff_t kz_hi = std::min(static_cast(KD) - 1, tzd); - const ptrdiff_t ky_lo = std::max(0, tyd - static_cast(IH * 2) + 2); - const ptrdiff_t ky_hi = std::min(static_cast(KH) - 1, tyd); - const ptrdiff_t kx_lo = std::max(0, txd - static_cast(IW * 2) + 2); - const ptrdiff_t kx_hi = std::min(static_cast(KW) - 1, txd); - - if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - for (ptrdiff_t kz = kz_lo + ((tzd - kz_lo) & 1); kz <= kz_hi; kz += 2) { - const size_t id = static_cast((tzd - kz) / 2); - if (id >= ID) continue; - const size_t src_z_off = id * IH * IW; - const size_t src_cg0 = g * ICg; - size_t s_base_row = n_base + src_cg0 * src_c_stride_elems + src_z_off; - for (ptrdiff_t ky = ky_lo + ((tyd - ky_lo) & 1); ky <= ky_hi; ky += 2) { - const size_t ih = static_cast((tyd - ky) / 2); - if (ih >= IH) continue; - for (ptrdiff_t kx = kx_lo + ((txd - kx_lo) & 1); kx <= kx_hi; kx += 2) { - const size_t iw = static_cast((txd - kx) / 2); - if (iw >= IW) continue; - const size_t s_base0 = s_base_row + ih * IW + iw; - // pair 0 (oc0/oc1) - { - const size_t w_base0 = idx_wei(0, - oc0, - static_cast(kz), - static_cast(ky), - static_cast(kx)); - const size_t w_base1 = has_oc1 ? idx_wei(0, - oc1, - static_cast(kz), - static_cast(ky), - static_cast(kx)) - : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = 1; - a.src_dx = 0; - a.wei = wei_p + w_base0; - if (has_oc1) - a.wei2 = wei_p + w_base1; - a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = 0; - (*m_ip_kernel_f16)(&a); - } - // pair 1 (oc2/oc3) - if (has_oc2) { - const size_t w_base2 = idx_wei(0, - oc2, - static_cast(kz), - static_cast(ky), - static_cast(kx)); - const size_t w_base3 = has_oc3 ? idx_wei(0, - oc3, - static_cast(kz), - static_cast(ky), - static_cast(kx)) - : 0; - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = 1; - a.src_dx = 0; - a.wei = wei_p + w_base2; - if (has_oc3) - a.wei2 = wei_p + w_base3; - a.wei_stride = wei_ic_stride_elems * sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = 0; - (*m_ip_kernel_f16)(&a); - } - } - } - } - } + } else { // Generic path (stride/dilation): modulus checks for (size_t kz = 0; kz < KD; ++kz) { @@ -1142,7 +967,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st static_cast(id_idx), static_cast(ih_idx), static_cast(iw_idx)); - // raw weight indices removed in product mode jit_conv3d_call_args a{}; a.src = src_p + s_base0; @@ -1171,7 +995,6 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st // second pair for oc2/oc3 if (has_oc2) { - // raw weight indices removed in product mode jit_conv3d_call_args a2{}; a2.src = src_p + s_base0; a2.src_stride = src_c_stride_elems * sizeof(uint16_t); @@ -1288,70 +1111,12 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st const size_t wei_ic_stride_elems = (grouped ? OCg : OC) * KD * KH * KW; ensure_weights_packed_f32(src); - // Output padding and dilations - const ptrdiff_t OPD0 = deconvAttrs.outputPadding.size() > 0 ? deconvAttrs.outputPadding[0] : 0; - const ptrdiff_t OPH0 = deconvAttrs.outputPadding.size() > 1 ? deconvAttrs.outputPadding[1] : 0; - const ptrdiff_t OPW0 = deconvAttrs.outputPadding.size() > 2 ? deconvAttrs.outputPadding[2] : 0; + // Output dilations const size_t dilD = deconvAttrs.dilation.size() > 0 ? static_cast(deconvAttrs.dilation[0]) + 1 : 1; const size_t dilH = deconvAttrs.dilation.size() > 1 ? static_cast(deconvAttrs.dilation[1]) + 1 : 1; const size_t dilW = deconvAttrs.dilation.size() > 2 ? static_cast(deconvAttrs.dilation[2]) + 1 : 1; - if (false) { - const bool grouped = weiDims.size() == 6; - const size_t G = grouped ? weiDims[0] : 1; - const size_t ICg = grouped ? weiDims[1] : IC; - const size_t OCg = grouped ? weiDims[2] : OC; - std::fill_n(dst_p, N * OC * OD * OH * OW, 0.0F); - for (size_t n = 0; n < N; ++n) { - for (size_t g = 0; g < G; ++g) { - for (size_t id = 0; id < ID; ++id) { - for (size_t ih = 0; ih < IH; ++ih) { - for (size_t iw_ = 0; iw_ < IW; ++iw_) { - for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t od = static_cast(id) * static_cast(SD) - PD0 + - static_cast(kz * dilD) + OPD0; - if (od < 0 || od >= static_cast(OD)) - continue; - for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t oh = static_cast(ih) * static_cast(SH) - PH0 + - static_cast(ky * dilH) + OPH0; - if (oh < 0 || oh >= static_cast(OH)) - continue; - for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ow = static_cast(iw_) * static_cast(SW) - - PW0 + static_cast(kx * dilW) + OPW0; - if (ow < 0 || ow >= static_cast(OW)) - continue; - for (size_t icg = 0; icg < ICg; ++icg) { - const size_t ic_global = g * ICg + icg; - const float sval = src_p[idx_src(n, ic_global, id, ih, iw_)]; - for (size_t ocg = 0; ocg < OCg; ++ocg) { - const size_t oc_global = g * OCg + ocg; - size_t w_off; - if (grouped) { - w_off = - (((((g * ICg + icg) * OCg + ocg) * KD + kz) * KH + ky) * KW + - kx); - } else { - w_off = ((((icg)*OC + oc_global) * KD + kz) * KH + ky) * KW + kx; - } - dst_p[idx_dst(n, - oc_global, - static_cast(od), - static_cast(oh), - static_cast(ow))] += sval * wei_p[w_off]; - } - } - } - } - } - } - } - } - } - } - return; - } + ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { const size_t oc0 = oc_quad * 4; @@ -1841,7 +1606,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st static_cast(id_idx), static_cast(ih_idx), static_cast(iw_idx)); - // raw weight indices removed // pair 0 { @@ -1872,7 +1636,6 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } // pair 1 if (has_oc2) { - // raw weight indices removed jit_conv3d_f32_call_args a2{}; a2.src = src_p + s_base0; a2.src_stride = src_c_stride_elems * sizeof(float); From 3653f2979b3ec7e8a39bbb90b3cbb3df8496aaa7 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 19:07:36 +0200 Subject: [PATCH 13/20] Simplify and clean up AArch64 JIT 3D Convolution and Deconvolution Executors by removing redundant comments, obsolete logic, and unused code to improve clarity and maintainability. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 23 ------- .../nodes/executors/aarch64/jit_conv3d.hpp | 18 +---- .../executors/aarch64/jit_conv3d_f32.cpp | 65 +++++-------------- .../executors/aarch64/jit_conv3d_f32.hpp | 30 ++++----- .../nodes/executors/aarch64/jit_deconv3d.hpp | 9 +-- 5 files changed, 34 insertions(+), 111 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index f341f3dcb9e369..a1db952e46f752 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -20,15 +20,12 @@ #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/core/type/float16.hpp" -// helper for jit_kernel_cast #include "utils/cpu_utils.hpp" -// no direct NEON intrinsics are used here; we rely on Xbyak_aarch64 using namespace dnnl::impl::cpu::aarch64; namespace ov::intel_cpu { -// --------------------------- JIT kernel (placeholder) --------------------------- JitConv3DKernelF16::JitConv3DKernelF16() = default; void JitConv3DKernelF16::create_ker() { @@ -38,9 +35,7 @@ void JitConv3DKernelF16::create_ker() { void JitConv3DKernelF16::gen_minimal_kernel() { using namespace Xbyak_aarch64; - // Minimal stable kernel (dual-OC, in-kernel kx loop) const XReg reg_args = abi_param1; // x0 - // Load essential arguments (absolute offsets from jit_conv3d_call_args) const XReg reg_src = x1; // const uint16_t* src const XReg reg_wei = x2; // const uint16_t* wei const XReg reg_wei2 = x3; // const uint16_t* wei2 (optional) @@ -62,18 +57,14 @@ void JitConv3DKernelF16::gen_minimal_kernel() { ldr(reg_acc2, ptr(reg_args, 96)); Label Lsingle, Ldone; - // Additional labels for kx-loop variants Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; cbz(reg_acc2, Lsingle); - // Jump to in-kernel kx loop dual-OC path (safe, call-clobbered only) b(Ldual_kx); - // Dual-OC with in-kernel kx loop (v20 for oc0, v21 for oc1) L(Ldual_kx); eor(VReg16B(20), VReg16B(20), VReg16B(20)); eor(VReg16B(21), VReg16B(21), VReg16B(21)); - // Load kx-loop controls and set bases const XReg reg_kw_cnt = x12; const XReg reg_src_dx = x13; const XReg reg_wei_dx = x14; @@ -88,21 +79,17 @@ void JitConv3DKernelF16::gen_minimal_kernel() { mov(q_src_base, reg_src); mov(q_wei_base, reg_wei); mov(q_wei2_base, reg_wei2); - // Treat kw_cnt==0 as 1 cbnz(reg_kw_cnt, Lkx_d); mov(reg_kw_cnt, 1); L(Lkx_d); - // Reset pointers and repeats for this kx ldr(reg_reps, ptr(reg_args, 40)); mov(reg_src, q_src_base); mov(reg_wei, q_wei_base); mov(reg_wei2, q_wei2_base); - // Channel repeats over 8-lane blocks Label Lrep_d_kx; L(Lrep_d_kx); cmp(reg_reps, 0); b(EQ, Ltail_prep_d_kx); - // Load src lanes into v0 ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[1], ptr(reg_src)); @@ -159,14 +146,12 @@ void JitConv3DKernelF16::gen_minimal_kernel() { ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); L(Lw_done_d2); - // MAC into accumulators fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); sub(reg_reps, reg_reps, 1); b(Lrep_d_kx); - // Tail handling per kx L(Ltail_prep_d_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); @@ -237,7 +222,6 @@ void JitConv3DKernelF16::gen_minimal_kernel() { fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); fmlal2(VReg4S(21), VReg4H(0), VReg4H(2)); - // advance bases to next kx and continue sub(reg_kw_cnt, reg_kw_cnt, 1); add(q_src_base, q_src_base, reg_src_dx); add(q_wei_base, q_wei_base, reg_wei_dx); @@ -264,7 +248,6 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Lrep_d); cmp(reg_reps, 0); b(EQ, Ltail_prep_d); - // Load src lanes (v0) ld1(VReg(0).h[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[1], ptr(reg_src)); @@ -329,7 +312,6 @@ void JitConv3DKernelF16::gen_minimal_kernel() { sub(reg_reps, reg_reps, 1); b(Lrep_d); - // Tail handling L(Ltail_prep_d); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); @@ -417,7 +399,6 @@ void JitConv3DKernelF16::gen_minimal_kernel() { // Single-OC path L(Lsingle); - // Jump to in-kernel kx loop single-OC path b(Lsingle_kx); // Single-OC with in-kernel kx loop L(Lsingle_kx); @@ -539,7 +520,6 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Ltail_done_s_kx); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); - // advance bases sub(s_kw_cnt, s_kw_cnt, 1); add(s_src_base, s_src_base, s_src_dx); add(s_wei_base, s_wei_base, s_wei_dx); @@ -928,7 +908,6 @@ void JitConv3DKernelF16::gen_optimized_kernel() { fmlal(VReg4S(23), VReg4H(0), VReg4H(4)); fmlal2(VReg4S(23), VReg4H(0), VReg4H(4)); - // advance bases for next kx add(q_src_base, q_src_base, reg_src_dx); add(q_wei_base, q_wei_base, reg_wei_dx); add(q_wei2_base, q_wei2_base, reg_wei_dx); @@ -946,7 +925,6 @@ void JitConv3DKernelF16::gen_optimized_kernel() { faddp(VReg2S(22), VReg2S(22), VReg2S(22)); faddp(VReg4S(23), VReg4S(23), VReg4S(23)); faddp(VReg2S(23), VReg2S(23), VReg2S(23)); - // advance bases for next ky and continue if any add(q_src_base, q_src_base, reg_src_dy); add(q_wei_base, q_wei_base, reg_wei_dy); add(q_wei2_base, q_wei2_base, reg_wei_dy); @@ -1624,7 +1602,6 @@ void JitConv3DKernelF16::generate() { gen_optimized_kernel(); } -// --------------------------- Executor --------------------------- [[maybe_unused]] static inline auto ptr_f16(const MemoryPtr& mem) -> const uint16_t* { return reinterpret_cast(mem->getData()); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index fd39ed1423477c..d3f94bac4984e1 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -13,10 +13,7 @@ #include "nodes/executors/executor.hpp" #include "nodes/executors/memory_arguments.hpp" #include "onednn/iml_type_mapper.h" - -// Xbyak AArch64 JIT #include -// FP32 kernel #include "nodes/executors/aarch64/jit_conv3d_f32.hpp" namespace ov::intel_cpu { @@ -60,7 +57,6 @@ class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { private: void generate() override; - // Split large codegen into smaller helpers to satisfy clang-tidy limits void gen_minimal_kernel(); void gen_optimized_kernel(); @@ -70,7 +66,6 @@ class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { void set_force_single_kh(bool v) { m_force_single_kh_ = v; } }; -// AArch64 JIT Convolution (FP16) executor for 3D conv (NCDHW) class JitConv3DExecutor : public Executor { public: JitConv3DExecutor(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); @@ -92,20 +87,15 @@ class JitConv3DExecutor : public Executor { static bool supports(const ConvConfig& cfg); - // Early weight preparation to reduce first-inference latency (product mode, no flags) void prepare_weights_early(const MemoryArgs& memory); private: - // Simple reference fallback (parallelized) using FP16 data; correctness-first void run_naive_fp16(const MemoryArgs& memory); void ensure_weights_packed(const MemoryArgs& memory); - // FP32 path void run_naive_fp32(const MemoryArgs& memory); void ensure_weights_packed_f32(const MemoryArgs& memory); - // Minimal inner-product kernel (fp16 x fp16 -> f32 accumulation) std::unique_ptr m_ip_kernel; - // Minimal inner-product kernel (fp32 x fp32 -> f32 accumulation) std::unique_ptr m_ip_kernel_f32; ConvAttrs m_attrs; @@ -113,21 +103,15 @@ class JitConv3DExecutor : public Executor { size_t m_threadsNum{0}; bool m_is_fp32{false}; - // Packed weights: layout [OC, KD, KH, KW, Ct] where Ct is 8-lane channel tiles std::vector m_wei_packed; bool m_wei_packed_ready{false}; size_t m_padded_C{0}; - // FP32 packed weights: [OC, KD, KH, KW, Ct=4] std::vector m_wei_packed_f32; bool m_wei_packed_ready_f32{false}; size_t m_padded_C_f32{0}; - // Optional fused PReLU (per-tensor or per-channel). Extracted from attrs.postOps. bool m_has_prelu{false}; - std::vector m_prelu_slopes; // size 1 (per-tensor) or OC (per-channel) - - // Gate executor-side post-ops (bias, PReLU). Disabled per user request for measurements. - bool m_apply_post_ops{false}; + std::vector m_prelu_slopes; }; using JitConv3DExecutorPtr = std::shared_ptr; diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp index 8b58711cc34cae..b40711225e333a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -22,14 +22,12 @@ #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" #include "utils/general_utils.h" -// helper for jit_kernel_cast #include "utils/cpu_utils.hpp" using namespace dnnl::impl::cpu::aarch64; namespace ov::intel_cpu { -// --------------------------- JIT kernel (FP32) --------------------------- void JitConv3DKernelF32::create_ker() { jit_generator::create_kernel(); ker_ = jit_kernel_cast(jit_ker()); @@ -37,24 +35,23 @@ void JitConv3DKernelF32::create_ker() { void JitConv3DKernelF32::generate() { using namespace Xbyak_aarch64; - const XReg reg_args = abi_param1; // x0 - - const XReg reg_src = x1; // const float* src - const XReg reg_wei = x2; // const float* wei - const XReg reg_wei2 = x3; // const float* wei2 (optional) - const XReg reg_reps = x4; // size_t repeats (C/4) - const XReg reg_tail = x5; // size_t tail (C%4) - const XReg reg_src_stride = x6; // bytes between channels - const XReg reg_wei_stride = x7; // bytes between channels - const XReg reg_src_blk_stride = x8; // bytes between successive 4-ch blocks - const XReg reg_wei_blk_stride = x9; // bytes between successive 4-ch blocks - const XReg reg_acc = x10; // float* acc - const XReg reg_acc2 = x11; // float* acc2 (optional) - const XReg reg_kw_cnt = x12; // taps along W - const XReg reg_src_dx = x13; // bytes to step src base per kx - const XReg reg_wei_dx = x14; // bytes to step wei base per kx - - // Load args by struct offsets (see jit_conv3d_f32_call_args) + const XReg reg_args = abi_param1; + + const XReg reg_src = x1; + const XReg reg_wei = x2; + const XReg reg_wei2 = x3; + const XReg reg_reps = x4; + const XReg reg_tail = x5; + const XReg reg_src_stride = x6; + const XReg reg_wei_stride = x7; + const XReg reg_src_blk_stride = x8; + const XReg reg_wei_blk_stride = x9; + const XReg reg_acc = x10; + const XReg reg_acc2 = x11; + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + ldr(reg_src, ptr(reg_args, 0)); ldr(reg_wei, ptr(reg_args, 8)); ldr(reg_wei2, ptr(reg_args, 16)); @@ -70,10 +67,9 @@ void JitConv3DKernelF32::generate() { ldr(reg_src_dx, ptr(reg_args, 96)); ldr(reg_wei_dx, ptr(reg_args, 104)); - // Work registers for base pointers per kx const XReg q_src_base = x15; const XReg q_wei_base = x16; - const XReg q_wei2_base = x17; // avoid x18 + const XReg q_wei2_base = x17; Label Lsingle, Ldone; Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; @@ -82,34 +78,26 @@ void JitConv3DKernelF32::generate() { cbz(reg_acc2, Lsingle); b(Ldual_kx); - // ---------------- Dual-OC with in-kernel kx loop ---------------- L(Ldual_kx); - // accumulators v20 (oc0), v21 (oc1) eor(VReg16B(20), VReg16B(20), VReg16B(20)); eor(VReg16B(21), VReg16B(21), VReg16B(21)); - // Save bases mov(q_src_base, reg_src); mov(q_wei_base, reg_wei); mov(q_wei2_base, reg_wei2); - // Treat kw_cnt==0 as 1 cbnz(reg_kw_cnt, Lkx_d); mov(reg_kw_cnt, 1); - // kx loop L(Lkx_d); - // Reset per-kx pointers and repeats ldr(reg_reps, ptr(reg_args, 24)); mov(reg_src, q_src_base); mov(reg_wei, q_wei_base); mov(reg_wei2, q_wei2_base); - // repeats loop over channel tiles of 4 Label Lrep_d; L(Lrep_d); cmp(reg_reps, 0); b(EQ, Ltail_prep_d_kx); - // src lanes -> v0.s[0..3] ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[1], ptr(reg_src)); @@ -117,7 +105,6 @@ void JitConv3DKernelF32::generate() { ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[3], ptr(reg_src)); - // wei lanes: vector fast path if stride==4 bytes Label Lw_np_d, Lw_done_d; cmp(reg_wei_stride, 4); b(NE, Lw_np_d); @@ -141,19 +128,15 @@ void JitConv3DKernelF32::generate() { add(reg_wei2, reg_wei2, reg_wei_stride); ld1(VReg(1).s[3], ptr(reg_wei)); ld1(VReg(2).s[3], ptr(reg_wei2)); - // advance to next 4-channel block for next repeat add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); L(Lw_done_d); - // advance src to next 4-channel block for next repeat add(reg_src, reg_src, reg_src_stride); - // MAC fmla(VReg4S(20), VReg4S(0), VReg4S(1)); fmla(VReg4S(21), VReg4S(0), VReg4S(2)); sub(reg_reps, reg_reps, 1); b(Lrep_d); - // Tail per kx L(Ltail_prep_d_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); @@ -190,13 +173,11 @@ void JitConv3DKernelF32::generate() { L(Ltail_done_d_kx); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); fmla(VReg4S(21), VReg4S(0), VReg4S(2)); - // advance bases to next kx sub(reg_kw_cnt, reg_kw_cnt, 1); add(q_src_base, q_src_base, reg_src_dx); add(q_wei_base, q_wei_base, reg_wei_dx); add(q_wei2_base, q_wei2_base, reg_wei_dx); cbnz(reg_kw_cnt, Lkx_d); - // horizontal add and store faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); faddp(VReg4S(21), VReg4S(21), VReg4S(21)); @@ -209,7 +190,6 @@ void JitConv3DKernelF32::generate() { str(SReg(1), ptr(reg_acc2)); b(Ldone); - // ---------------- Single-OC with in-kernel kx loop ---------------- L(Lsingle); eor(VReg16B(20), VReg16B(20), VReg16B(20)); mov(q_src_base, reg_src); @@ -218,7 +198,6 @@ void JitConv3DKernelF32::generate() { mov(reg_kw_cnt, 1); L(Lsingle_kx); - // Reset per-kx pointers and repeats ldr(reg_reps, ptr(reg_args, 24)); mov(reg_src, q_src_base); mov(reg_wei, q_wei_base); @@ -227,7 +206,6 @@ void JitConv3DKernelF32::generate() { L(Lrep_s); cmp(reg_reps, 0); b(EQ, Ltail_prep_s_kx); - // src lanes ld1(VReg(0).s[0], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[1], ptr(reg_src)); @@ -235,7 +213,6 @@ void JitConv3DKernelF32::generate() { ld1(VReg(0).s[2], ptr(reg_src)); add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).s[3], ptr(reg_src)); - // wei lanes: vector fast path if stride==4 Label Lw_np_s, Lw_done_s; cmp(reg_wei_stride, 4); b(NE, Lw_np_s); @@ -250,16 +227,13 @@ void JitConv3DKernelF32::generate() { ld1(VReg(1).s[2], ptr(reg_wei)); add(reg_wei, reg_wei, reg_wei_stride); ld1(VReg(1).s[3], ptr(reg_wei)); - // advance to next 4-channel block for next repeat add(reg_wei, reg_wei, reg_wei_stride); L(Lw_done_s); - // advance src to next 4-channel block for next repeat add(reg_src, reg_src, reg_src_stride); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); sub(reg_reps, reg_reps, 1); b(Lrep_s); - // Tail single L(Ltail_prep_s_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); @@ -288,13 +262,11 @@ void JitConv3DKernelF32::generate() { L(Ltail_done_s_kx); fmla(VReg4S(20), VReg4S(0), VReg4S(1)); - // advance to next kx sub(reg_kw_cnt, reg_kw_cnt, 1); add(q_src_base, q_src_base, reg_src_dx); add(q_wei_base, q_wei_base, reg_wei_dx); cbnz(reg_kw_cnt, Lsingle_kx); - // reduce and store faddp(VReg4S(20), VReg4S(20), VReg4S(20)); faddp(VReg2S(20), VReg2S(20), VReg2S(20)); ldr(SReg(0), ptr(reg_acc)); @@ -306,7 +278,6 @@ void JitConv3DKernelF32::generate() { ret(); } -// --------------------------- Executor (FP32) --------------------------- JitConv3DExecutorF32::JitConv3DExecutorF32(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& /*context*/) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp index 88f2067a8f0163..935a1741bf12c6 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp @@ -13,26 +13,25 @@ #include "nodes/executors/executor.hpp" #include "nodes/executors/memory_arguments.hpp" -// Xbyak AArch64 JIT #include namespace ov::intel_cpu { struct jit_conv3d_f32_call_args { - const float* src; // f32 base ptr - const float* wei; // f32 base ptr (oc0) - const float* wei2; // optional second oc f32 base ptr (can be null) - size_t repeats; // number of full 4-channel blocks - size_t tail; // remaining channels (< 4) - size_t src_stride; // stride between channels in bytes - size_t wei_stride; // stride between channels in bytes - size_t src_blk_stride; // stride between successive 4-channel blocks in bytes - size_t wei_blk_stride; // stride between successive 4-channel blocks in bytes - float* acc; // f32 accumulator - float* acc2; // optional second f32 accumulator (can be null) - size_t kw_cnt; // number of taps along W to iterate (stride=1 fast path); 0 or 1 -> single - size_t src_dx; // bytes to advance src base between successive kx taps - size_t wei_dx; // bytes to advance weights base between successive kx taps + const float* src; + const float* wei; + const float* wei2; + size_t repeats; + size_t tail; + size_t src_stride; + size_t wei_stride; + size_t src_blk_stride; + size_t wei_blk_stride; + float* acc; + float* acc2; + size_t kw_cnt; + size_t src_dx; + size_t wei_dx; }; class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { @@ -53,7 +52,6 @@ class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { jit_fn ker_{nullptr}; }; -// AArch64 JIT Convolution (FP32) executor for 3D conv (NCDHW) class JitConv3DExecutorF32 : public Executor { public: JitConv3DExecutorF32(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index 83ad32e00b99c0..9c3e4c4e268f7c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -16,11 +16,9 @@ namespace ov::intel_cpu { class JitDeconv3DExecutor : public DeconvExecutor { public: explicit JitDeconv3DExecutor(ExecutorContext::CPtr context) : DeconvExecutor(std::move(context)) {} - // Constructor with early weights preparation (product mode): - // expects src[0]=input, src[1]=weights; guards dynamic shapes internally + // Early weight preparation in ctor (src[0]=input, src[1]=weights; guards dynamic shapes) JitDeconv3DExecutor(const std::vector& src, ExecutorContext::CPtr context) : DeconvExecutor(std::move(context)) { - // Derive precision from src[0] if available if (!src.empty() && src[0] && src[0]->getDescPtr()) { const auto prec = src[0]->getDescPtr()->getPrecision(); m_is_fp32 = (prec == ov::element::f32); @@ -32,7 +30,6 @@ class JitDeconv3DExecutor : public DeconvExecutor { m_ip_kernel_f16 = std::make_unique(); m_ip_kernel_f16->create_ker(); } - // Early pack (static shapes only) prepare_weights_early(src); } ~JitDeconv3DExecutor() override = default; @@ -50,21 +47,17 @@ class JitDeconv3DExecutor : public DeconvExecutor { return impl_desc_type::jit_asimd; } - // Early weight preparation to avoid first-inference overhead void prepare_weights_early(const std::vector& src); private: std::vector m_srcDescs; std::vector m_dstDescs; - // kernels std::unique_ptr m_ip_kernel_f16; std::unique_ptr m_ip_kernel_f32; bool m_is_fp32{false}; - // packed weights std::vector m_wei_packed_f16; std::vector m_wei_packed_f32; - // alternative packing for S=2 (even/odd taps) — FP16 only std::vector m_wei_packed_s2_f16; bool m_wei_packed_ready_f16{false}; bool m_wei_packed_ready_f32{false}; From f822a7ba21330be8aa0cb43fe0f25e5f5f9bea9b Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 19:16:36 +0200 Subject: [PATCH 14/20] Refactor AArch64 JIT 3D Convolution and Deconvolution Executors by consolidating repetitive kernel invocation logic into reusable lambda functions, improving code clarity, maintainability, and reducing duplication. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 88 ++++-------- .../nodes/executors/aarch64/jit_deconv3d.cpp | 135 ++++++------------ 2 files changed, 71 insertions(+), 152 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index a1db952e46f752..2d02dc7452163b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -2146,33 +2146,13 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { a.wei_dx = sizeof(float); (*m_ip_kernel_f32)(&a); if (has_oc2) { - const size_t w2 = index_wei(oc2, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w3 = has_oc3 ? index_wei(oc3, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - jit_conv3d_f32_call_args a2{}; - a2.src = src_p + s_base; - a2.src_stride = a.src_stride; - a2.src_blk_stride = a.src_blk_stride; + const size_t w2 = index_wei(oc2, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)); + const size_t w3 = has_oc3 ? index_wei(oc3, 0, static_cast(kz), static_cast(ky), static_cast(kx_lo)) : 0; + jit_conv3d_f32_call_args a2{a}; a2.acc = &acc2; a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = a.repeats; - a2.tail = a.tail; - a2.kw_cnt = a.kw_cnt; - a2.src_dx = a.src_dx; a2.wei = wei_p + w2; - if (has_oc3) - a2.wei2 = wei_p + w3; - a2.wei_stride = a.wei_stride; - a2.wei_blk_stride = a.wei_blk_stride; - a2.wei_dx = a.wei_dx; + if (has_oc3) a2.wei2 = wei_p + w3; (*m_ip_kernel_f32)(&a2); } } @@ -2198,52 +2178,36 @@ void JitConv3DExecutor::run_naive_fp32(const MemoryArgs& memory) { static_cast(iz), static_cast(iy), static_cast(ix)); - // pair 0 - { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; + auto run_pair_f32 = [&](float* acc, float* acc2, const float* w0, const float* w1) { + jit_conv3d_f32_call_args a{}; + a.src = src_p + s_base; a.src_stride = src_c_stride_elems * sizeof(float); a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; + a.acc = acc; + a.acc2 = acc2; a.repeats = C / 4; a.tail = C % 4; a.kw_cnt = 1; a.src_dx = 0; + a.wei = w0; + if (w1) a.wei2 = w1; + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; + a.wei_dx = 0; + (*m_ip_kernel_f32)(&a); + }; const size_t base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; - a.wei = m_wei_packed_f32.data() + base0; - if (has_oc1) { - const size_t base1 = (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; - a.wei2 = m_wei_packed_f32.data() + base1; + const size_t base1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32 : 0; + run_pair_f32(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f32.data() + base0, + has_oc1 ? m_wei_packed_f32.data() + base1 : nullptr); + if (has_oc2) { + const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; + const size_t base3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32 : 0; + run_pair_f32(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f32.data() + base2, + has_oc3 ? m_wei_packed_f32.data() + base3 : nullptr); } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel_f32)(&a); - } - // pair 1 - if (has_oc2) { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = 1; - a.src_dx = 0; - const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; - a.wei = m_wei_packed_f32.data() + base2; - if (has_oc3) { - const size_t base3 = (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C_f32; - a.wei2 = m_wei_packed_f32.data() + base3; - } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel_f32)(&a); - } } } } diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 43c55df0cde0d0..62a2beff9e516a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -930,7 +930,7 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st } } else { - // Generic path (stride/dilation): modulus checks + // Generic path (stride/dilation) for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t id_num = static_cast(od) + PD0 - static_cast(kz * dilD); @@ -968,57 +968,36 @@ void JitDeconv3DExecutor::exec_fp16(const std::vector& src, const st static_cast(ih_idx), static_cast(iw_idx)); - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = ICg / 8; - a.tail = ICg % 8; - a.kw_cnt = 1; - a.src_dx = 0; - { - const size_t pack_base0 = - (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; - a.wei = m_wei_packed_f16.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; - a.wei2 = m_wei_packed_f16.data() + pack_base1; - } + auto run_pair = [&](float* acc, float* acc2, const uint16_t* w0, const uint16_t* w1) { + jit_conv3d_call_args a{}; + a.src = src_p + s_base0; + a.src_stride = src_c_stride_elems * sizeof(uint16_t); + a.src_blk_stride = a.src_stride * 8; + a.acc = acc; + a.acc2 = acc2; + a.repeats = ICg / 8; + a.tail = ICg % 8; + a.kw_cnt = 1; + a.src_dx = 0; + a.wei = w0; + if (w1) a.wei2 = w1; a.wei_stride = sizeof(uint16_t); a.wei_blk_stride = a.wei_stride * 8; a.wei_dx = 0; - } - (*m_ip_kernel_f16)(&a); + (*m_ip_kernel_f16)(&a); + }; + const size_t base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t base1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16 : 0; + run_pair(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f16.data() + base0, + has_oc1 ? m_wei_packed_f16.data() + base1 : nullptr); - // second pair for oc2/oc3 if (has_oc2) { - jit_conv3d_call_args a2{}; - a2.src = src_p + s_base0; - a2.src_stride = src_c_stride_elems * sizeof(uint16_t); - a2.src_blk_stride = a2.src_stride * 8; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = ICg / 8; - a2.tail = ICg % 8; - a2.kw_cnt = 1; - a2.src_dx = 0; - { - const size_t pack_base2 = - (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; - a2.wei = m_wei_packed_f16.data() + pack_base2; - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; - a2.wei2 = m_wei_packed_f16.data() + pack_base3; - } - a2.wei_stride = sizeof(uint16_t); - a2.wei_blk_stride = a2.wei_stride * 8; - a2.wei_dx = 0; - } - (*m_ip_kernel_f16)(&a2); + const size_t base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16; + const size_t base3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f16 : 0; + run_pair(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f16.data() + base2, + has_oc3 ? m_wei_packed_f16.data() + base3 : nullptr); } } } @@ -1569,7 +1548,7 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st } } } else { - // Generic path (stride/dilation): modulus checks + // Generic path (stride/dilation) for (size_t kz = 0; kz < KD; ++kz) { const ptrdiff_t id_num = static_cast(od) + PD0 - static_cast(kz * dilD); @@ -1607,59 +1586,35 @@ void JitDeconv3DExecutor::exec_fp32(const std::vector& src, const st static_cast(ih_idx), static_cast(iw_idx)); - // pair 0 - { + auto run_pair_f32 = [&](float* acc, float* acc2, const float* w0, const float* w1) { jit_conv3d_f32_call_args a{}; a.src = src_p + s_base0; a.src_stride = src_c_stride_elems * sizeof(float); a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; + a.acc = acc; + a.acc2 = acc2; a.repeats = ICg / 4; a.tail = ICg % 4; a.kw_cnt = 1; a.src_dx = 0; - if (true) { - const size_t pack_base0 = - (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a.wei = m_wei_packed_f32.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a.wei2 = m_wei_packed_f32.data() + pack_base1; - } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - } else { /* unreachable */ } + a.wei = w0; + if (w1) a.wei2 = w1; + a.wei_stride = sizeof(float); + a.wei_blk_stride = a.wei_stride * 4; a.wei_dx = 0; (*m_ip_kernel_f32)(&a); - } - // pair 1 + }; + const size_t pb0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pb1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32 : 0; + run_pair_f32(&acc0, has_oc1 ? &acc1 : nullptr, + m_wei_packed_f32.data() + pb0, + has_oc1 ? m_wei_packed_f32.data() + pb1 : nullptr); if (has_oc2) { - jit_conv3d_f32_call_args a2{}; - a2.src = src_p + s_base0; - a2.src_stride = src_c_stride_elems * sizeof(float); - a2.src_blk_stride = a2.src_stride * 4; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = ICg / 4; - a2.tail = ICg % 4; - a2.kw_cnt = 1; - a2.src_dx = 0; - if (true) { - const size_t pack_base2 = - (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a2.wei = m_wei_packed_f32.data() + pack_base2; - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; - a2.wei2 = m_wei_packed_f32.data() + pack_base3; - } - a2.wei_stride = sizeof(float); - a2.wei_blk_stride = a2.wei_stride * 4; - } else { /* unreachable */ } - a2.wei_dx = 0; - (*m_ip_kernel_f32)(&a2); + const size_t pb2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32; + const size_t pb3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_IC_f32 : 0; + run_pair_f32(&acc2, has_oc3 ? &acc3 : nullptr, + m_wei_packed_f32.data() + pb2, + has_oc3 ? m_wei_packed_f32.data() + pb3 : nullptr); } } } From d8ed2c3fd03c1c078eec48f0282dd37f839700f8 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 19:19:58 +0200 Subject: [PATCH 15/20] Refactor AArch64 JIT 3D Convolution Executor by consolidating repetitive kernel invocation logic into reusable lambda functions, reducing code duplication and improving maintainability. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 147 ++++++------------ 1 file changed, 48 insertions(+), 99 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 2d02dc7452163b..14eb4872291a2b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -1779,59 +1779,31 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { const size_t iy2 = static_cast(iy0 + ky); const size_t ix2 = static_cast(ix0 + kx_lo); const size_t s_base2 = index_src(n, 0, iz, iy2, ix2); - jit_conv3d_call_args a{}; - a.src = src_p + s_base2; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = a.src_stride * 8; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 8; - a.tail = C % 8; - a.kw_cnt = kw_count; - a.src_dx = sizeof(uint16_t); - const size_t pack_base0 = - (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei = m_wei_packed.data() + pack_base0; - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base1; - } - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - a.wei_dx = m_padded_C * sizeof(uint16_t); - (*m_ip_kernel)(&a); + auto run_pair = [&](float* acc, float* acc2, size_t base0, size_t base1) { + jit_conv3d_call_args aa{}; + aa.src = src_p + s_base2; + aa.src_stride = src_c_stride_elems * sizeof(uint16_t); + aa.src_blk_stride = aa.src_stride * 8; + aa.acc = acc; + aa.acc2 = acc2; + aa.repeats = C / 8; + aa.tail = C % 8; + aa.kw_cnt = kw_count; + aa.src_dx = sizeof(uint16_t); + aa.wei = m_wei_packed.data() + base0; + if (acc2) aa.wei2 = m_wei_packed.data() + base1; + aa.wei_stride = sizeof(uint16_t); + aa.wei_blk_stride = aa.wei_stride * 8; + aa.wei_dx = m_padded_C * sizeof(uint16_t); + (*m_ip_kernel)(&aa); + }; + const size_t base0 = (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base1 = has_oc1 ? (((oc1 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C : 0; + run_pair(&acc0, has_oc1 ? &acc1 : nullptr, base0, base1); if (has_oc2) { - jit_conv3d_call_args a2{}; - a2.src = src_p + s_base2; - a2.src_stride = a.src_stride; - a2.src_blk_stride = a.src_blk_stride; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = a.repeats; - a2.tail = a.tail; - a2.kw_cnt = a.kw_cnt; - a2.src_dx = a.src_dx; - const size_t pack_base2 = - (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + - static_cast(kx_lo)) * - m_padded_C; - a2.wei = m_wei_packed.data() + pack_base2; - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + - static_cast(kx_lo)) * - m_padded_C; - a2.wei2 = m_wei_packed.data() + pack_base3; - } - a2.wei_stride = a.wei_stride; - a2.wei_blk_stride = a.wei_blk_stride; - a2.wei_dx = a.wei_dx; - (*m_ip_kernel)(&a2); + const size_t base2 = (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C; + const size_t base3 = has_oc3 ? (((oc3 * KD + static_cast(kz)) * KH + static_cast(ky)) * KW + static_cast(kx_lo)) * m_padded_C : 0; + run_pair(&acc2, has_oc3 ? &acc3 : nullptr, base2, base3); } } } @@ -1854,55 +1826,32 @@ void JitConv3DExecutor::run_naive_fp16(const MemoryArgs& memory) { static_cast(iz), static_cast(iy), static_cast(ix)); - // raw base indices removed in product mode - // pair 0 - { - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = - a.src_stride * - 8; // used logically, kernel advances by stride once after 8 lanes - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - - // packed index: ((((oc*KD + kz)*KH + ky)*KW + kx)*paddedC) - const size_t pack_base0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base0; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); // contiguous halves - a.wei_blk_stride = a.wei_stride * 8; // logical - if (has_oc1) { - const size_t pack_base1 = - (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base1; + auto run_pair2 = [&](float* acc, float* acc2, size_t base0, size_t base1) { + jit_conv3d_call_args aa{}; + aa.src = src_p + s_base0; + aa.src_stride = src_c_stride_elems * sizeof(uint16_t); + aa.src_blk_stride = aa.src_stride * 8; + aa.acc = acc; + aa.acc2 = acc2; + const size_t pack_base0 = base0; + aa.wei = m_wei_packed.data() + pack_base0; + aa.repeats = C / 8; + aa.tail = C % 8; + aa.wei_stride = sizeof(uint16_t); + aa.wei_blk_stride = aa.wei_stride * 8; + if (acc2) { + const size_t pack_base1 = base1; + aa.wei2 = m_wei_packed.data() + pack_base1; } - (*m_ip_kernel)(&a); - } - // pair 1 + (*m_ip_kernel)(&aa); + }; + const size_t b0 = (((oc0 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t b1 = has_oc1 ? (((oc1 * KD + kz) * KH + ky) * KW + kx) * m_padded_C : 0; + run_pair2(&acc0, has_oc1 ? &acc1 : nullptr, b0, b1); if (has_oc2) { - // raw base indices removed in product mode - jit_conv3d_call_args a{}; - a.src = src_p + s_base0; - a.src_stride = src_c_stride_elems * sizeof(uint16_t); - a.src_blk_stride = - a.src_stride * - 8; // used logically, kernel advances by stride once after 8 lanes - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - const size_t pack_base2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei = m_wei_packed.data() + pack_base2; - a.repeats = C / 8; - a.tail = C % 8; - a.wei_stride = sizeof(uint16_t); - a.wei_blk_stride = a.wei_stride * 8; - if (has_oc3) { - const size_t pack_base3 = - (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - a.wei2 = m_wei_packed.data() + pack_base3; - } - (*m_ip_kernel)(&a); + const size_t b2 = (((oc2 * KD + kz) * KH + ky) * KW + kx) * m_padded_C; + const size_t b3 = has_oc3 ? (((oc3 * KD + kz) * KH + ky) * KW + kx) * m_padded_C : 0; + run_pair2(&acc2, has_oc3 ? &acc3 : nullptr, b2, b3); } } } From 7e0d96014876e9cab40815b8ea9305cbb7c8636a Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Wed, 22 Oct 2025 20:46:23 +0200 Subject: [PATCH 16/20] Remove JitConv3DExecutorF32 implementation, associated helper functions, and redundant includes to simplify and clean up AArch64 JIT 3D Convolution code. --- .../executors/aarch64/jit_conv3d_f32.cpp | 385 +----------------- .../executors/aarch64/jit_conv3d_f32.hpp | 41 -- 2 files changed, 2 insertions(+), 424 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp index b40711225e333a..42fe0b23ef37cb 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp @@ -9,20 +9,10 @@ #include #include -#include #include #include -#include -#include -#include #include "cpu_memory.h" -#include "memory_desc/cpu_memory_desc.h" -#include "nodes/executors/implementation_utils.hpp" -#include "openvino/core/parallel.hpp" -#include "openvino/core/type/element_type.hpp" -#include "utils/general_utils.h" -#include "utils/cpu_utils.hpp" using namespace dnnl::impl::cpu::aarch64; @@ -30,8 +20,9 @@ namespace ov::intel_cpu { void JitConv3DKernelF32::create_ker() { jit_generator::create_kernel(); - ker_ = jit_kernel_cast(jit_ker()); + ker_ = reinterpret_cast(const_cast(jit_ker())); } + void JitConv3DKernelF32::generate() { using namespace Xbyak_aarch64; @@ -278,376 +269,4 @@ void JitConv3DKernelF32::generate() { ret(); } -JitConv3DExecutorF32::JitConv3DExecutorF32(const ConvAttrs& attrs, - const MemoryArgs& memory, - const ExecutorContext::CPtr& /*context*/) - : m_attrs(attrs) { - m_memory = memory; - m_ip_kernel = std::make_unique(); - m_ip_kernel->create_ker(); -} - -bool JitConv3DExecutorF32::supports(const ConvConfig& cfg) { - if (!cfg.descs.count(ARG_SRC) || !cfg.descs.count(ARG_WEI) || !cfg.descs.count(ARG_DST)) - return false; - if (!cfg.descs.at(ARG_SRC) || !cfg.descs.at(ARG_WEI) || !cfg.descs.at(ARG_DST)) - return false; - const auto& s = cfg.descs.at(ARG_SRC)->getShape(); - const auto& w = cfg.descs.at(ARG_WEI)->getShape(); - const auto& d = cfg.descs.at(ARG_DST)->getShape(); - if (s.getRank() != 5 || w.getRank() < 5 || d.getRank() != 5) - return false; - const auto sp = cfg.descs.at(ARG_SRC)->getPrecision(); - const auto wp = cfg.descs.at(ARG_WEI)->getPrecision(); - const auto dp = cfg.descs.at(ARG_DST)->getPrecision(); - if (!(sp == ov::element::f32 && wp == ov::element::f32 && dp == ov::element::f32)) - return false; - if (w.getRank() != 5) - return false; // groups unsupported here - for (auto v : cfg.attrs.dilation) { - if (v != 0) - return false; - } - for (auto v : cfg.attrs.stride) { - if (!(v == 1 || v == 2)) - return false; - } - return true; -} - -void JitConv3DExecutorF32::ensure_weights_packed(const MemoryArgs& memory) { - if (m_wei_packed_ready) - return; - auto src = memory.at(ARG_SRC); - auto wei = memory.at(ARG_WEI); - const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); - const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); - if (srcDims.size() != 5 || weiDims.size() != 5) - return; - const size_t C = srcDims[1]; - const size_t OC = weiDims[0]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - m_padded_C = (C + 3) / 4 * 4; - const size_t total = OC * KD * KH * KW * m_padded_C; - m_wei_packed.assign(total, 0.0F); - const auto* wsrc = reinterpret_cast(wei->getData()); - - auto idx_wei_src = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; - }; - auto idx_wei_pack = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) -> size_t { - const size_t base = (((oc * KD + kz) * KH + ky) * KW + kx) * m_padded_C; - const size_t blk = c / 4; - const size_t lane = c % 4; - return base + blk * 4 + lane; - }; - - for (size_t oc = 0; oc < OC; ++oc) { - for (size_t kz = 0; kz < KD; ++kz) { - for (size_t ky = 0; ky < KH; ++ky) { - for (size_t kx = 0; kx < KW; ++kx) { - for (size_t c = 0; c < C; ++c) { - m_wei_packed[idx_wei_pack(oc, c, kz, ky, kx)] = wsrc[idx_wei_src(oc, c, kz, ky, kx)]; - } - } - } - } - } - m_wei_packed_ready = true; -} - -void JitConv3DExecutorF32::run_naive_fp32(const MemoryArgs& memory) { - auto src = memory.at(ARG_SRC); - auto wei = memory.at(ARG_WEI); - auto dst = memory.at(ARG_DST); - const auto& srcDims = src->getDescPtr()->getShape().getStaticDims(); - const auto& weiDims = wei->getDescPtr()->getShape().getStaticDims(); - const auto& dstDims = dst->getDescPtr()->getShape().getStaticDims(); - - const size_t N = srcDims[0]; - const size_t C = srcDims[1]; - const size_t ID = srcDims[2], IH = srcDims[3], IW = srcDims[4]; - const size_t OC = weiDims[0]; - const size_t KD = weiDims[2], KH = weiDims[3], KW = weiDims[4]; - const size_t OD = dstDims[2], OH = dstDims[3], OW = dstDims[4]; - - const size_t SD = m_attrs.stride.size() > 0 ? m_attrs.stride[0] : 1; - const size_t SH = m_attrs.stride.size() > 1 ? m_attrs.stride[1] : 1; - const size_t SW = m_attrs.stride.size() > 2 ? m_attrs.stride[2] : 1; - - const ptrdiff_t PD0 = m_attrs.paddingL.size() > 0 ? m_attrs.paddingL[0] : 0; - const ptrdiff_t PH0 = m_attrs.paddingL.size() > 1 ? m_attrs.paddingL[1] : 0; - const ptrdiff_t PW0 = m_attrs.paddingL.size() > 2 ? m_attrs.paddingL[2] : 0; - - const auto* src_p = reinterpret_cast(src->getData()); - const auto* wei_p = reinterpret_cast(wei->getData()); - auto* dst_p = reinterpret_cast(dst->getData()); - - auto index_src = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { - return (((n * C + c) * ID + z) * IH + y) * IW + x; - }; - auto index_dst = [&](size_t n, size_t c, size_t z, size_t y, size_t x) { - return (((n * OC + c) * OD + z) * OH + y) * OW + x; - }; - auto index_wei = [&](size_t oc, size_t c, size_t kz, size_t ky, size_t kx) { - return ((((oc)*C + c) * KD + kz) * KH + ky) * KW + kx; - }; - - const size_t src_c_stride_elems = ID * IH * IW; // elements between channels - const size_t wei_c_stride_elems = KD * KH * KW; // elements between weight channels - - ensure_weights_packed(memory); - - ov::parallel_for2d(N, (OC + 3) / 4, [&](size_t n, size_t oc_quad) { - const size_t oc0 = oc_quad * 4; - const size_t oc1 = std::min(oc0 + 1, OC); - const size_t oc2 = std::min(oc0 + 2, OC); - const size_t oc3 = std::min(oc0 + 3, OC); - const bool has_oc1 = oc1 < OC; - const bool has_oc2 = oc2 < OC; - const bool has_oc3 = oc3 < OC; - - for (size_t od = 0; od < OD; ++od) { - const ptrdiff_t iz0 = static_cast(od) * static_cast(SD) - PD0; - for (size_t oh = 0; oh < OH; ++oh) { - const ptrdiff_t iy0 = static_cast(oh) * static_cast(SH) - PH0; - for (size_t ow = 0; ow < OW; ++ow) { - const ptrdiff_t ix0 = static_cast(ow) * static_cast(SW) - PW0; - - float acc0 = 0.0F, acc1 = 0.0F, acc2 = 0.0F, acc3 = 0.0F; - - if (SD == 1 && SH == 1 && SW == 1) { - const ptrdiff_t kz_lo = std::max(0, -iz0); - const ptrdiff_t kz_hi = - std::min(static_cast(KD) - 1, static_cast(ID) - 1 - iz0); - const ptrdiff_t ky_lo = std::max(0, -iy0); - const ptrdiff_t ky_hi = - std::min(static_cast(KH) - 1, static_cast(IH) - 1 - iy0); - const ptrdiff_t kx_lo = std::max(0, -ix0); - const ptrdiff_t kx_hi = - std::min(static_cast(KW) - 1, static_cast(IW) - 1 - ix0); - if (kz_lo <= kz_hi && ky_lo <= ky_hi && kx_lo <= kx_hi) { - const size_t kw_count = static_cast(kx_hi - kx_lo + 1); - for (ptrdiff_t kz = kz_lo; kz <= kz_hi; ++kz) { - const size_t iz = static_cast(iz0 + kz); - for (ptrdiff_t ky = ky_lo; ky <= ky_hi; ++ky) { - const size_t iy = static_cast(iy0 + ky); - const size_t ix = static_cast(ix0 + kx_lo); - const size_t s_base = index_src(n, 0, iz, iy, ix); - - if (m_wei_packed_ready) { - // dual pairs: (oc0,oc1), (oc2,oc3) - // pair 0 - { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = kw_count; - a.src_dx = sizeof(float); - const size_t base0 = - (((oc0 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei = m_wei_packed.data() + base0; - if (has_oc1) { - const size_t base1 = (((oc1 * KD + static_cast(kz)) * KH + - static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei2 = m_wei_packed.data() + base1; - } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = m_padded_C * sizeof(float); - (*m_ip_kernel)(&a); - } - // pair 1 - if (has_oc2) { - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = kw_count; - a.src_dx = sizeof(float); - const size_t base2 = - (((oc2 * KD + static_cast(kz)) * KH + static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei = m_wei_packed.data() + base2; - if (has_oc3) { - const size_t base3 = (((oc3 * KD + static_cast(kz)) * KH + - static_cast(ky)) * - KW + - static_cast(kx_lo)) * - m_padded_C; - a.wei2 = m_wei_packed.data() + base3; - } - a.wei_stride = sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = m_padded_C * sizeof(float); - (*m_ip_kernel)(&a); - } - } else { - // generic path: kx loop in kernel, but weights non-packed - const size_t w0 = index_wei(oc0, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w1 = has_oc1 ? index_wei(oc1, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = kw_count; - a.src_dx = sizeof(float); - a.wei = wei_p + w0; - if (has_oc1) - a.wei2 = wei_p + w1; - a.wei_stride = wei_c_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = sizeof(float); - (*m_ip_kernel)(&a); - - if (has_oc2) { - const size_t w2 = index_wei(oc2, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)); - const size_t w3 = has_oc3 ? index_wei(oc3, - 0, - static_cast(kz), - static_cast(ky), - static_cast(kx_lo)) - : 0; - jit_conv3d_f32_call_args a2{}; - a2.src = src_p + s_base; - a2.src_stride = a.src_stride; - a2.src_blk_stride = a.src_blk_stride; - a2.acc = &acc2; - a2.acc2 = has_oc3 ? &acc3 : nullptr; - a2.repeats = a.repeats; - a2.tail = a.tail; - a2.kw_cnt = a.kw_cnt; - a2.src_dx = a.src_dx; - a2.wei = wei_p + w2; - if (has_oc3) - a2.wei2 = wei_p + w3; - a2.wei_stride = a.wei_stride; - a2.wei_blk_stride = a.wei_blk_stride; - a2.wei_dx = a.wei_dx; - (*m_ip_kernel)(&a2); - } - } - } - } - } - } else { - // generic spatial stride path (host loops over all taps) - for (size_t kz = 0; kz < KD; ++kz) { - const ptrdiff_t iz = iz0 + static_cast(kz); - if (iz < 0 || iz >= static_cast(ID)) - continue; - for (size_t ky = 0; ky < KH; ++ky) { - const ptrdiff_t iy = iy0 + static_cast(ky); - if (iy < 0 || iy >= static_cast(IH)) - continue; - for (size_t kx = 0; kx < KW; ++kx) { - const ptrdiff_t ix = ix0 + static_cast(kx); - if (ix < 0 || ix >= static_cast(IW)) - continue; - const size_t s_base = index_src(n, - 0, - static_cast(iz), - static_cast(iy), - static_cast(ix)); - // pair 0 - { - const size_t w0 = index_wei(oc0, 0, kz, ky, kx); - const size_t w1 = has_oc1 ? index_wei(oc1, 0, kz, ky, kx) : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc0; - a.acc2 = has_oc1 ? &acc1 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = 1; - a.src_dx = 0; - a.wei = wei_p + w0; - if (has_oc1) - a.wei2 = wei_p + w1; - a.wei_stride = wei_c_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel)(&a); - } - // pair 1 - if (has_oc2) { - const size_t w2 = index_wei(oc2, 0, kz, ky, kx); - const size_t w3 = has_oc3 ? index_wei(oc3, 0, kz, ky, kx) : 0; - jit_conv3d_f32_call_args a{}; - a.src = src_p + s_base; - a.src_stride = src_c_stride_elems * sizeof(float); - a.src_blk_stride = a.src_stride * 4; - a.acc = &acc2; - a.acc2 = has_oc3 ? &acc3 : nullptr; - a.repeats = C / 4; - a.tail = C % 4; - a.kw_cnt = 1; - a.src_dx = 0; - a.wei = wei_p + w2; - if (has_oc3) - a.wei2 = wei_p + w3; - a.wei_stride = wei_c_stride_elems * sizeof(float); - a.wei_blk_stride = a.wei_stride * 4; - a.wei_dx = 0; - (*m_ip_kernel)(&a); - } - } - } - } - } - - // Store - dst_p[index_dst(n, oc0, od, oh, ow)] = acc0; - if (has_oc1) - dst_p[index_dst(n, oc1, od, oh, ow)] = acc1; - if (has_oc2) - dst_p[index_dst(n, oc2, od, oh, ow)] = acc2; - if (has_oc3) - dst_p[index_dst(n, oc3, od, oh, ow)] = acc3; - } - } - } - }); -} - -void JitConv3DExecutorF32::execute(const MemoryArgs& memory) { - run_naive_fp32(memory); -} - } // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp index 935a1741bf12c6..ee9b6fc4df03af 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp @@ -9,10 +9,6 @@ #include #include -#include "nodes/executors/convolution_config.hpp" -#include "nodes/executors/executor.hpp" -#include "nodes/executors/memory_arguments.hpp" - #include namespace ov::intel_cpu { @@ -52,41 +48,4 @@ class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { jit_fn ker_{nullptr}; }; -class JitConv3DExecutorF32 : public Executor { -public: - JitConv3DExecutorF32(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); - - bool update(const MemoryArgs& memory) override { - m_memory = memory; - return true; - } - void execute(const MemoryArgs& memory) override; - void execute() override { - execute(m_memory); - } - void exec([[maybe_unused]] const std::vector& src, - [[maybe_unused]] const std::vector& dst) override {} - - [[nodiscard]] impl_desc_type implType() const override { - return impl_desc_type::jit_asimd; - } - - static bool supports(const ConvConfig& cfg); - -private: - void run_naive_fp32(const MemoryArgs& memory); - void ensure_weights_packed(const MemoryArgs& memory); - - std::unique_ptr m_ip_kernel; - - ConvAttrs m_attrs; - MemoryArgs m_memory; - - std::vector m_wei_packed; // [OC, KD, KH, KW, Ct=4] - bool m_wei_packed_ready{false}; - size_t m_padded_C{0}; -}; - -using JitConv3DExecutorF32Ptr = std::shared_ptr; - } // namespace ov::intel_cpu From 7b7588d2b19424d4fe72b6c4420ad15d80edc5ea Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 23 Oct 2025 14:01:42 +0200 Subject: [PATCH 17/20] Remove JitConv3DKernelF32 implementation, associated helper functions, and unnecessary includes to streamline and clean up AArch64 JIT 3D Convolution code. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 251 ++++++++++++++++ .../nodes/executors/aarch64/jit_conv3d.hpp | 36 ++- .../executors/aarch64/jit_conv3d_f32.cpp | 272 ------------------ .../executors/aarch64/jit_conv3d_f32.hpp | 51 ---- .../nodes/executors/aarch64/jit_deconv3d.cpp | 1 - .../nodes/executors/aarch64/jit_deconv3d.hpp | 1 - 6 files changed, 286 insertions(+), 326 deletions(-) delete mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp delete mode 100644 src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 14eb4872291a2b..097a80836c28d9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -1602,6 +1602,257 @@ void JitConv3DKernelF16::generate() { gen_optimized_kernel(); } +void JitConv3DKernelF32::create_ker() { + jit_generator::create_kernel(); + ker_ = reinterpret_cast(const_cast(jit_ker())); +} + +void JitConv3DKernelF32::generate() { + using namespace Xbyak_aarch64; + + const XReg reg_args = abi_param1; + + const XReg reg_src = x1; + const XReg reg_wei = x2; + const XReg reg_wei2 = x3; + const XReg reg_reps = x4; + const XReg reg_tail = x5; + const XReg reg_src_stride = x6; + const XReg reg_wei_stride = x7; + const XReg reg_src_blk_stride = x8; + const XReg reg_wei_blk_stride = x9; + const XReg reg_acc = x10; + const XReg reg_acc2 = x11; + const XReg reg_kw_cnt = x12; + const XReg reg_src_dx = x13; + const XReg reg_wei_dx = x14; + + ldr(reg_src, ptr(reg_args, 0)); + ldr(reg_wei, ptr(reg_args, 8)); + ldr(reg_wei2, ptr(reg_args, 16)); + ldr(reg_reps, ptr(reg_args, 24)); + ldr(reg_tail, ptr(reg_args, 32)); + ldr(reg_src_stride, ptr(reg_args, 40)); + ldr(reg_wei_stride, ptr(reg_args, 48)); + ldr(reg_src_blk_stride, ptr(reg_args, 56)); + ldr(reg_wei_blk_stride, ptr(reg_args, 64)); + ldr(reg_acc, ptr(reg_args, 72)); + ldr(reg_acc2, ptr(reg_args, 80)); + ldr(reg_kw_cnt, ptr(reg_args, 88)); + ldr(reg_src_dx, ptr(reg_args, 96)); + ldr(reg_wei_dx, ptr(reg_args, 104)); + + const XReg q_src_base = x15; + const XReg q_wei_base = x16; + const XReg q_wei2_base = x17; + + Label Lsingle, Ldone; + Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; + Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; + + cbz(reg_acc2, Lsingle); + b(Ldual_kx); + + L(Ldual_kx); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + eor(VReg16B(21), VReg16B(21), VReg16B(21)); + + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + mov(q_wei2_base, reg_wei2); + cbnz(reg_kw_cnt, Lkx_d); + mov(reg_kw_cnt, 1); + + L(Lkx_d); + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + mov(reg_wei2, q_wei2_base); + + Label Lrep_d; + L(Lrep_d); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + Label Lw_np_d, Lw_done_d; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_d); + ld1(VReg4S(1), ptr(reg_wei)); + ld1(VReg4S(2), ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + add(reg_wei2, reg_wei2, reg_wei_blk_stride); + b(Lw_done_d); + L(Lw_np_d); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_wei2, reg_wei2, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + L(Lw_done_d); + add(reg_src, reg_src, reg_src_stride); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + sub(reg_reps, reg_reps, 1); + b(Lrep_d); + + L(Ltail_prep_d_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + eor(VReg16B(2), VReg16B(2), VReg16B(2)); + cmp(reg_tail, 0); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + ld1(VReg(2).s[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + ld1(VReg(2).s[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + ld1(VReg(2).s[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_d_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); + ld1(VReg(2).s[3], ptr(reg_wei2)); + L(Ltail_done_d_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + fmla(VReg4S(21), VReg4S(0), VReg4S(2)); + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + add(q_wei2_base, q_wei2_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lkx_d); + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + faddp(VReg4S(21), VReg4S(21), VReg4S(21)); + faddp(VReg2S(21), VReg2S(21), VReg2S(21)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + ldr(SReg(1), ptr(reg_acc2)); + fadd(SReg(1), SReg(1), SReg(21)); + str(SReg(1), ptr(reg_acc2)); + b(Ldone); + + L(Lsingle); + eor(VReg16B(20), VReg16B(20), VReg16B(20)); + mov(q_src_base, reg_src); + mov(q_wei_base, reg_wei); + cbnz(reg_kw_cnt, Lsingle_kx); + mov(reg_kw_cnt, 1); + + L(Lsingle_kx); + ldr(reg_reps, ptr(reg_args, 24)); + mov(reg_src, q_src_base); + mov(reg_wei, q_wei_base); + + Label Lrep_s; + L(Lrep_s); + cmp(reg_reps, 0); + b(EQ, Ltail_prep_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[1], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[2], ptr(reg_src)); + add(reg_src, reg_src, reg_src_stride); + ld1(VReg(0).s[3], ptr(reg_src)); + Label Lw_np_s, Lw_done_s; + cmp(reg_wei_stride, 4); + b(NE, Lw_np_s); + ld1(VReg4S(1), ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_blk_stride); + b(Lw_done_s); + L(Lw_np_s); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + ld1(VReg(1).s[3], ptr(reg_wei)); + add(reg_wei, reg_wei, reg_wei_stride); + L(Lw_done_s); + add(reg_src, reg_src, reg_src_stride); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + sub(reg_reps, reg_reps, 1); + b(Lrep_s); + + L(Ltail_prep_s_kx); + eor(VReg16B(0), VReg16B(0), VReg16B(0)); + eor(VReg16B(1), VReg16B(1), VReg16B(1)); + cmp(reg_tail, 0); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[0], ptr(reg_src)); + ld1(VReg(1).s[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[1], ptr(reg_src)); + ld1(VReg(1).s[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[2], ptr(reg_src)); + ld1(VReg(1).s[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); + add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); + b(LE, Ltail_done_s_kx); + ld1(VReg(0).s[3], ptr(reg_src)); + ld1(VReg(1).s[3], ptr(reg_wei)); + L(Ltail_done_s_kx); + fmla(VReg4S(20), VReg4S(0), VReg4S(1)); + + sub(reg_kw_cnt, reg_kw_cnt, 1); + add(q_src_base, q_src_base, reg_src_dx); + add(q_wei_base, q_wei_base, reg_wei_dx); + cbnz(reg_kw_cnt, Lsingle_kx); + + faddp(VReg4S(20), VReg4S(20), VReg4S(20)); + faddp(VReg2S(20), VReg2S(20), VReg2S(20)); + ldr(SReg(0), ptr(reg_acc)); + fadd(SReg(0), SReg(0), SReg(20)); + str(SReg(0), ptr(reg_acc)); + b(Ldone); + + L(Ldone); + ret(); +} + [[maybe_unused]] static inline auto ptr_f16(const MemoryPtr& mem) -> const uint16_t* { return reinterpret_cast(mem->getData()); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp index d3f94bac4984e1..d9100f7d780e90 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.hpp @@ -14,7 +14,6 @@ #include "nodes/executors/memory_arguments.hpp" #include "onednn/iml_type_mapper.h" #include -#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" namespace ov::intel_cpu { @@ -42,6 +41,23 @@ struct jit_conv3d_call_args { size_t wei_dy; // bytes to advance weights base between successive ky taps }; +struct jit_conv3d_f32_call_args { + const float* src; + const float* wei; + const float* wei2; + size_t repeats; + size_t tail; + size_t src_stride; + size_t wei_stride; + size_t src_blk_stride; + size_t wei_blk_stride; + float* acc; + float* acc2; + size_t kw_cnt; + size_t src_dx; + size_t wei_dx; +}; + class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { public: DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF16) @@ -66,6 +82,24 @@ class JitConv3DKernelF16 : public dnnl::impl::cpu::aarch64::jit_generator { void set_force_single_kh(bool v) { m_force_single_kh_ = v; } }; +class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF32) + using jit_fn = void (*)(const jit_conv3d_f32_call_args*); + + JitConv3DKernelF32() = default; + + void create_ker(); + inline void operator()(const jit_conv3d_f32_call_args* p) const { + ker_(p); + } + +private: + void generate() override; + + jit_fn ker_{nullptr}; +}; + class JitConv3DExecutor : public Executor { public: JitConv3DExecutor(const ConvAttrs& attrs, const MemoryArgs& memory, const ExecutorContext::CPtr& context); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp deleted file mode 100644 index 42fe0b23ef37cb..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.cpp +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" - -#include -#include -#include -#include - -#include -#include - -#include "cpu_memory.h" - -using namespace dnnl::impl::cpu::aarch64; - -namespace ov::intel_cpu { - -void JitConv3DKernelF32::create_ker() { - jit_generator::create_kernel(); - ker_ = reinterpret_cast(const_cast(jit_ker())); -} - -void JitConv3DKernelF32::generate() { - using namespace Xbyak_aarch64; - - const XReg reg_args = abi_param1; - - const XReg reg_src = x1; - const XReg reg_wei = x2; - const XReg reg_wei2 = x3; - const XReg reg_reps = x4; - const XReg reg_tail = x5; - const XReg reg_src_stride = x6; - const XReg reg_wei_stride = x7; - const XReg reg_src_blk_stride = x8; - const XReg reg_wei_blk_stride = x9; - const XReg reg_acc = x10; - const XReg reg_acc2 = x11; - const XReg reg_kw_cnt = x12; - const XReg reg_src_dx = x13; - const XReg reg_wei_dx = x14; - - ldr(reg_src, ptr(reg_args, 0)); - ldr(reg_wei, ptr(reg_args, 8)); - ldr(reg_wei2, ptr(reg_args, 16)); - ldr(reg_reps, ptr(reg_args, 24)); - ldr(reg_tail, ptr(reg_args, 32)); - ldr(reg_src_stride, ptr(reg_args, 40)); - ldr(reg_wei_stride, ptr(reg_args, 48)); - ldr(reg_src_blk_stride, ptr(reg_args, 56)); - ldr(reg_wei_blk_stride, ptr(reg_args, 64)); - ldr(reg_acc, ptr(reg_args, 72)); - ldr(reg_acc2, ptr(reg_args, 80)); - ldr(reg_kw_cnt, ptr(reg_args, 88)); - ldr(reg_src_dx, ptr(reg_args, 96)); - ldr(reg_wei_dx, ptr(reg_args, 104)); - - const XReg q_src_base = x15; - const XReg q_wei_base = x16; - const XReg q_wei2_base = x17; - - Label Lsingle, Ldone; - Label Ldual_kx, Lkx_d, Ltail_prep_d_kx, Ltail_done_d_kx; - Label Lsingle_kx, Lkx_s, Ltail_prep_s_kx, Ltail_done_s_kx; - - cbz(reg_acc2, Lsingle); - b(Ldual_kx); - - L(Ldual_kx); - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - eor(VReg16B(21), VReg16B(21), VReg16B(21)); - - mov(q_src_base, reg_src); - mov(q_wei_base, reg_wei); - mov(q_wei2_base, reg_wei2); - cbnz(reg_kw_cnt, Lkx_d); - mov(reg_kw_cnt, 1); - - L(Lkx_d); - ldr(reg_reps, ptr(reg_args, 24)); - mov(reg_src, q_src_base); - mov(reg_wei, q_wei_base); - mov(reg_wei2, q_wei2_base); - - Label Lrep_d; - L(Lrep_d); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_d_kx); - ld1(VReg(0).s[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[3], ptr(reg_src)); - Label Lw_np_d, Lw_done_d; - cmp(reg_wei_stride, 4); - b(NE, Lw_np_d); - ld1(VReg4S(1), ptr(reg_wei)); - ld1(VReg4S(2), ptr(reg_wei2)); - add(reg_wei, reg_wei, reg_wei_blk_stride); - add(reg_wei2, reg_wei2, reg_wei_blk_stride); - b(Lw_done_d); - L(Lw_np_d); - ld1(VReg(1).s[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[0], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[1], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).s[2], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).s[3], ptr(reg_wei)); - ld1(VReg(2).s[3], ptr(reg_wei2)); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - L(Lw_done_d); - add(reg_src, reg_src, reg_src_stride); - fmla(VReg4S(20), VReg4S(0), VReg4S(1)); - fmla(VReg4S(21), VReg4S(0), VReg4S(2)); - sub(reg_reps, reg_reps, 1); - b(Lrep_d); - - L(Ltail_prep_d_kx); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - eor(VReg16B(2), VReg16B(2), VReg16B(2)); - cmp(reg_tail, 0); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[0], ptr(reg_src)); - ld1(VReg(1).s[0], ptr(reg_wei)); - ld1(VReg(2).s[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[1], ptr(reg_src)); - ld1(VReg(1).s[1], ptr(reg_wei)); - ld1(VReg(2).s[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[2], ptr(reg_src)); - ld1(VReg(1).s[2], ptr(reg_wei)); - ld1(VReg(2).s[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_d_kx); - ld1(VReg(0).s[3], ptr(reg_src)); - ld1(VReg(1).s[3], ptr(reg_wei)); - ld1(VReg(2).s[3], ptr(reg_wei2)); - L(Ltail_done_d_kx); - fmla(VReg4S(20), VReg4S(0), VReg4S(1)); - fmla(VReg4S(21), VReg4S(0), VReg4S(2)); - sub(reg_kw_cnt, reg_kw_cnt, 1); - add(q_src_base, q_src_base, reg_src_dx); - add(q_wei_base, q_wei_base, reg_wei_dx); - add(q_wei2_base, q_wei2_base, reg_wei_dx); - cbnz(reg_kw_cnt, Lkx_d); - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - faddp(VReg4S(21), VReg4S(21), VReg4S(21)); - faddp(VReg2S(21), VReg2S(21), VReg2S(21)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); - ldr(SReg(1), ptr(reg_acc2)); - fadd(SReg(1), SReg(1), SReg(21)); - str(SReg(1), ptr(reg_acc2)); - b(Ldone); - - L(Lsingle); - eor(VReg16B(20), VReg16B(20), VReg16B(20)); - mov(q_src_base, reg_src); - mov(q_wei_base, reg_wei); - cbnz(reg_kw_cnt, Lsingle_kx); - mov(reg_kw_cnt, 1); - - L(Lsingle_kx); - ldr(reg_reps, ptr(reg_args, 24)); - mov(reg_src, q_src_base); - mov(reg_wei, q_wei_base); - - Label Lrep_s; - L(Lrep_s); - cmp(reg_reps, 0); - b(EQ, Ltail_prep_s_kx); - ld1(VReg(0).s[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).s[3], ptr(reg_src)); - Label Lw_np_s, Lw_done_s; - cmp(reg_wei_stride, 4); - b(NE, Lw_np_s); - ld1(VReg4S(1), ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_blk_stride); - b(Lw_done_s); - L(Lw_np_s); - ld1(VReg(1).s[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).s[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).s[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).s[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - L(Lw_done_s); - add(reg_src, reg_src, reg_src_stride); - fmla(VReg4S(20), VReg4S(0), VReg4S(1)); - sub(reg_reps, reg_reps, 1); - b(Lrep_s); - - L(Ltail_prep_s_kx); - eor(VReg16B(0), VReg16B(0), VReg16B(0)); - eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[0], ptr(reg_src)); - ld1(VReg(1).s[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[1], ptr(reg_src)); - ld1(VReg(1).s[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[2], ptr(reg_src)); - ld1(VReg(1).s[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).s[3], ptr(reg_src)); - ld1(VReg(1).s[3], ptr(reg_wei)); - L(Ltail_done_s_kx); - fmla(VReg4S(20), VReg4S(0), VReg4S(1)); - - sub(reg_kw_cnt, reg_kw_cnt, 1); - add(q_src_base, q_src_base, reg_src_dx); - add(q_wei_base, q_wei_base, reg_wei_dx); - cbnz(reg_kw_cnt, Lsingle_kx); - - faddp(VReg4S(20), VReg4S(20), VReg4S(20)); - faddp(VReg2S(20), VReg2S(20), VReg2S(20)); - ldr(SReg(0), ptr(reg_acc)); - fadd(SReg(0), SReg(0), SReg(20)); - str(SReg(0), ptr(reg_acc)); - b(Ldone); - - L(Ldone); - ret(); -} - -} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp deleted file mode 100644 index ee9b6fc4df03af..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d_f32.hpp +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include - -#include - -namespace ov::intel_cpu { - -struct jit_conv3d_f32_call_args { - const float* src; - const float* wei; - const float* wei2; - size_t repeats; - size_t tail; - size_t src_stride; - size_t wei_stride; - size_t src_blk_stride; - size_t wei_blk_stride; - float* acc; - float* acc2; - size_t kw_cnt; - size_t src_dx; - size_t wei_dx; -}; - -class JitConv3DKernelF32 : public dnnl::impl::cpu::aarch64::jit_generator { -public: - DECLARE_CPU_JIT_AUX_FUNCTIONS(JitConv3DKernelF32) - using jit_fn = void (*)(const jit_conv3d_f32_call_args*); - - JitConv3DKernelF32() = default; - - void create_ker(); - inline void operator()(const jit_conv3d_f32_call_args* p) const { - ker_(p); - } - -private: - void generate() override; - - jit_fn ker_{nullptr}; -}; - -} // namespace ov::intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 62a2beff9e516a..4648f6f764c847 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -10,7 +10,6 @@ #include #include "cpu_memory.h" -#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/core/type/float16.hpp" diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp index 9c3e4c4e268f7c..b9d961f5693d48 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.hpp @@ -8,7 +8,6 @@ #include #include "nodes/executors/aarch64/jit_conv3d.hpp" -#include "nodes/executors/aarch64/jit_conv3d_f32.hpp" #include "nodes/executors/deconv.hpp" namespace ov::intel_cpu { From 99d361f57a727b336ba9c8ebe25e69e088fef42d Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 23 Oct 2025 15:16:15 +0200 Subject: [PATCH 18/20] Refactor AArch64 JIT 3D Convolution and Deconvolution Executors by introducing reusable helper functions to consolidate repetitive load patterns, reducing code duplication and improving maintainability. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 188 +++++++----------- .../nodes/executors/aarch64/jit_deconv3d.cpp | 2 - 2 files changed, 69 insertions(+), 121 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 097a80836c28d9..3a2689c25cf7ce 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -81,6 +81,46 @@ void JitConv3DKernelF16::gen_minimal_kernel() { mov(q_wei2_base, reg_wei2); cbnz(reg_kw_cnt, Lkx_d); mov(reg_kw_cnt, 1); + + // helpers to emit identical load patterns without runtime cost + auto emit_src8 = [&](const XReg& src, const XReg& src_stride) { + ld1(VReg(0).h[0], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[1], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[2], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[3], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[4], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[5], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[6], ptr(src)); add(src, src, src_stride); + ld1(VReg(0).h[7], ptr(src)); + }; + auto emit_wei8_pair = [&](const XReg& wei, const XReg& wei2, const XReg& wei_stride, const XReg& wei_blk_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + ld1(VReg8H(2), ptr(wei2)); + add(wei, wei, wei_blk_stride); + add(wei2, wei2, wei_blk_stride); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[0], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[1], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[2], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[3], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[4], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[5], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(2).h[6], ptr(wei2)); add(wei2, wei2, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + ld1(VReg(2).h[7], ptr(wei2)); + L(Lw_done); + }; L(Lkx_d); ldr(reg_reps, ptr(reg_args, 40)); mov(reg_src, q_src_base); @@ -90,62 +130,8 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Lrep_d_kx); cmp(reg_reps, 0); b(EQ, Ltail_prep_d_kx); - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // Load weights for oc0/oc1 (vector fast path if stride==2) - Label Lw_np_d, Lw_done_d2; - cmp(reg_wei_stride, 2); - b(NE, Lw_np_d); - ld1(VReg8H(1), ptr(reg_wei)); - ld1(VReg8H(2), ptr(reg_wei2)); - add(reg_wei, reg_wei, reg_wei_blk_stride2); - add(reg_wei2, reg_wei2, reg_wei_blk_stride2); - b(Lw_done_d2); - L(Lw_np_d); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_wei2, reg_wei2, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); - L(Lw_done_d2); + emit_src8(reg_src, reg_src_stride); + emit_wei8_pair(reg_wei, reg_wei2, reg_wei_stride, reg_wei_blk_stride2); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); @@ -317,69 +303,33 @@ void JitConv3DKernelF16::gen_minimal_kernel() { eor(VReg16B(1), VReg16B(1), VReg16B(1)); eor(VReg16B(2), VReg16B(2), VReg16B(2)); // lanes 0..7 guarded by tail - cmp(reg_tail, 0); - b(LE, Ltail_done_d); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - ld1(VReg(2).h[0], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_d); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - ld1(VReg(2).h[1], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_d); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - ld1(VReg(2).h[2], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_d); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - ld1(VReg(2).h[3], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_d); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - ld1(VReg(2).h[4], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_d); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - ld1(VReg(2).h[5], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_d); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - ld1(VReg(2).h[6], ptr(reg_wei2)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - add(reg_wei2, reg_wei2, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_d); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - ld1(VReg(2).h[7], ptr(reg_wei2)); - - L(Ltail_done_d); + { + Label Ltail_done_d; + cmp(reg_tail, 0); b(LE, Ltail_done_d); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); ld1(VReg(2).h[0], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_d); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); ld1(VReg(2).h[1], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_d); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); ld1(VReg(2).h[2], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_d); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); ld1(VReg(2).h[3], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_d); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); ld1(VReg(2).h[4], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_d); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); ld1(VReg(2).h[5], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_d); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); ld1(VReg(2).h[6], ptr(reg_wei2)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); add(reg_wei2, reg_wei2, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_d); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); ld1(VReg(2).h[7], ptr(reg_wei2)); + L(Ltail_done_d); + } fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal(VReg4S(21), VReg4H(0), VReg4H(2)); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp index 4648f6f764c847..7c864786b3dc3c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_deconv3d.cpp @@ -16,8 +16,6 @@ namespace ov::intel_cpu { -// removed unused helpers - bool JitDeconv3DExecutor::init(const DeconvAttrs& attrs, const std::vector& srcDescs, const std::vector& dstDescs, From c69e4a177ec9b54ff897f5051a4327f269fff890 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 23 Oct 2025 15:20:08 +0200 Subject: [PATCH 19/20] Refactor AArch64 JIT 3D Convolution Executor by introducing reusable helpers for repeated load patterns, reducing code duplication and improving maintainability. --- .../nodes/executors/aarch64/jit_conv3d.cpp | 266 +++++++----------- 1 file changed, 94 insertions(+), 172 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp index 3a2689c25cf7ce..785058714af73e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_conv3d.cpp @@ -121,6 +121,42 @@ void JitConv3DKernelF16::gen_minimal_kernel() { ld1(VReg(2).h[7], ptr(wei2)); L(Lw_done); }; + auto emit_wei8_single_blk = [&](const XReg& wei, const XReg& wei_stride, const XReg& wei_blk_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + add(wei, wei, wei_blk_stride); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + L(Lw_done); + }; + auto emit_wei8_single16 = [&](const XReg& wei, const XReg& wei_stride) { + Label Lw_np, Lw_done; + cmp(wei_stride, 2); + b(NE, Lw_np); + ld1(VReg8H(1), ptr(wei)); + add(wei, wei, 16); + b(Lw_done); + L(Lw_np); + ld1(VReg(1).h[0], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[1], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[2], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[3], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[4], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[5], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[6], ptr(wei)); add(wei, wei, wei_stride); + ld1(VReg(1).h[7], ptr(wei)); + L(Lw_done); + }; L(Lkx_d); ldr(reg_reps, ptr(reg_args, 40)); mov(reg_src, q_src_base); @@ -234,21 +270,7 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Lrep_d); cmp(reg_reps, 0); b(EQ, Ltail_prep_d); - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); + emit_src8(reg_src, reg_src_stride); // Load wei lanes for oc0 (v1) and oc1 (v2) — vector fast path if wei_stride==2 Label Ldw_np_d, Ldw_done_d; cmp(reg_wei_stride, 2); @@ -391,29 +413,7 @@ void JitConv3DKernelF16::gen_minimal_kernel() { add(reg_src, reg_src, reg_src_stride); ld1(VReg(0).h[7], ptr(reg_src)); // weights (vector fast path if stride==2) - Label Lw_np_s, Lw_done_s2; - cmp(reg_wei_stride, 2); - b(NE, Lw_np_s); - ld1(VReg8H(1), ptr(reg_wei)); - add(reg_wei, reg_wei, s_wei_blk_stride2); - b(Lw_done_s2); - L(Lw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Lw_done_s2); + emit_wei8_single_blk(reg_wei, reg_wei_stride, s_wei_blk_stride2); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); sub(reg_reps, reg_reps, 1); @@ -421,53 +421,33 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Ltail_prep_s_kx); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_s_kx); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ltail_done_s_kx); + { + Label Ltail_done_s_kx; + cmp(reg_tail, 0); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s_kx); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s_kx); + } fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); sub(s_kw_cnt, s_kw_cnt, 1); @@ -485,46 +465,8 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Lrep_s); cmp(reg_reps, 0); b(EQ, Ltail_prep_s); - // src lanes - ld1(VReg(0).h[0], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[1], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[2], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[3], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[4], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[5], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[6], ptr(reg_src)); - add(reg_src, reg_src, reg_src_stride); - ld1(VReg(0).h[7], ptr(reg_src)); - // wei lanes — vector fast path if wei_stride==2 - Label Ldw_np_s, Ldw_done_s; - cmp(reg_wei_stride, 2); - b(NE, Ldw_np_s); - ld1(VReg8H(1), ptr(reg_wei)); - add(reg_wei, reg_wei, 16); - b(Ldw_done_s); - L(Ldw_np_s); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_wei, reg_wei, reg_wei_stride); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ldw_done_s); + emit_src8(reg_src, reg_src_stride); + emit_wei8_single16(reg_wei, reg_wei_stride); fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); sub(reg_reps, reg_reps, 1); @@ -534,53 +476,33 @@ void JitConv3DKernelF16::gen_minimal_kernel() { L(Ltail_prep_s); eor(VReg16B(0), VReg16B(0), VReg16B(0)); eor(VReg16B(1), VReg16B(1), VReg16B(1)); - cmp(reg_tail, 0); - b(LE, Ltail_done_s); - ld1(VReg(0).h[0], ptr(reg_src)); - ld1(VReg(1).h[0], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 1); - b(LE, Ltail_done_s); - ld1(VReg(0).h[1], ptr(reg_src)); - ld1(VReg(1).h[1], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 2); - b(LE, Ltail_done_s); - ld1(VReg(0).h[2], ptr(reg_src)); - ld1(VReg(1).h[2], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 3); - b(LE, Ltail_done_s); - ld1(VReg(0).h[3], ptr(reg_src)); - ld1(VReg(1).h[3], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 4); - b(LE, Ltail_done_s); - ld1(VReg(0).h[4], ptr(reg_src)); - ld1(VReg(1).h[4], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 5); - b(LE, Ltail_done_s); - ld1(VReg(0).h[5], ptr(reg_src)); - ld1(VReg(1).h[5], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 6); - b(LE, Ltail_done_s); - ld1(VReg(0).h[6], ptr(reg_src)); - ld1(VReg(1).h[6], ptr(reg_wei)); - add(reg_src, reg_src, reg_src_stride); - add(reg_wei, reg_wei, reg_wei_stride); - cmp(reg_tail, 7); - b(LE, Ltail_done_s); - ld1(VReg(0).h[7], ptr(reg_src)); - ld1(VReg(1).h[7], ptr(reg_wei)); - L(Ltail_done_s); + { + Label Ltail_done_s; + cmp(reg_tail, 0); b(LE, Ltail_done_s); + ld1(VReg(0).h[0], ptr(reg_src)); ld1(VReg(1).h[0], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 1); b(LE, Ltail_done_s); + ld1(VReg(0).h[1], ptr(reg_src)); ld1(VReg(1).h[1], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 2); b(LE, Ltail_done_s); + ld1(VReg(0).h[2], ptr(reg_src)); ld1(VReg(1).h[2], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 3); b(LE, Ltail_done_s); + ld1(VReg(0).h[3], ptr(reg_src)); ld1(VReg(1).h[3], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 4); b(LE, Ltail_done_s); + ld1(VReg(0).h[4], ptr(reg_src)); ld1(VReg(1).h[4], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 5); b(LE, Ltail_done_s); + ld1(VReg(0).h[5], ptr(reg_src)); ld1(VReg(1).h[5], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 6); b(LE, Ltail_done_s); + ld1(VReg(0).h[6], ptr(reg_src)); ld1(VReg(1).h[6], ptr(reg_wei)); + add(reg_src, reg_src, reg_src_stride); add(reg_wei, reg_wei, reg_wei_stride); + cmp(reg_tail, 7); b(LE, Ltail_done_s); + ld1(VReg(0).h[7], ptr(reg_src)); ld1(VReg(1).h[7], ptr(reg_wei)); + L(Ltail_done_s); + } fmlal(VReg4S(20), VReg4H(0), VReg4H(1)); fmlal2(VReg4S(20), VReg4H(0), VReg4H(1)); faddp(VReg4S(20), VReg4S(20), VReg4S(20)); From 1cba4b1621cc93dc284ee994c80ab8a17db15045 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Thu, 23 Oct 2025 16:37:20 +0200 Subject: [PATCH 20/20] Refactor Deconvolution node by replacing `execPtrDeconvACL` with factory-based executor `execPtrFactory`, consolidating execution logic, and removing obsolete `useACL` flag and related redundant paths. --- src/plugins/intel_cpu/src/nodes/deconv.cpp | 105 +++++------------- src/plugins/intel_cpu/src/nodes/deconv.h | 4 +- .../executors/convolution_implementations.cpp | 14 +++ 3 files changed, 44 insertions(+), 79 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index b1db674d8c1455..fd5c2bdbaaac7e 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -638,8 +638,8 @@ void Deconvolution::getSupportedDescriptors() { return AclDeconvExecutorBuilder::customIsSupported(deconvAttrs, srcMemoryDescs, dstMemoryDescs); }; - useACL = checkDesc(LayoutType::nspc) || checkDesc(LayoutType::ncsp); - if (useACL) { + + if (checkDesc(LayoutType::nspc) || checkDesc(LayoutType::ncsp)) { return; } #endif @@ -792,22 +792,18 @@ VectorDims Deconvolution::shapeInferInternal(const VectorDims& inDims, std::vect } void Deconvolution::execute(const dnnl::stream& strm) { - if (useACL) { + if (execPtrFactory) { std::vector srcMemory; - for (size_t i = 0; i < getOriginalInputsNumber(); i++) { + for (size_t i = 0; i < getOriginalInputsNumber(); i++) srcMemory.push_back(getSrcMemoryAtPort(i)); - } std::vector dstMemory; - for (size_t i = 0; i < getOriginalOutputsNumber(); i++) { + for (size_t i = 0; i < getOriginalOutputsNumber(); i++) dstMemory.push_back(getDstMemoryAtPort(i)); - } - // TODO: need to pass post ops data - execPtrDeconvACL->exec(srcMemory, dstMemory, nullptr); + execPtrFactory->exec(srcMemory, dstMemory, nullptr); return; } CPU_NODE_ASSERT(execPtr, "executor is not compiled"); - execPtr->exec(primArgs, strm); if (externOutShape) { @@ -969,7 +965,9 @@ void Deconvolution::prepareParams() { auto* selected_pd = getSelectedPrimitiveDescriptor(); CPU_NODE_ASSERT(selected_pd, "Preferable primitive descriptor is not set."); - if (useACL) { + // Minimal integration: always try factory path (ACL/JIT) with early-packing ctor; + // fall back to oneDNN path if factory does not provide an executor. + { if (isDynamicNode()) { initPaddingR(getParentEdgeAt(0)->getMemory().getDescPtr()->getShape(), getChildEdgeAt(0)->getMemory().getDescPtr()->getShape()); @@ -983,14 +981,24 @@ void Deconvolution::prepareParams() { dstMemoryDescs.push_back(getChildEdgeAt(i)->getMemory().getDescWithType()); } - // Build executor with constructor-time early packing (for JIT on ARM64); falls back to regular path std::vector srcMemoriesEarly; - srcMemoriesEarly.push_back(getSrcMemoryAtPort(0)); - srcMemoriesEarly.push_back(getSrcMemoryAtPort(1)); - execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutorWithMem( - deconvAttrs, srcMemoryDescs, dstMemoryDescs, *attr, srcMemoriesEarly); - selected_pd->setImplementationType(execPtrDeconvACL->getImplType()); - return; + for (size_t i = 0; i < getOriginalInputsNumber(); i++) { + srcMemoriesEarly.push_back(getSrcMemoryAtPort(i)); + } + + try { + auto factory = selected_pd->getExecutorFactoryAs(); + if (factory) { + auto exec = factory->makeExecutorWithMem(deconvAttrs, srcMemoryDescs, dstMemoryDescs, *attr, srcMemoriesEarly); + if (exec) { + execPtrFactory = exec; + selected_pd->setImplementationType(execPtrFactory->getImplType()); + return; + } + } + } catch (...) { + // Fallback to oneDNN path when factory isn't applicable + } } auto inMemoryDesc = getParentEdgeAt(0)->getMemory().getDescWithType(); auto outMemoryDesc = getChildEdgeAt(0)->getMemory().getDescWithType(); @@ -1355,70 +1363,13 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { dstMemoryDescs, std::make_shared(context, getImplPriority())); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); - useACL = true; // reuse factory-based execution path return; } } #endif - // If ACL path is not selected, try AArch64 JIT factory for 5D FP16/FP32 - if (!useACL) { -#if defined(OPENVINO_ARCH_ARM64) - const auto rank = getInputShapeAtPort(0).getRank(); - const bool is5D = (rank == 5); - const bool fp16_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f16 && - getOriginalInputPrecisionAtPort(1) == ov::element::f16 && - getOriginalOutputPrecisionAtPort(0) == ov::element::f16; - const bool fp32_ok = getOriginalInputPrecisionAtPort(0) == ov::element::f32 && - getOriginalInputPrecisionAtPort(1) == ov::element::f32 && - getOriginalOutputPrecisionAtPort(0) == ov::element::f32; - if (is5D && (fp16_ok || fp32_ok)) { - auto [inDims, outDims] = makeDummyInOutShape(); - auto tmpInShape = Shape(inDims); - auto tmpOutShape = Shape(outDims); - initPaddingR(tmpInShape, tmpOutShape); - - const auto& creatorsMap = BlockedDescCreator::getCommonCreators(); - NodeConfig config; - config.inConfs.resize(getParentEdges().size()); - config.outConfs.resize(getOriginalOutputsNumber()); - - auto setDesc = [&](size_t port, const Shape& shape, bool isInput) { - const auto prec = - isInput ? getOriginalInputPrecisionAtPort(port) : getOriginalOutputPrecisionAtPort(port); - const auto& shp = isInput ? getInputShapeAtPort(port) : getOutputShapeAtPort(port); - auto d = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(prec, shp); - if (isInput) - config.inConfs[port].setMemDesc(d); - else - config.outConfs[port].setMemDesc(d); - }; - setDesc(0, tmpInShape, true); - setDesc(1, Shape(getInputShapeAtPort(1).getStaticDims()), true); - for (size_t i = 2; i < getParentEdges().size(); ++i) - setDesc(i, Shape(getInputShapeAtPort(i).getStaticDims()), true); - setDesc(0, tmpOutShape, false); - - std::vector srcMemoryDescs; - srcMemoryDescs.push_back(config.inConfs[0].getMemDesc()->cloneWithNewDims(tmpInShape.getDims())); - for (size_t i = 1; i < config.inConfs.size(); i++) - srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()->clone()); - std::vector dstMemoryDescs; - dstMemoryDescs.push_back(config.outConfs[0].getMemDesc()->cloneWithNewDims(tmpOutShape.getDims())); - for (size_t i = 1; i < config.outConfs.size(); i++) - dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); - auto factory = - std::make_shared(deconvAttrs, - srcMemoryDescs, - dstMemoryDescs, - std::make_shared(context, getImplPriority())); - supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::jit_asimd, factory); - return; - } -#endif - Node::initSupportedPrimitiveDescriptors(); - return; - } + Node::initSupportedPrimitiveDescriptors(); + return; auto [inDims, outDims] = makeDummyInOutShape(); auto tmpInShape = Shape(inDims); diff --git a/src/plugins/intel_cpu/src/nodes/deconv.h b/src/plugins/intel_cpu/src/nodes/deconv.h index 2dda217f287845..1612a1c4e359dd 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.h +++ b/src/plugins/intel_cpu/src/nodes/deconv.h @@ -73,7 +73,8 @@ class Deconvolution : public Node { AttrPtr initPrimitiveAttr() override; AttrPtr makePrimitiveAttr(const VectorDims& dims); std::vector getAvailableFormatsForDims(const Shape& dims) const override; - std::shared_ptr execPtrDeconvACL = nullptr; + // Factory-based executor (JIT/ACL), created via DeconvExecutorFactory + std::shared_ptr execPtrFactory = nullptr; private: using executorPtr = std::shared_ptr; @@ -101,7 +102,6 @@ class Deconvolution : public Node { VectorDims dnnlCompatibleWeiDims; VectorDims expectedBiasDims; - bool useACL = false; DeconvAttrs deconvAttrs; Shape inShape, outShape; diff --git a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp index 9d6c2b5355b207..50c66c2f03b89e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convolution_implementations.cpp @@ -258,6 +258,20 @@ const std::vector>& getImplementations() { AcceptsAnyShape, CreateDnnlDefault{} ) + OV_CPU_INSTANCE_DNNL_X64( + "convolution_dnnl_nspc_nspc_unconditional_x64", ExecutorType::Dnnl, OperationType::Convolution, + // supports + [](const ConvConfig& config, const MemoryFormatFilter& memoryFormatFilter) -> bool { + // Unconditionally allow nspc path for x64 for shapes where backup may decline. + VERIFY(MatchesMemoryFormatFilter(config.descs, LayoutConfig{LayoutType::nspc, LayoutType::ncsp, LayoutType::nspc, LayoutType::nspc}, + memoryFormatFilter, dnnlConvolutionMappingNotation), MEMORY_FORMAT_MISMATCH); + VERIFY(!isQuantized(config), UNSUPPORTED_SRC_PRECISIONS); + return true; + }, + CreateOptimalConfigDefault{{LayoutType::nspc, LayoutType::ncsp, LayoutType::nspc, LayoutType::nspc}}, + AcceptsAnyShape, + CreateDnnlDefault{} + ) OV_CPU_INSTANCE_ACL( "convolution_dnnl_nspc_nspc_unconditional_acl", ExecutorType::Dnnl, OperationType::Convolution, // supports