@@ -183,76 +183,87 @@ void
183183}
184184
185185//
186- // BF16 Pointwise Convolution Kernel
186+ // BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
187187//
188188void MLASCALL
189189MlasConvPointwiseBf16KernelNeon (
190190 const float * Input,
191191 const float * Filter,
192192 float * Output,
193193 size_t StrideWidth,
194- size_t InputChannels, /* numChannels/BlockSize = 16/16 = 1 */
194+ size_t InputChannels,
195195 size_t FilterCount,
196- size_t /* InputStride*/ ,
196+ size_t InputStride,
197197 size_t FilterStride,
198198 size_t OutputStride,
199199 size_t OutputCount,
200200 const float * Bias,
201201 unsigned KernelFlags
202202)
203203{
204+ const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0 ;
204205 const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0 ;
206+ const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
205207
206208 const size_t StrideWidthElements = StrideWidth / sizeof (float );
209+ const size_t InputStrideElements = InputStride / sizeof (float );
207210 const size_t FilterStrideElements = FilterStride / sizeof (float );
208211 const size_t OutputStrideElements = OutputStride / sizeof (float );
209212
210- const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
211- const float32x4_t ReluMask = vreinterpretq_f32_s32 (MlasBroadcastInt32x4 (-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));
213+ // SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false),
214+ // pre-add bias to existing output before the GEMM operations.
215+ if (BiasAddition && AccumulateOutput) {
216+ for (size_t f = 0 ; f < FilterCount; f++) {
217+ float * output = Output + f * OutputStrideElements;
218+ const float32x4_t b0 = MlasLoadFloat32x4 (&Bias[f * BlockSize]);
219+ const float32x4_t b1 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 4 ]);
220+ const float32x4_t b2 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 8 ]);
221+ const float32x4_t b3 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 12 ]);
222+ for (size_t i = 0 ; i < OutputCount; i++) {
223+ MlasStoreFloat32x4 (&output[i * BlockSize], MlasAddFloat32x4 (b0, MlasLoadFloat32x4 (&output[i * BlockSize])));
224+ MlasStoreFloat32x4 (&output[i * BlockSize + 4 ], MlasAddFloat32x4 (b1, MlasLoadFloat32x4 (&output[i * BlockSize + 4 ])));
225+ MlasStoreFloat32x4 (&output[i * BlockSize + 8 ], MlasAddFloat32x4 (b2, MlasLoadFloat32x4 (&output[i * BlockSize + 8 ])));
226+ MlasStoreFloat32x4 (&output[i * BlockSize + 12 ], MlasAddFloat32x4 (b3, MlasLoadFloat32x4 (&output[i * BlockSize + 12 ])));
227+ }
228+ }
229+ }
212230
213- std::vector<MLAS_SBGEMM_DATA_PARAMS> gemm_params (FilterCount);
231+ // Build SBGEMM params for all (filter, input_channel) combinations.
232+ // FilterCount <= 4, InputChannels <= 8, so max 32 elements.
233+ // Bias is set on all elements but SBGEMM only uses it when ZeroMode=true.
234+ MLAS_SBGEMM_DATA_PARAMS gemm_params[32 ];
214235
236+ size_t idx = 0 ;
215237 for (size_t f = 0 ; f < FilterCount; f++) {
216238 const float * filter = Filter + f * FilterStrideElements;
217239 float * output = Output + f * OutputStrideElements;
218-
219- gemm_params[f].A = Input;
220- gemm_params[f].B = filter;
221- gemm_params[f].C = output;
222- gemm_params[f].lda = StrideWidthElements;
223- gemm_params[f].ldb = BlockSize;
224- gemm_params[f].ldc = BlockSize;
225- gemm_params[f].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr ;
226- gemm_params[f].AIsfp32 = true ;
227- gemm_params[f].BIsfp32 = true ;
228- gemm_params[f].OutputProcessor = nullptr ;
240+ for (size_t ic = 0 ; ic < InputChannels; ic++, idx++) {
241+ gemm_params[idx].A = Input + ic * InputStrideElements;
242+ gemm_params[idx].B = filter + ic * BlockSize * BlockSize;
243+ gemm_params[idx].C = output;
244+ gemm_params[idx].lda = StrideWidthElements;
245+ gemm_params[idx].ldb = BlockSize;
246+ gemm_params[idx].ldc = BlockSize;
247+ gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr ;
248+ gemm_params[idx].AIsfp32 = true ;
249+ gemm_params[idx].BIsfp32 = true ;
250+ gemm_params[idx].ZeroMode = (ic == 0 ) && !AccumulateOutput;
251+ gemm_params[idx].OutputProcessor = nullptr ;
252+ }
229253 }
230254
231- MlasSBGemmBatch (OutputCount, BlockSize, InputChannels * BlockSize, FilterCount, gemm_params.data (), nullptr );
232-
233- for (size_t f = 0 ; f < FilterCount; f++) {
234- float * output = Output + f * OutputStrideElements;
235-
236- for (size_t output_idx = 0 ; output_idx < OutputCount; output_idx++) {
237- float32x4_t Accumulator0 = MlasLoadFloat32x4 (&output[output_idx * BlockSize]);
238- float32x4_t Accumulator1 = MlasLoadFloat32x4 (&output[output_idx * BlockSize + 4 ]);
239- float32x4_t Accumulator2 = MlasLoadFloat32x4 (&output[output_idx * BlockSize + 8 ]);
240- float32x4_t Accumulator3 = MlasLoadFloat32x4 (&output[output_idx * BlockSize + 12 ]);
241-
242- float32x4_t Relu0 = MlasMaximumFloat32x4 (Accumulator0, ZeroVector);
243- float32x4_t Relu1 = MlasMaximumFloat32x4 (Accumulator1, ZeroVector);
244- float32x4_t Relu2 = MlasMaximumFloat32x4 (Accumulator2, ZeroVector);
245- float32x4_t Relu3 = MlasMaximumFloat32x4 (Accumulator3, ZeroVector);
246-
247- Accumulator0 = MlasBlendFloat32x4 (Accumulator0, Relu0, ReluMask);
248- Accumulator1 = MlasBlendFloat32x4 (Accumulator1, Relu1, ReluMask);
249- Accumulator2 = MlasBlendFloat32x4 (Accumulator2, Relu2, ReluMask);
250- Accumulator3 = MlasBlendFloat32x4 (Accumulator3, Relu3, ReluMask);
251-
252- MlasStoreFloat32x4 (&output[output_idx * BlockSize], Accumulator0);
253- MlasStoreFloat32x4 (&output[output_idx * BlockSize + 4 ], Accumulator1);
254- MlasStoreFloat32x4 (&output[output_idx * BlockSize + 8 ], Accumulator2);
255- MlasStoreFloat32x4 (&output[output_idx * BlockSize + 12 ], Accumulator3);
255+ MlasSBGemmBatch (OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr );
256+
257+ if (ReluActivation) {
258+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
259+ for (size_t f = 0 ; f < FilterCount; f++) {
260+ float * output = Output + f * OutputStrideElements;
261+ for (size_t i = 0 ; i < OutputCount; i++) {
262+ MlasStoreFloat32x4 (&output[i * BlockSize], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize]), ZeroVector));
263+ MlasStoreFloat32x4 (&output[i * BlockSize + 4 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 4 ]), ZeroVector));
264+ MlasStoreFloat32x4 (&output[i * BlockSize + 8 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 8 ]), ZeroVector));
265+ MlasStoreFloat32x4 (&output[i * BlockSize + 12 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 12 ]), ZeroVector));
266+ }
256267 }
257268 }
258269}
0 commit comments