@@ -492,7 +492,7 @@ namespace Simd
492492 }
493493
494494 template <Term16bType term, SimdConvolutionActivationType type, int prev> void Convolution16bNhwcGemm_32x32b (const uint16_t * src0, const ConvParam& p, const AlgParam& a,
495- size_t dstS, size_t dstC, const uint16_t * weight0, const __m512* bias, const __m512* params, float * buf0, float * buf1, uint8_t * dst1, int last )
495+ size_t dstS, size_t dstC, const uint16_t * weight0, const __m512* bias, const __m512* params, float * buf0, float * buf1, uint8_t * dst1)
496496 {
497497 int dB = (int )a.microD , dD = int (p.dstC * a.elem ), dS = (int )a.bufK , strideB = dB * 4 , dW = (int )a.microD , strideW = dW * 4 ;
498498 int stepS = 32 , strideS = dS * 2 ;
@@ -556,62 +556,39 @@ namespace Simd
556556 _tile_stored (2 , buf1 + 16 * dB + 0 , strideB);
557557 _tile_dpbf16ps (3 , 5 , 7 );
558558 _tile_stored (3 , buf1 + 16 * dB + F, strideB);
559- if (last)
560- {
561- if (term == Term16bLast16b)
562- {
563- __mmask32 tailD = TailMask32 (dstC);
564- size_t ds = 0 , dstS8 = dstS & (~7 );
565- for (; ds < dstS8; ds += 8 )
566- Apply16b2x8<type, 1 >(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
567- for (; ds < dstS; ++ds)
568- Apply16b2<type, 1 >(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
569- }
570- if (term == Term16bLast32f)
571- {
572- __mmask16 tailD = TailMask16 (dstC - F);
573- size_t ds = 0 ;
574- for (; ds < dstS; ++ds)
575- Apply32f2<type, 1 >(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
576- }
577- }
578559 }
579560
580561 template <Term16bType term, SimdConvolutionActivationType type, int prev> void Convolution16bNhwcGemm_N32x32b (const uint16_t * src0, const ConvParam& p, const AlgParam& a,
581562 size_t dstS, size_t dstC, const uint16_t * weight0, const __m512* bias, const __m512* params, float * buf, uint8_t * dst)
582563 {
583564 int dB = (int )a.microD , dD = int (p.dstC * a.elem ), dW = (int )a.microD , dS = (int )a.bufK ;
584- // const uint16_t* src1 = src0 + 16 * dS;
585565 float * buf0 = buf, * buf1 = buf + 32 * dB;
586566
587567 size_t ds = 0 ;
588- Convolution16bNhwcGemm_32x32b<term, type, 0 >(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst, dstS <= 32 ), ds += 32 ;
568+ Convolution16bNhwcGemm_32x32b<term, type, 0 >(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst), ds += 32 ;
589569 for (; ds < dstS; ds += 32 )
590570 {
591571 Swap (buf0, buf1);
592- // Convolution16bNhwcGemm_32x32b<term, type, 0>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, 1);
593- Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS);
572+ Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD);
573+ }
574+ uint8_t * dst1 = dst + (ds - 32 ) * dD;
575+ dstS -= ds - 32 ;
576+ if (term == Term16bLast16b)
577+ {
578+ __mmask32 tailD = TailMask32 (dstC);
579+ size_t ds = 0 , dstS8 = dstS & (~7 );
580+ for (; ds < dstS8; ds += 8 )
581+ Apply16b2x8<type, 1 >(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
582+ for (; ds < dstS; ++ds)
583+ Apply16b2<type, 1 >(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
584+ }
585+ if (term == Term16bLast32f)
586+ {
587+ __mmask16 tailD = TailMask16 (dstC - F);
588+ size_t ds = 0 ;
589+ for (; ds < dstS; ++ds)
590+ Apply32f2<type, 1 >(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
594591 }
595-
596- // if (last)
597- // {
598- // if (last && term == Term16bLast16b)
599- // {
600- // __mmask32 tailD = TailMask32(dstC);
601- // size_t ds = 0, dstS8 = dstS & (~7);
602- // for (; ds < dstS8; ds += 8)
603- // Apply16b2x8<type, 1>(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
604- // for (; ds < dstS; ++ds)
605- // Apply16b2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
606- // }
607- // if (last && term == Term16bLast32f)
608- // {
609- // __mmask16 tailD = TailMask16(dstC - F);
610- // size_t ds = 0;
611- // for (; ds < dstS; ++ds)
612- // Apply32f2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
613- // }
614- // }
615592 }
616593
617594 template <Term16bType term, SimdConvolutionActivationType type, int stream, int flush> void Convolution16bNhwcGemm_32x16 (const uint16_t * src0, const ConvParam& p, const AlgParam& a,
0 commit comments