Skip to content

Commit 8b49214

Browse files
committed
*fix bugs: Error in Base implementation, AMX-BF16 optimizations of class SynetInnerProduct16bGemmNN (part 2).
1 parent 5f460b8 commit 8b49214

3 files changed

+6
-4
lines changed

src/Simd/SimdAmxBf16SynetInnerProduct16bGemmNN.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "Simd/SimdBFloat16.h"
3333
#include "Simd/SimdCopy.h"
3434
#include "Simd/SimdTile.h"
35+
#include "Simd/SimdLog.h"
3536

3637
namespace Simd
3738
{
@@ -291,7 +292,7 @@ namespace Simd
291292
size_t M, size_t N, size_t K, int update, const uint16_t* B, float* C, int post, const float* bias, uint8_t* dst)
292293
{
293294
size_t m = 32, m1 = M, mm = AlignLo(m1, m), t = m1 - mm;
294-
size_t dA = a.aK, dB = a.bK * DF, dC = (a.macroK < a.aK || a.macroN != a.aN || a.macroM != a.aM) ? a.cN : 0, dD = p.N * a.eC;
295+
size_t dA = a.aK, dB = a.bK * DF, dC = (a.macroK < a.aK || a.macroN != a.aN || a.macroM != a.aM || C == (float*)dst) ? a.cN : 0, dD = p.N * a.eC;
295296
__m512 _bias[2];
296297
if (mm)
297298
{

src/Simd/SimdBaseSynetInnerProduct16bGemmNN.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ namespace Simd
129129
a.aM = AlignHi(p.M, a.microM);
130130
a.macroK = Simd::RestrictRange(AlignLo(L1 / a.microN / 2, a.microK), a.microK, a.aK);
131131
a.macroN = Simd::RestrictRange(AlignLo(L3 / a.macroK / 2, a.microN), a.microN, a.aN);
132-
a.macroM = Simd::RestrictRange(L2 / a.macroK / 2, a.microM, a.aM);
132+
a.macroM = Simd::RestrictRange(AlignLo(L2 / a.macroK / 2, a.microM), a.microM, a.aM);
133133
a.eA = p.typeA == SimdTensorData32f ? 4 : 2;
134134
a.eB = p.typeB == SimdTensorData32f ? 4 : 2;
135135
a.eC = p.typeC == SimdTensorData32f ? 4 : 2;

src/Test/TestSynetInnerProduct16b.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,16 @@ namespace Test
198198
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 127, 128, b16, b16, f32, f, f, t), f1, f2);
199199
#endif
200200
#if 1
201+
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(64, 512, 512, f32, f32, f32, f, t, t), f1, f2);
202+
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(64, 608, 608, f32, f32, f32, f, t, t), f1, f2);
201203
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(1824, 64, 608, f32, f32, f32, f, t, t), f1, f2);
202204
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(64, 1824, 608, f32, f32, f32, f, t, t), f1, f2);
203205
#endif
204206
#if 0
205207
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(3333, 3333, 3333, b16, b16, f32, f, f, t), f1, f2);
206208
#endif
207209
#else
208-
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(1824, 64, 608, f32, f32, f32, f, t, t), f1, f2);
209-
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(64, 1824, 608, f32, f32, f32, f, t, t), f1, f2);
210+
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(64, 512, 512, f32, f32, f32, f, t, t), f1, f2);
210211
#endif
211212

212213
return result;

0 commit comments

Comments
 (0)