Skip to content

Commit 256a7ee

Browse files
committed
+add AMX-BF16 optimizations of class SynetConvolution16bNhwcGemmV1 (part 3).
1 parent 227f9ab commit 256a7ee

3 files changed

Lines changed: 26 additions & 17 deletions

File tree

src/Simd/SimdAmxBf16SynetConvolution16bNhwcGemmV1.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,6 @@ namespace Simd
434434
_tile_zero(1);
435435
_tile_zero(2);
436436
_tile_zero(3);
437-
438437
int srcC32 = (int)a.bufK - 32, sc = 0;
439438
if(stream)
440439
_tile_stream_loadd(4, src0, strideS);
@@ -504,7 +503,6 @@ namespace Simd
504503
_tile_zero(1);
505504
_tile_zero(2);
506505
_tile_zero(3);
507-
508506
int srcC32 = (int)a.bufK - 32, sc = 0;
509507
_tile_stream_loadd(4, src0, strideS);
510508
_tile_loadd(6, weight0 + sc * dW, strideW);
@@ -513,14 +511,19 @@ namespace Simd
513511
__mmask32 tailD = TailMask32(dstC);
514512
for (uint8_t* dst0 = dst1 - 32 * dD; dst0 < dst1; src1 += stepS, dst0 += prev * dD, buf0 += prev * dB)
515513
{
514+
_tile_loadd(7, weight1 + sc * dW, strideW);
516515
if (term == Term16bLast16b)
517516
{
518-
for (size_t ds = 0; ds < prev; ++ds)
517+
for (size_t ds = 0; ds < prev / 2; ++ds)
519518
Apply16b2<type, 1>(dst0 + ds * dD, buf0 + ds * dB, bias, params, tailD);
520-
}
521-
_tile_loadd(7, weight1 + sc * dW, strideW);
522-
_tile_stream_loadd(5, src1, strideS);
519+
}
523520
_tile_dpbf16ps(0, 4, 6);
521+
_tile_stream_loadd(5, src1, strideS);
522+
if (term == Term16bLast16b)
523+
{
524+
for (size_t ds = prev / 2; ds < prev; ++ds)
525+
Apply16b2<type, 1>(dst0 + ds * dD, buf0 + ds * dB, bias, params, tailD);
526+
}
524527
_tile_dpbf16ps(1, 4, 7);
525528
src0 += stepS;
526529
_tile_stream_loadd(4, src0, strideS);
@@ -533,8 +536,8 @@ namespace Simd
533536
for (; sc < srcC32; src1 += stepS)
534537
{
535538
_tile_loadd(7, weight1 + sc * dW, strideW);
536-
_tile_stream_loadd(5, src1, strideS);
537539
_tile_dpbf16ps(0, 4, 6);
540+
_tile_stream_loadd(5, src1, strideS);
538541
_tile_dpbf16ps(1, 4, 7);
539542
src0 += stepS;
540543
_tile_stream_loadd(4, src0, strideS);
@@ -553,7 +556,6 @@ namespace Simd
553556
_tile_stored(2, buf1 + 16 * dB + 0, strideB);
554557
_tile_dpbf16ps(3, 5, 7);
555558
_tile_stored(3, buf1 + 16 * dB + F, strideB);
556-
std::cout << " last " << last << " prev " << prev << " buf0 " << buf0 << " buf1 " << buf1 << std::endl;
557559
if (last)
558560
{
559561
if (term == Term16bLast16b)
@@ -584,11 +586,11 @@ namespace Simd
584586

585587
size_t ds = 0;
586588
Convolution16bNhwcGemm_32x32b<term, type, 0>(src0, p, a, dstS, dstC, weight0, bias, params, buf0, buf1, dst, dstS <= 32), ds += 32;
587-
Swap(buf0, buf1);
588589
for (; ds < dstS; ds += 32)
589590
{
591+
Swap(buf0, buf1);
590592
//Convolution16bNhwcGemm_32x32b<term, type, 0>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, 1);
591-
Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS ? 1 : 0);
593+
Convolution16bNhwcGemm_32x32b<term, type, prev>(src0 + ds * dS, p, a, dstS - ds, dstC, weight0, bias, params, buf0, buf1, dst + ds * dD, ds + 32 >= dstS);
592594
}
593595

594596
//if (last)
@@ -802,7 +804,14 @@ namespace Simd
802804
if (dC > F)
803805
{
804806
for (; i < nn8; i += n * 8)
805-
Convolution16bNhwcGemm_N32x32b<term, type, 4>(s + i * dS, p, a, n * 8, dC, weight, _bias, _params, buf, d + i * dD);
807+
{
808+
if (dS > 512)
809+
Convolution16bNhwcGemm_N32x32b<term, type, 2>(s + i * dS, p, a, n * 8, dC, weight, _bias, _params, buf, d + i * dD);
810+
else if (dS > 256)
811+
Convolution16bNhwcGemm_N32x32b<term, type, 4>(s + i * dS, p, a, n * 8, dC, weight, _bias, _params, buf, d + i * dD);
812+
else
813+
Convolution16bNhwcGemm_N32x32b<term, type, 8>(s + i * dS, p, a, n * 8, dC, weight, _bias, _params, buf, d + i * dD);
814+
}
806815
for (; i < nn; i += n)
807816
body_2(s + i * dS, p, a, n, dC, weight, _bias, _params, buf, d + i * dD);
808817
if (m)

src/Simd/SimdBaseSynetConvolution16bNhwcGemmV1.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ namespace Simd
139139
src += _stepS;
140140
dst += _stepD;
141141
}
142-
//std::cout << SimdVersion() << std::endl;
143142
}
144143

