Skip to content

Commit 2452284

Browse files
committed
+add AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w6 for class SynetMergedConvolution16b.
1 parent e7dd15c commit 2452284

File tree

4 files changed

+117
-4
lines changed

4 files changed

+117
-4
lines changed

docs/2024.html

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ <h5>New features</h5>
7171
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w6 for framework SynetMergedConvolution32f.</li>
7272
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w8 for framework SynetMergedConvolution32f.</li>
7373
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w8 for class SynetMergedConvolution16b.</li>
74-
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
7574
<li>Base implementation of function Yuv444pToRgbaV2.</li>
7675
</ul>
7776
<h5>Improving</h5>

docs/2025.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ <h5>New features</h5>
4444
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetTiledScale2D32f.</li>
4545
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w6 for class SynetMergedConvolution16b.</li>
4646
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w4 for class SynetMergedConvolution16b.</li>
47+
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
48+
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w6 for class SynetMergedConvolution16b.</li>
4749
</ul>
4850
<h5>Improving</h5>
4951
<ul>

src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise3x3.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,120 @@ namespace Simd
482482
}
483483
}
484484

485+
//-------------------------------------------------------------------------------------------------
486+
487+
static SIMD_INLINE bool Preferable_k3p1d1s1w6(const ConvParam& p)
488+
{
489+
return p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) &&
490+
(p.srcW >= 6 && (p.srcW % 6 == 0 || p.srcW % 6 >= 4));
491+
}
492+
493+
template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w6(const uint8_t* src8,
494+
const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
495+
{
496+
assert(p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) && p.srcW >= 6);
497+
const T* src = (T*)src8;
498+
size_t srcH = p.srcH, srcW = p.srcW;
499+
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
500+
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
501+
size_t wD = 9 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 6;
502+
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;
503+
504+
__m512 s0, s1, w0, w1, w2, d0, d1, d2, d3, d4, d5;
505+
506+
__m512 _params[2], _bias[1];
507+
_params[0] = _mm512_set1_ps(params[0]);
508+
if (type == SimdConvolutionActivationRestrictRange ||
509+
type == SimdConvolutionActivationHswish ||
510+
type == SimdConvolutionActivationHardSigmoid)
511+
_params[1] = _mm512_set1_ps(params[1]);
512+
for (size_t dc = 0; dc < dstCe; dc += F)
513+
{
514+
_bias[0] = _mm512_loadu_ps(bias + dc);
515+
if (type == ::SimdConvolutionActivationPrelu)
516+
_params[0] = _mm512_loadu_ps(params + dc);
517+
__mmask16 tailS = TailMask16(dstC - dc);
518+
__mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
519+
for (size_t dy = yBeg; dy < yEnd; ++dy)
520+
{
521+
for (size_t dx = 0;; dx += Min<size_t>(6, endW - dx))
522+
{
523+
d0 = _mm512_setzero_ps();
524+
d1 = _mm512_setzero_ps();
525+
d2 = _mm512_setzero_ps();
526+
d3 = _mm512_setzero_ps();
527+
d4 = _mm512_setzero_ps();
528+
d5 = _mm512_setzero_ps();
529+
__mmask16 tailS0 = dx == 0 ? 0 : tailS;
530+
__mmask16 tailS1 = dx == endW ? 0 : tailS;
531+
for (size_t ky = 0; ky < 3; ++ky)
532+
{
533+
size_t sy = dy + ky - 1;
534+
const T* ps = src + (sy & sM) * sY + (dx - 1) * sX;
535+
const float* pw = weight + ky * 3 * F;
536+
if (sy < srcH)
537+
{
538+
w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
539+
s0 = LoadSrc(ps + 0 * sX, tailS0);
540+
d0 = _mm512_fmadd_ps(s0, w0, d0);
541+
542+
w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
543+
s1 = LoadSrc(ps + 1 * sX, tailS);
544+
d0 = _mm512_fmadd_ps(s1, w1, d0);
545+
d1 = _mm512_fmadd_ps(s1, w0, d1);
546+
547+
s0 = LoadSrc(ps + 2 * sX, tailS);
548+
w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
549+
d0 = _mm512_fmadd_ps(s0, w2, d0);
550+
d1 = _mm512_fmadd_ps(s0, w1, d1);
551+
d2 = _mm512_fmadd_ps(s0, w0, d2);
552+
553+
s1 = LoadSrc(ps + 3 * sX, tailS);
554+
d1 = _mm512_fmadd_ps(s1, w2, d1);
555+
d2 = _mm512_fmadd_ps(s1, w1, d2);
556+
d3 = _mm512_fmadd_ps(s1, w0, d3);
557+
558+
s0 = LoadSrc(ps + 4 * sX, tailS);
559+
d2 = _mm512_fmadd_ps(s0, w2, d2);
560+
d3 = _mm512_fmadd_ps(s0, w1, d3);
561+
d4 = _mm512_fmadd_ps(s0, w0, d4);
562+
563+
s1 = LoadSrc(ps + 5 * sX, tailS);
564+
d3 = _mm512_fmadd_ps(s1, w2, d3);
565+
d4 = _mm512_fmadd_ps(s1, w1, d4);
566+
d5 = _mm512_fmadd_ps(s1, w0, d5);
567+
568+
s0 = LoadSrc(ps + 6 * sX, tailS);
569+
d4 = _mm512_fmadd_ps(s0, w2, d4);
570+
d5 = _mm512_fmadd_ps(s0, w1, d5);
571+
572+
s1 = LoadSrc(ps + 7 * sX, tailS1);
573+
d5 = _mm512_fmadd_ps(s1, w2, d5);
574+
}
575+
}
576+
uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
577+
Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
578+
Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
579+
Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
580+
Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
581+
Save1<term, type>(pd + 4 * dX, dD, d4, _bias, _params, tailC);
582+
Save1<term, type>(pd + 5 * dX, dD, d5, _bias, _params, tailC);
583+
if (dx == endW)
584+
break;
585+
}
586+
}
587+
src += sD;
588+
dst += dD;
589+
weight += wD;
590+
}
591+
}
485592

