Skip to content

Commit f539d6c

Browse files
committed
+add AMX-BF16 optimizations of class SynetConvolution16bNhwcGemmV1 (part 4).
1 parent 256a7ee commit f539d6c

1 file changed

Lines changed: 21 additions & 44 deletions

File tree

src/Simd/SimdAmxBf16SynetConvolution16bNhwcGemmV1.cpp

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ namespace Simd
492492
}
493493

494494
template<Term16bType term, SimdConvolutionActivationType type, int prev> void Convolution16bNhwcGemm_32x32b(const uint16_t* src0, const ConvParam& p, const AlgParam& a,
495-
size_t dstS, size_t dstC, const uint16_t* weight0, const __m512* bias, const __m512* params, float* buf0, float* buf1, uint8_t* dst1, int last)
495+
size_t dstS, size_t dstC, const uint16_t* weight0, const __m512* bias, const __m512* params, float* buf0, float* buf1, uint8_t* dst1)
496496
{
497497
int dB = (int)a.microD, dD = int(p.dstC * a.elem), dS = (int)a.bufK, strideB = dB * 4, dW = (int)a.microD, strideW = dW * 4;
498498
int stepS = 32, strideS = dS * 2;
@@ -556,62 +556,39 @@ namespace Simd
556556
_tile_stored(2, buf1 + 16 * dB + 0, strideB);
557557
_tile_dpbf16ps(3, 5, 7);
558558
_tile_stored(3, buf1 + 16 * dB + F, strideB);
559-
if (last)
560-
{
561-
if (term == Term16bLast16b)
562-
{
563-
__mmask32 tailD = TailMask32(dstC);
564-
size_t ds = 0, dstS8 = dstS & (~7);
565-
for (; ds < dstS8; ds += 8)
566-
Apply16b2x8<type, 1>(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
567-
for (; ds < dstS; ++ds)
568-
Apply16b2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
569-
}
570-
if (term == Term16bLast32f)
571-
{
572-
__mmask16 tailD = TailMask16(dstC - F);
573-
size_t ds = 0;
574-
for (; ds < dstS; ++ds)
575-
Apply32f2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
576-
}
577-
}
578559
}
579560

580561
template<Term16bType term, SimdConvolutionActivationType type, int prev> void Convolution16bNhwcGemm_N32x32b(const uint16_t* src0, const ConvParam& p, const AlgParam& a,
581562
size_t dstS, size_t dstC, const uint16_t* weight0, const __m512* bias, const __m512* params, float* buf, uint8_t* dst)
582563
{
583564
int dB = (int)a.microD, dD = int(p.dstC * a.elem), dW = (int)a.microD, dS = (int)a.bufK;
584-
//const uint16_t* src1 = src0 + 16 * dS;
585565
float* buf0 = buf, * buf1 = buf + 32 * dB;
586566

587567
size_t ds = 0;
588-
Convolution16bNhwcGemm_32x32b<term, type, 0>(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst, dstS <= 32), ds += 32;
568+
Convolution16bNhwcGemm_32x32b<term, type, 0>(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst), ds += 32;
589569
for (; ds < dstS; ds += 32)
590570
{
591571
Swap(buf0, buf1);
592-
//Convolution16bNhwcGemm_32x32b<term, type, 0>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, 1);
593-
Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS);
572+
Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD);
573+
}
574+
uint8_t* dst1 = dst + (ds - 32) * dD;
575+
dstS -= ds - 32;
576+
if (term == Term16bLast16b)
577+
{
578+
__mmask32 tailD = TailMask32(dstC);
579+
size_t ds = 0, dstS8 = dstS & (~7);
580+
for (; ds < dstS8; ds += 8)
581+
Apply16b2x8<type, 1>(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
582+
for (; ds < dstS; ++ds)
583+
Apply16b2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
584+
}
585+
if (term == Term16bLast32f)
586+
{
587+
__mmask16 tailD = TailMask16(dstC - F);
588+
size_t ds = 0;
589+
for (; ds < dstS; ++ds)
590+
Apply32f2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
594591
}
595-
596-
//if (last)
597-
//{
598-
// if (last && term == Term16bLast16b)
599-
// {
600-
// __mmask32 tailD = TailMask32(dstC);
601-
// size_t ds = 0, dstS8 = dstS & (~7);
602-
// for (; ds < dstS8; ds += 8)
603-
// Apply16b2x8<type, 1>(dst1 + ds * dD, dD, buf1 + ds * dB, dB, bias, params, tailD);
604-
// for (; ds < dstS; ++ds)
605-
// Apply16b2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
606-
// }
607-
// if (last && term == Term16bLast32f)
608-
// {
609-
// __mmask16 tailD = TailMask16(dstC - F);
610-
// size_t ds = 0;
611-
// for (; ds < dstS; ++ds)
612-
// Apply32f2<type, 1>(dst1 + ds * dD, buf1 + ds * dB, bias, params, tailD);
613-
// }
614-
//}
615592
}
616593

617594
template<Term16bType term, SimdConvolutionActivationType type, int stream, int flush> void Convolution16bNhwcGemm_32x16(const uint16_t* src0, const ConvParam& p, const AlgParam& a,

0 commit comments

Comments
 (0)