Skip to content

Commit 639c1aa

Browse files
committed
*improve AMX-BF16 optimizations of class SynetConvolution16bNhwcGemm (part 1: Convolution16bNhwcGemm_32x32).
1 parent d3a4b53 commit 639c1aa

3 files changed

Lines changed: 35 additions & 30 deletions

File tree

docs/2025.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ <h5>Improving</h5>
4848
<li>SSE4.1 optimizations of class ResizerBf16Bilinear.</li>
4949
<li>SSE4.1, AVX2, AVX-512BW optimizations of class ResizerFloatBilinear.</li>
5050
<li>AMX-BF16 optimizations of class SynetConvolution16bNchwGemm.</li>
51+
<li>AMX-BF16 optimizations of class SynetConvolution16bNhwcGemm.</li>
5152
<</ul>
5253

5354
<h4>Test framework</h4>

src/Simd/SimdAmxBf16SynetConvolution16bNhwcGemm.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -228,33 +228,16 @@ namespace Simd
228228

229229
//-----------------------------------------------------------------------------------------
230230

231-
template<Term16bType term, SimdConvolutionActivationType type> void Convolution16bNhwcGemm_32x32(const uint16_t* src0, const ConvParam& p, const AlgParam& a,
231+
template<Term16bType term, SimdConvolutionActivationType type, int cfg> void Convolution16bNhwcGemm_32x32(const uint16_t* src0, const ConvParam& p, const AlgParam& a,
232232
size_t srcC, size_t dstS, size_t dstC, int zero, const uint16_t* weight0, const __m512* bias, const __m512* params, float* buf, uint8_t* dst)
233233
{
234234
int dB = (int)a.dB, dD = int(p.dstC * a.elem), dS = (int)a.bufK, strideB = dB * 4, strideW = 64;
235235
int stepS = a.reorderType ? 512 : 32, strideS = a.reorderType ? 64 : dS * 2;
236236
const uint16_t* src1 = src0 + 16 * dS;
237237
const uint16_t* weight1 = weight0 + a.bufK * F;
238238

239-
TileConf conf;
240-
conf.rows[0] = 16;
241-
conf.rows[1] = 16;
242-
conf.rows[2] = uint8_t(dstS - 16);
243-
conf.rows[3] = uint8_t(dstS - 16);
244-
conf.rows[4] = 16;
245-
conf.rows[5] = uint8_t(dstS - 16);
246-
conf.rows[6] = 16;
247-
conf.rows[7] = 16;
248-
conf.colsb[0] = 64;
249-
conf.colsb[1] = uint16_t((dstC - 16) * 4);
250-
conf.colsb[2] = 64;
251-
conf.colsb[3] = uint16_t((dstC - 16) * 4);
252-
conf.colsb[4] = 64;
253-
conf.colsb[5] = 64;
254-
conf.colsb[6] = 64;
255-
conf.colsb[7] = uint16_t((dstC - 16) * 4);
256-
_tile_loadconfig(&conf);
257-
239+
if (cfg)
240+
SetTileConf2x2(dstS, dstC);
258241
if (zero)
259242
{
260243
_tile_zero(0);
@@ -269,17 +252,30 @@ namespace Simd
269252
_tile_stream_loadd(2, buf + 16 * dB + 0, strideB);
270253
_tile_stream_loadd(3, buf + 16 * dB + F, strideB);
271254
}
272-
for (size_t sc = 0; sc < srcC; sc += 32, src0 += stepS, src1 += stepS)
255+
256+
size_t srcC32 = srcC - 32, sc = 0;
257+
_tile_stream_loadd(4, src0, strideS);
258+
_tile_loadd(6, weight0 + sc * 16, strideW);
259+
for (; sc < srcC32; src1 += stepS)
273260
{
274-
_tile_stream_loadd(4, src0, strideS);
275-
_tile_loadd(6, weight0 + sc * 16, strideW);
276-
_tile_dpbf16ps(0, 4, 6);
277261
_tile_loadd(7, weight1 + sc * 16, strideW);
278-
_tile_dpbf16ps(1, 4, 7);
279262
_tile_stream_loadd(5, src1, strideS);
263+
_tile_dpbf16ps(0, 4, 6);
264+
_tile_dpbf16ps(1, 4, 7);
265+
src0 += stepS;
266+
_tile_stream_loadd(4, src0, strideS);
280267
_tile_dpbf16ps(2, 5, 6);
268+
sc += 32;
269+
_tile_loadd(6, weight0 + sc * 16, strideW);
281270
_tile_dpbf16ps(3, 5, 7);
282271
}
272+
_tile_loadd(7, weight1 + sc * 16, strideW);
273+
_tile_stream_loadd(5, src1, strideS);
274+
_tile_dpbf16ps(0, 4, 6);
275+
_tile_dpbf16ps(1, 4, 7);
276+
_tile_dpbf16ps(2, 5, 6);
277+
_tile_dpbf16ps(3, 5, 7);
278+
283279
_tile_stored(0, buf + 0, strideB);
284280
_tile_stored(1, buf + F, strideB);
285281
_tile_stored(2, buf + 16 * dB + 0, strideB);
@@ -470,8 +466,8 @@ namespace Simd
470466
{
471467
size_t n = 32, n1 = dstH * p.dstW, nn = AlignLoAny(n1, n), m = n1 - nn, dW = a.bufK * DF;
472468
size_t dB = a.macroK < a.bufK ? a.dB : 0, dD = p.dstC * a.elem, dS = a.bufK;
473-
Convolution16bNhwcGemmPtr body_2 = Convolution16bNhwcGemm_32x32<term, type>;
474-
Convolution16bNhwcGemmPtr tail_2 = m > 16 ? Convolution16bNhwcGemm_32x32<term, type> : Convolution16bNhwcGemm_16x32<term, type>;
469+
Convolution16bNhwcGemmPtr body_2 = Convolution16bNhwcGemm_32x32<term, type, 0>;
470+
Convolution16bNhwcGemmPtr tail_2 = m > 16 ? Convolution16bNhwcGemm_32x32<term, type, 1> : Convolution16bNhwcGemm_16x32<term, type>;
475471
Convolution16bNhwcGemmPtr body_1 = Convolution16bNhwcGemm_32x16<term, type>;
476472
Convolution16bNhwcGemmPtr tail_1 = m > 16 ? Convolution16bNhwcGemm_32x16<term, type> : Convolution16bNhwcGemm_16x16<term, type>;
477473

@@ -482,6 +478,7 @@ namespace Simd
482478
type == SimdConvolutionActivationHardSigmoid)
483479
_params[1] = _mm512_set1_ps(params[1]);
484480

481+
SetTileConfFull();
485482
for (size_t dc = 0; dc < dstC; dc += DF)
486483
{
487484
size_t dC = Simd::Min(DF, dstC - dc);
@@ -498,6 +495,8 @@ namespace Simd
498495
size_t i = 0;
499496
if (dC > F)
500497
{
498+
if(m)
499+
SetTileConfFull();
501500
for (; i < nn; i += n, s += n * dS, b += n * dB, d += n * dD)
502501
body_2(s, p, a, srcC, n, dC, zero, weight, _bias, _params, b, d);
503502
if (m)

src/Test/TestSynetConvolution16b.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,21 @@ namespace Test
316316
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 32, 192, 256, 256, _1, _1, _1, _0, _0, 1, aRe, tT, f32, f32), c, f1, f2);
317317
#endif
318318
#if 1
319+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 3, 3, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
320+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 5, 5, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
319321
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 13, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
322+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 16, 16, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
323+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 383, 13, 13, 1150, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
324+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 14, 1155, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
320325
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 13, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, f32, f32), c, f1, f2);
321-
322-
#if 1
326+
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 1024, 13, 13, 1152, _1, _1, _1, _0, _0, 1, aRe, tT, b16, b16), c, f1, f2);
327+
#endif
328+
#if 0
323329
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 16, 16, 1152, _1, _1, _1, _0, _0, 1, aRe, tF, b16, b16), c, f1, f2);
324330
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 14, 1150, _1, _1, _1, _0, _0, 1, aRe, tF, b16, b16), c, f1, f2);
325331
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 13, 1152, _1, _1, _1, _0, _0, 1, aRe, tF, f32, f32), c, f1, f2);
326332
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 383, 13, 13, 1155, _1, _1, _1, _0, _0, 1, aRe, tF, b16, b16), c, f1, f2);
327333
#endif
328-
#endif
329334
#else
330335
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 384, 13, 13, 1152, _1, _1, _1, _0, _0, 1, aRe, tF, b16, b16), c, f1, f2);
331336
#endif

0 commit comments

Comments
 (0)