145144
void SynetConvolution16bNhwcGemmV1::Forward(const uint8_t* src, uint16_t* buf, float* sum, uint8_t* dst)
@@ -181,7 +180,7 @@ namespace Simd
181180

182181
bool SynetConvolution16bNhwcGemmV1::Preferable(const ConvParam& p)
183182
{
184-
return p.trans != 0 && p.group == 1 && Simd::Aligned(p.dstW * p.dstH, 32) && Simd::Aligned(p.dstC, 32) && p.srcC >= 256 && p.srcC <= 512 && 0;
183+
return p.trans != 0 && p.group == 1 && Simd::Aligned(p.dstW * p.dstH, 32) && Simd::Aligned(p.dstC, 32) && p.srcC >= 256 && p.srcC <= 512 && 1;
185184
}
186185
}
187186
#endif

src/Test/TestSynetConvolution16b.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ namespace Test
7979
TEST_LOG_SS(Info, "Test [" << f1.desc << " & " << f2.desc << "].");
8080

8181
const SimdConvolutionParameters& c = p.conv;
82-
srand(1);
82+
srand(0);
8383
Tensor32f weight(p.WeightShape());
8484
FillRandom(weight.Data(), weight.Size(), -10.0, 10.0f);
8585

8686
Tensor32f bias({ c.dstC });
87-
FillRandom(bias.Data(), bias.Size(), -1.0, 1.0f);
87+
FillRandom(bias.Data(), bias.Size(), -0.0, 0.0f);
8888

8989
Tensor32f params({ c.dstC });
9090
FillRandom(params.Data(), params.Size(), 0.0f, 1.0f);
@@ -109,7 +109,7 @@ namespace Test
109109

110110
Tensor32f src32f(p.SrcShape(), p.conv.srcF), dst32f1(p.DstShape(), p.conv.dstF), dst32f2(p.DstShape(), p.conv.dstF), buf32f;
111111
Tensor16u src16u(p.SrcShape(), p.conv.srcF), dst16u1(p.DstShape(), p.conv.dstF), dst16u2(p.DstShape(), p.conv.dstF), buf16u;
112-
FillRandom(src32f.Data(), src32f.Size(), -1.0, -1.0f);
112+
FillRandom(src32f.Data(), src32f.Size(), -1.0, 1.0f);
113113

114114
SimdFloat32ToBFloat16(src32f.Data(), src32f.Size(), src16u.Data());
115115

@@ -435,8 +435,9 @@ namespace Test
435435
#if 1
436436
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 512, 512, 64, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
437437
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 128, 256, 256, 128, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
438-
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 4, 128, 256, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
438+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 288, 128, 128, 256, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
439439
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 512, 64, 64, 512, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
440+
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 768, 64, 64, 512, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
440441
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 1024, 32, 32, 1024, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
441442
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 2048, 16, 16, 2048, _1, _1, _1, _0, _0, 1, aId, tT, b16, b16), c, f1, f2);
442443
#endif

0 commit comments

Comments
 (0)