@@ -482,13 +482,120 @@ namespace Simd
482482 }
483483 }
484484
485+ // -------------------------------------------------------------------------------------------------
486+
487+ static SIMD_INLINE bool Preferable_k3p1d1s1w6 (const ConvParam& p)
488+ {
489+ return p.IsKernel (3 ) && p.IsPad (1 ) && p.IsStride (1 ) && p.IsDilation (1 ) &&
490+ (p.srcW >= 6 && (p.srcW % 6 == 0 || p.srcW % 6 >= 4 ));
491+ }
492+
493+ template <typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w6 (const uint8_t * src8,
494+ const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float * weight, const float * bias, const float * params, uint8_t * dst)
495+ {
496+ assert (p.IsKernel (3 ) && p.IsPad (1 ) && p.IsStride (1 ) && p.IsDilation (1 ) && p.srcW >= 6 );
497+ const T* src = (T*)src8;
498+ size_t srcH = p.srcH , srcW = p.srcW ;
499+ size_t sM = (a.bufH [1 ] - 1 ), sD = a.bufH [1 ] ? a.bufH [1 ] * p.srcW * F : F, sX = a.bufH [1 ] ? F : p.srcC , sY = sX * p.srcW , dstC = maC;
500+ size_t dX = (a.bufH [2 ] ? a.maC * 2 : p.dstC * a.elem [1 ]), dY = p.dstW * dX, dy0 = a.bufH [2 ] ? yBeg : 0 , dD = a.bufH [2 ] ? F * 2 : F * a.elem [1 ];
501+ size_t wD = 9 * F, dstCF = AlignLo (dstC, F), dstW = p.dstW , endW = dstW - 6 ;
502+ size_t dstCe = a.bufH [2 ] ? AlignHi (dstC, DF) : dstC;
503+
504+ __m512 s0, s1, w0, w1, w2, d0, d1, d2, d3, d4, d5;
505+
506+ __m512 _params[2 ], _bias[1 ];
507+ _params[0 ] = _mm512_set1_ps (params[0 ]);
508+ if (type == SimdConvolutionActivationRestrictRange ||
509+ type == SimdConvolutionActivationHswish ||
510+ type == SimdConvolutionActivationHardSigmoid)
511+ _params[1 ] = _mm512_set1_ps (params[1 ]);
512+ for (size_t dc = 0 ; dc < dstCe; dc += F)
513+ {
514+ _bias[0 ] = _mm512_loadu_ps (bias + dc);
515+ if (type == ::SimdConvolutionActivationPrelu)
516+ _params[0 ] = _mm512_loadu_ps (params + dc);
517+ __mmask16 tailS = TailMask16 (dstC - dc);
518+ __mmask32 tailC = (dc == dstCF && a.bufH [2 ]) ? TailMask32 (dstCe - dstCF) : tailS;
519+ for (size_t dy = yBeg; dy < yEnd; ++dy)
520+ {
521+ for (size_t dx = 0 ;; dx += Min<size_t >(6 , endW - dx))
522+ {
523+ d0 = _mm512_setzero_ps ();
524+ d1 = _mm512_setzero_ps ();
525+ d2 = _mm512_setzero_ps ();
526+ d3 = _mm512_setzero_ps ();
527+ d4 = _mm512_setzero_ps ();
528+ d5 = _mm512_setzero_ps ();
529+ __mmask16 tailS0 = dx == 0 ? 0 : tailS;
530+ __mmask16 tailS1 = dx == endW ? 0 : tailS;
531+ for (size_t ky = 0 ; ky < 3 ; ++ky)
532+ {
533+ size_t sy = dy + ky - 1 ;
534+ const T* ps = src + (sy & sM ) * sY + (dx - 1 ) * sX ;
535+ const float * pw = weight + ky * 3 * F;
536+ if (sy < srcH)
537+ {
538+ w0 = _mm512_maskz_loadu_ps (tailS, pw + 0 * F);
539+ s0 = LoadSrc (ps + 0 * sX , tailS0);
540+ d0 = _mm512_fmadd_ps (s0, w0, d0);
541+
542+ w1 = _mm512_maskz_loadu_ps (tailS, pw + 1 * F);
543+ s1 = LoadSrc (ps + 1 * sX , tailS);
544+ d0 = _mm512_fmadd_ps (s1, w1, d0);
545+ d1 = _mm512_fmadd_ps (s1, w0, d1);
546+
547+ s0 = LoadSrc (ps + 2 * sX , tailS);
548+ w2 = _mm512_maskz_loadu_ps (tailS, pw + 2 * F);
549+ d0 = _mm512_fmadd_ps (s0, w2, d0);
550+ d1 = _mm512_fmadd_ps (s0, w1, d1);
551+ d2 = _mm512_fmadd_ps (s0, w0, d2);
552+
553+ s1 = LoadSrc (ps + 3 * sX , tailS);
554+ d1 = _mm512_fmadd_ps (s1, w2, d1);
555+ d2 = _mm512_fmadd_ps (s1, w1, d2);
556+ d3 = _mm512_fmadd_ps (s1, w0, d3);
557+
558+ s0 = LoadSrc (ps + 4 * sX , tailS);
559+ d2 = _mm512_fmadd_ps (s0, w2, d2);
560+ d3 = _mm512_fmadd_ps (s0, w1, d3);
561+ d4 = _mm512_fmadd_ps (s0, w0, d4);
562+
563+ s1 = LoadSrc (ps + 5 * sX , tailS);
564+ d3 = _mm512_fmadd_ps (s1, w2, d3);
565+ d4 = _mm512_fmadd_ps (s1, w1, d4);
566+ d5 = _mm512_fmadd_ps (s1, w0, d5);
567+
568+ s0 = LoadSrc (ps + 6 * sX , tailS);
569+ d4 = _mm512_fmadd_ps (s0, w2, d4);
570+ d5 = _mm512_fmadd_ps (s0, w1, d5);
571+
572+ s1 = LoadSrc (ps + 7 * sX , tailS1);
573+ d5 = _mm512_fmadd_ps (s1, w2, d5);
574+ }
575+ }
576+ uint8_t * pd = dst + (dy - dy0) * dY + dx * dX;
577+ Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
578+ Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
579+ Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
580+ Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
581+ Save1<term, type>(pd + 4 * dX, dD, d4, _bias, _params, tailC);
582+ Save1<term, type>(pd + 5 * dX, dD, d5, _bias, _params, tailC);
583+ if (dx == endW)
584+ break ;
585+ }
586+ }
587+ src += sD ;
588+ dst += dD;
589+ weight += wD;
590+ }
591+ }
485592
486593 // -------------------------------------------------------------------------------------------------
487594
488595 static SIMD_INLINE bool Preferable_k3p1d1s1w8 (const ConvParam& p)
489596 {
490597 return p.IsKernel (3 ) && p.IsPad (1 ) && p.IsStride (1 ) && p.IsDilation (1 ) &&
491- (p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6 )/* && AlignHiAny(p.srcW, 8) < AlignHiAny(p.srcW, 6) * 1.2 */ );
598+ (p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6 ));
492599 }
493600
494601 template <typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w8 (const uint8_t * src8,
@@ -614,6 +721,11 @@ namespace Simd
614721 depthwise = DepthwiseConvolution_k3p1d1s1w8<T, term, type>;
615722 return true ;
616723 }
724+ else if (Preferable_k3p1d1s1w6 (p))
725+ {
726+ depthwise = DepthwiseConvolution_k3p1d1s1w6<T, term, type>;
727+ return true ;
728+ }
617729 else if (IsKernel (p, 3 ) && IsDilation (p, 1 ) && IsStride (p, 1 ))
618730 {
619731 depthwise = DepthwiseConvolution3x3_V2<T, term, type>;
0 commit comments