486593
//-------------------------------------------------------------------------------------------------
487594

488595
static SIMD_INLINE bool Preferable_k3p1d1s1w8(const ConvParam& p)
489596
{
490597
return p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) &&
491-
(p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6)/*&& AlignHiAny(p.srcW, 8) < AlignHiAny(p.srcW, 6) * 1.2*/);
598+
(p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6));
492599
}
493600

494601
template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w8(const uint8_t* src8,
@@ -614,6 +721,11 @@ namespace Simd
614721
depthwise = DepthwiseConvolution_k3p1d1s1w8<T, term, type>;
615722
return true;
616723
}
724+
else if (Preferable_k3p1d1s1w6(p))
725+
{
726+
depthwise = DepthwiseConvolution_k3p1d1s1w6<T, term, type>;
727+
return true;
728+
}
617729
else if (IsKernel(p, 3) && IsDilation(p, 1) && IsStride(p, 1))
618730
{
619731
depthwise = DepthwiseConvolution3x3_V2<T, term, type>;

src/Test/TestSynetMergedConvolution16b.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ namespace Test
282282
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(2, 512, 127, 127), Cnv(a0, 1, 1, 1024), Cnv(a1, 3, 1), Cnv(a2, 1, 1, 512), t, f32, f32), f1, f2);
283283
#endif
284284
#if 1
285-
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 6, 6), Cnv(aSw, 1, 1, 512), Cnv(aId, 3, 1), b16, f32), f1, f2);
286-
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 4, 8), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, f32), f1, f2);
285+
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 8, 6), Cnv(aSw, 1, 1, 512), Cnv(aId, 3, 1), b16, f32), f1, f2);
286+
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 6, 8), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, f32), f1, f2);
287287
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 12, 12), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, b16), f1, f2);
288288
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 128, 24, 24), Cnv(aSw, 1, 1, 128), Cnv(aSw, 3, 1), b16, b16), f1, f2);
289289
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 64, 48, 48), Cnv(aSw, 1, 1, 64), Cnv(aSw, 3, 1), b16, b16), f1, f2);

0 commit comments

Comments
 (0)