Skip to content

Commit 6cbbd25

Browse files
Extend MlasSBGemmBatch to accept ZeroMode
1 parent 8d76b80 commit 6cbbd25

File tree

3 files changed

+61
-48
lines changed

3 files changed

+61
-48
lines changed

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19551955
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
19561956
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
19571957
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
1958+
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
19581959
};
19591960

19601961
/**

onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -183,76 +183,87 @@ void
183183
}
184184

185185
//
186-
// BF16 Pointwise Convolution Kernel
186+
// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
187187
//
188188
void MLASCALL
189189
MlasConvPointwiseBf16KernelNeon(
190190
const float* Input,
191191
const float* Filter,
192192
float* Output,
193193
size_t StrideWidth,
194-
size_t InputChannels, /* numChannels/BlockSize = 16/16 = 1 */
194+
size_t InputChannels,
195195
size_t FilterCount,
196-
size_t /*InputStride*/,
196+
size_t InputStride,
197197
size_t FilterStride,
198198
size_t OutputStride,
199199
size_t OutputCount,
200200
const float* Bias,
201201
unsigned KernelFlags
202202
)
203203
{
204+
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
204205
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
206+
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
205207

206208
const size_t StrideWidthElements = StrideWidth / sizeof(float);
209+
const size_t InputStrideElements = InputStride / sizeof(float);
207210
const size_t FilterStrideElements = FilterStride / sizeof(float);
208211
const size_t OutputStrideElements = OutputStride / sizeof(float);
209212

210-
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
211-
const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));
213+
// SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false),
214+
// pre-add bias to existing output before the GEMM operations.
215+
if (BiasAddition && AccumulateOutput) {
216+
for (size_t f = 0; f < FilterCount; f++) {
217+
float* output = Output + f * OutputStrideElements;
218+
const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]);
219+
const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]);
220+
const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]);
221+
const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]);
222+
for (size_t i = 0; i < OutputCount; i++) {
223+
MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize])));
224+
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4])));
225+
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8])));
226+
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12])));
227+
}
228+
}
229+
}
212230

213-
std::vector<MLAS_SBGEMM_DATA_PARAMS> gemm_params(FilterCount);
231+
// Build SBGEMM params for all (filter, input_channel) combinations.
232+
// FilterCount <= 4, InputChannels <= 8, so max 32 elements.
233+
// Bias is set on all elements but SBGEMM only uses it when ZeroMode=true.
234+
MLAS_SBGEMM_DATA_PARAMS gemm_params[32];
214235

236+
size_t idx = 0;
215237
for (size_t f = 0; f < FilterCount; f++) {
216238
const float* filter = Filter + f * FilterStrideElements;
217239
float* output = Output + f * OutputStrideElements;
218-
219-
gemm_params[f].A = Input;
220-
gemm_params[f].B = filter;
221-
gemm_params[f].C = output;
222-
gemm_params[f].lda = StrideWidthElements;
223-
gemm_params[f].ldb = BlockSize;
224-
gemm_params[f].ldc = BlockSize;
225-
gemm_params[f].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr;
226-
gemm_params[f].AIsfp32 = true;
227-
gemm_params[f].BIsfp32 = true;
228-
gemm_params[f].OutputProcessor = nullptr;
240+
for (size_t ic = 0; ic < InputChannels; ic++, idx++) {
241+
gemm_params[idx].A = Input + ic * InputStrideElements;
242+
gemm_params[idx].B = filter + ic * BlockSize * BlockSize;
243+
gemm_params[idx].C = output;
244+
gemm_params[idx].lda = StrideWidthElements;
245+
gemm_params[idx].ldb = BlockSize;
246+
gemm_params[idx].ldc = BlockSize;
247+
gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr;
248+
gemm_params[idx].AIsfp32 = true;
249+
gemm_params[idx].BIsfp32 = true;
250+
gemm_params[idx].ZeroMode = (ic == 0) && !AccumulateOutput;
251+
gemm_params[idx].OutputProcessor = nullptr;
252+
}
229253
}
230254

231-
MlasSBGemmBatch(OutputCount, BlockSize, InputChannels * BlockSize, FilterCount, gemm_params.data(), nullptr);
232-
233-
for (size_t f = 0; f < FilterCount; f++) {
234-
float* output = Output + f * OutputStrideElements;
235-
236-
for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) {
237-
float32x4_t Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
238-
float32x4_t Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]);
239-
float32x4_t Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]);
240-
float32x4_t Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]);
241-
242-
float32x4_t Relu0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector);
243-
float32x4_t Relu1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector);
244-
float32x4_t Relu2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector);
245-
float32x4_t Relu3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector);
246-
247-
Accumulator0 = MlasBlendFloat32x4(Accumulator0, Relu0, ReluMask);
248-
Accumulator1 = MlasBlendFloat32x4(Accumulator1, Relu1, ReluMask);
249-
Accumulator2 = MlasBlendFloat32x4(Accumulator2, Relu2, ReluMask);
250-
Accumulator3 = MlasBlendFloat32x4(Accumulator3, Relu3, ReluMask);
251-
252-
MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0);
253-
MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1);
254-
MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2);
255-
MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3);
255+
MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr);
256+
257+
if (ReluActivation) {
258+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
259+
for (size_t f = 0; f < FilterCount; f++) {
260+
float* output = Output + f * OutputStrideElements;
261+
for (size_t i = 0; i < OutputCount; i++) {
262+
MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector));
263+
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector));
264+
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector));
265+
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector));
266+
}
256267
}
257268
}
258269
}

onnxruntime/core/mlas/lib/sbgemm.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK,
112112

113113
template <typename KernelType>
114114
MLAS_FORCEINLINE void
115-
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor)
115+
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
116116
{
117117
constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides;
118118
size_t PackedStrideN = Strides.N;
@@ -131,7 +131,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size
131131
//
132132
size_t CountK;
133133
for (size_t k = 0; k < K; k += CountK) {
134-
bool ZeroMode = (k == 0);
134+
bool ZeroMode = (k == 0) && InitialZeroMode;
135135
CountK = std::min(K - k, PackedStrideK);
136136

137137
const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN;
@@ -148,7 +148,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size
148148

149149
template <typename KernelType>
150150
void
151-
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor)
151+
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
152152
{
153153
//
154154
// Compute the strides to step through slices of the input matrices.
@@ -201,7 +201,7 @@ MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_
201201
const float* pbias =
202202
((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart
203203

204-
bool ZeroMode = (k == 0);
204+
bool ZeroMode = (k == 0) && InitialZeroMode;
205205
MlasSBGemmKernel<KernelType>(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode);
206206
}
207207
if (PostProcessor != nullptr) {
@@ -249,16 +249,17 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN,
249249
const float* A = (const float*)DataParams->A + RangeStartM * lda;
250250
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
251251
const float* bias = DataParams->Bias;
252+
const bool zeroMode = DataParams->ZeroMode;
252253

253254
if (!DataParams->BIsfp32) {
254255
MlasSBGemmPackedOperation<KernelType>(
255256
RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A,
256-
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor
257+
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode
257258
);
258259
} else {
259260
const size_t ldb = DataParams->ldb;
260261
const float* B = (const float*)DataParams->B + RangeStartN;
261-
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor);
262+
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode);
262263
}
263264
}
264265

0 commit comments

Comments
 (0)