@@ -434,7 +434,6 @@ namespace Simd
434434 _tile_zero (1 );
435435 _tile_zero (2 );
436436 _tile_zero (3 );
437-
438437 int srcC32 = (int )a.bufK - 32 , sc = 0 ;
439438 if (stream)
440439 _tile_stream_loadd (4 , src0, strideS);
@@ -504,7 +503,6 @@ namespace Simd
504503 _tile_zero (1 );
505504 _tile_zero (2 );
506505 _tile_zero (3 );
507-
508506 int srcC32 = (int )a.bufK - 32 , sc = 0 ;
509507 _tile_stream_loadd (4 , src0, strideS);
510508 _tile_loadd (6 , weight0 + sc * dW, strideW);
@@ -513,14 +511,19 @@ namespace Simd
513511 __mmask32 tailD = TailMask32 (dstC);
514512 for (uint8_t * dst0 = dst1 - 32 * dD; dst0 < dst1; src1 += stepS, dst0 += prev * dD, buf0 += prev * dB)
515513 {
514+ _tile_loadd (7 , weight1 + sc * dW, strideW);
516515 if (term == Term16bLast16b)
517516 {
518- for (size_t ds = 0 ; ds < prev; ++ds)
517+ for (size_t ds = 0 ; ds < prev / 2 ; ++ds)
519518 Apply16b2<type, 1 >(dst0 + ds * dD, buf0 + ds * dB, bias, params, tailD);
520- }
521- _tile_loadd (7 , weight1 + sc * dW, strideW);
522- _tile_stream_loadd (5 , src1, strideS);
519+ }
523520 _tile_dpbf16ps (0 , 4 , 6 );
521+ _tile_stream_loadd (5 , src1, strideS);
522+ if (term == Term16bLast16b)
523+ {
524+ for (size_t ds = prev / 2 ; ds < prev; ++ds)
525+ Apply16b2<type, 1 >(dst0 + ds * dD, buf0 + ds * dB, bias, params, tailD);
526+ }
524527 _tile_dpbf16ps (1 , 4 , 7 );
525528 src0 += stepS;
526529 _tile_stream_loadd (4 , src0, strideS);
@@ -533,8 +536,8 @@ namespace Simd
533536 for (; sc < srcC32; src1 += stepS)
534537 {
535538 _tile_loadd (7 , weight1 + sc * dW, strideW);
536- _tile_stream_loadd (5 , src1, strideS);
537539 _tile_dpbf16ps (0 , 4 , 6 );
540+ _tile_stream_loadd (5 , src1, strideS);
538541 _tile_dpbf16ps (1 , 4 , 7 );
539542 src0 += stepS;
540543 _tile_stream_loadd (4 , src0, strideS);
@@ -553,7 +556,6 @@ namespace Simd
553556 _tile_stored (2 , buf1 + 16 * dB + 0 , strideB);
554557 _tile_dpbf16ps (3 , 5 , 7 );
555558 _tile_stored (3 , buf1 + 16 * dB + F, strideB);
556- std::cout << " last " << last << " prev " << prev << " buf0 " << buf0 << " buf1 " << buf1 << std::endl;
557559 if (last)
558560 {
559561 if (term == Term16bLast16b)
@@ -584,11 +586,11 @@ namespace Simd
584586
585587 size_t ds = 0 ;
586588 Convolution16bNhwcGemm_32x32b<term, type, 0 >(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst, dstS <= 32 ), ds += 32 ;
587- Swap (buf0, buf1);
588589 for (; ds < dstS; ds += 32 )
589590 {
591+ Swap (buf0, buf1);
590592 // Convolution16bNhwcGemm_32x32b<term, type, 0>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, 1);
591- Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS ? 1 : 0 );
593+ Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS);
592594 }
593595
594596 // if (last)
@@ -802,7 +804,14 @@ namespace Simd
802804 if (dC > F)
803805 {
804806 for (; i < nn8; i += n * 8 )
805- Convolution16bNhwcGemm_N32x32b<term, type, 4 >(s + i * dS, p, a, n * 8 , dC, weight, _bias, _params, buf, d + i * dD);
807+ {
808+ if (dS > 512 )
809+ Convolution16bNhwcGemm_N32x32b<term, type, 2 >(s + i * dS, p, a, n * 8 , dC, weight, _bias, _params, buf, d + i * dD);
810+ else if (dS > 256 )
811+ Convolution16bNhwcGemm_N32x32b<term, type, 4 >(s + i * dS, p, a, n * 8 , dC, weight, _bias, _params, buf, d + i * dD);
812+ else
813+ Convolution16bNhwcGemm_N32x32b<term, type, 8 >(s + i * dS, p, a, n * 8 , dC, weight, _bias, _params, buf, d + i * dD);
814+ }
806815 for (; i < nn; i += n)
807816 body_2 (s + i * dS, p, a, n, dC, weight, _bias, _params, buf, d + i * dD);
808817 if (m)
0 commit comments