@@ -24,52 +24,6 @@ Module Name:
2424
2525constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE;
2626
27- inline void MLASCALL
28- MlasRowDot (const float * Arow, const float * Brow, float * out, int index, size_t len)
29- {
30- float32x4_t acc4 = MlasBroadcastFloat32x4 (0 .f );
31-
32- size_t i = 0 ;
33- for (; i + 8 <= len; i += 8 ) {
34- float32x4_t a0 = MlasLoadFloat32x4 (Arow + i);
35- float32x4_t a1 = MlasLoadFloat32x4 (Arow + i + 4 );
36- float32x4_t b0 = MlasLoadFloat32x4 (Brow + i);
37- float32x4_t b1 = MlasLoadFloat32x4 (Brow + i + 4 );
38-
39- bfloat16x8_t a_bf16 = vcvtq_low_bf16_f32 (a0);
40- a_bf16 = vcvtq_high_bf16_f32 (a_bf16, a1);
41-
42- bfloat16x8_t b_bf16 = vcvtq_low_bf16_f32 (b0);
43- b_bf16 = vcvtq_high_bf16_f32 (b_bf16, b1);
44-
45- acc4 = vbfdotq_f32 (acc4, a_bf16, b_bf16);
46- }
47-
48- float sum = vaddvq_f32 (acc4);
49-
50- for (; i < len; i++)
51- sum += Arow[i] * Brow[i];
52-
53- out[index] = sum;
54- }
55-
56- inline void MLASCALL
57- MlasSBDotRowWise (const float * A, const float * B, size_t len, float * out)
58- {
59- float tmpA[len];
60- float tmpB[len];
61-
62- for (size_t r = 0 ; r < 16 ; r++) {
63- for (size_t j = 0 ; j < len; j++)
64- tmpA[j] = A[j * 16 + r];
65-
66- for (size_t j = 0 ; j < len; j++)
67- tmpB[j] = B[j * 16 + r];
68-
69- MlasRowDot (tmpA, tmpB, out, r, len);
70- }
71- }
72-
7327void
7428 MLASCALL
7529 MlasConvDepthwiseBf16KernelNeon (
@@ -91,97 +45,251 @@ void
9145 unsigned KernelFlags
9246 )
9347{
94- const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
95- const float32x4_t AccumulateMask = vreinterpretq_f32_s32 (MlasBroadcastInt32x4 (-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT)));
96- const float32x4_t BiasMask = vreinterpretq_f32_s32 (MlasBroadcastInt32x4 (-(KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION)));
97- const float32x4_t ReluMask = vreinterpretq_f32_s32 (MlasBroadcastInt32x4 (-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));
48+ const bool AccumulateOutput = KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT;
49+ const bool BiasAddition = KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION;
50+ const bool ReluActivation = KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION;
9851
9952 const size_t StrideWidthElements = StrideWidth / sizeof (float );
10053 const size_t DilationWidthElements = DilationWidth / sizeof (float );
10154 const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof (float );
55+ const size_t InputWidthElements = InputWidth / sizeof (float );
56+ const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
57+ const size_t KernelSize = KernelHeight * KernelWidth;
10258
10359 MLAS_UNREFERENCED_PARAMETER (InputStride);
10460
61+ // Depthwise: 16 independent channels, each doing [TotalOutputCount][KernelSize] x [KernelSize][1]
62+ // Batch all 16 channels into one MlasSBGemmBatch call
63+
64+ std::vector<float > im2col_buffer (BlockSize * TotalOutputCount * KernelSize);
65+ std::vector<float > filter_cols (BlockSize * KernelSize);
66+ std::vector<float > output_buffer (BlockSize * TotalOutputCount);
67+
68+ // Prepare filter columns: transpose [KernelSize][16] -> 16 separate [KernelSize] vectors
69+ for (size_t c = 0 ; c < BlockSize; c++) {
70+ for (size_t k = 0 ; k < KernelSize; k++) {
71+ filter_cols[c * KernelSize + k] = Filter[k * BlockSize + c];
72+ }
73+ }
74+
75+ // im2col for all channels: [c][out_idx][kpos]
76+ for (size_t c = 0 ; c < BlockSize; c++) {
77+ for (size_t out_idx = 0 ; out_idx < TotalOutputCount; out_idx++) {
78+ for (size_t kpos = 0 ; kpos < KernelSize; kpos++) {
79+ size_t kh = kpos / KernelWidth;
80+ size_t kw = kpos % KernelWidth;
81+ const float * input_ptr = Input + out_idx * StrideWidthElements +
82+ kh * DilatedInputWidthElements + kw * DilationWidthElements + c;
83+ const float * row_start = InputBase + kh * DilatedInputWidthElements;
84+ const float * row_end = row_start + InputWidthElements;
85+ im2col_buffer[c * TotalOutputCount * KernelSize + out_idx * KernelSize + kpos] =
86+ (input_ptr >= row_start && input_ptr < row_end) ? *input_ptr : 0 .0f ;
87+ }
88+ }
89+ }
90+
91+ // Batched SBGEMM: 16 independent GEMMs, each M=TotalOutputCount, N=1, K=KernelSize
92+ MLAS_SBGEMM_DATA_PARAMS params[16 ];
93+ for (size_t c = 0 ; c < BlockSize; c++) {
94+ params[c].A = &im2col_buffer[c * TotalOutputCount * KernelSize];
95+ params[c].B = &filter_cols[c * KernelSize];
96+ params[c].C = &output_buffer[c * TotalOutputCount];
97+ params[c].lda = KernelSize;
98+ params[c].ldb = 1 ;
99+ params[c].ldc = 1 ;
100+ params[c].Bias = nullptr ;
101+ params[c].AIsfp32 = true ;
102+ params[c].BIsfp32 = true ;
103+ params[c].ZeroMode = true ;
104+ params[c].OutputProcessor = nullptr ;
105+ }
106+ MlasSBGemmBatch (TotalOutputCount, 1 , KernelSize, BlockSize, params, nullptr );
107+
108+ // Scatter results back to output and apply post-processing
109+ for (size_t out_idx = 0 ; out_idx < TotalOutputCount; out_idx++) {
110+ float * output_ptr = &Output[out_idx * BlockSize];
111+ for (size_t c = 0 ; c < BlockSize; c++) {
112+ float val = output_buffer[c * TotalOutputCount + out_idx];
113+ if (AccumulateOutput) val += output_ptr[c];
114+ if (BiasAddition) val += Bias[c];
115+ if (ReluActivation && val < 0 ) val = 0 ;
116+ output_ptr[c] = val;
117+ }
118+ }
119+ }
120+
121+ //
122+ // BF16 NCHW/NCHWc Convolution Kernel using im2col + SBGEMM.
123+ // NCHW: 1 input channel per kernel position, single GEMM with K=KernelSize
124+ // NCHWc: BlockSize input channels per kernel position, loop over kpos with K=BlockSize
125+ //
126+ // BF16 NCHW/NCHWc Convolution Kernel using im2col + SBGEMM.
127+ // NCHW: 1 input channel per kernel position, single GEMM with K=KernelSize
128+ // NCHWc: BlockSize input channels per kernel position, loop over kpos with K=BlockSize
129+ //
130+ template <bool IsNchwcFormat>
131+ void MLASCALL
132+ MlasConvBf16KernelNeonImpl (
133+ const float * Input,
134+ const float * Filter,
135+ float * Output,
136+ size_t StrideWidth,
137+ size_t DilationWidth,
138+ size_t FilterCount,
139+ size_t InputStride,
140+ size_t FilterStride,
141+ size_t OutputStride,
142+ size_t KernelHeight,
143+ size_t KernelWidth,
144+ const float * InputBase,
145+ size_t InputWidth,
146+ size_t DilatedInputWidth,
147+ size_t OutputCountLeftPad,
148+ size_t OutputCount,
149+ size_t OutputCountRightPad,
150+ const float * Bias,
151+ unsigned KernelFlags
152+ )
153+ {
154+ const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0 ;
155+ const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0 ;
156+ const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0 ;
157+
158+ const size_t StrideWidthElements = StrideWidth / sizeof (float );
159+ const size_t DilationWidthElements = DilationWidth / sizeof (float );
160+ const size_t FilterStrideElements = FilterStride / sizeof (float );
161+ const size_t OutputStrideElements = OutputStride / sizeof (float );
105162 const size_t InputWidthElements = InputWidth / sizeof (float );
163+ const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof (float );
164+
165+ MLAS_UNREFERENCED_PARAMETER (InputStride);
106166
107167 const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
108-
109- const size_t MaxKernelPositions = KernelHeight * KernelWidth;
110- float tmpInput[MaxKernelPositions * BlockSize];
111-
112- for (size_t output_idx = 0 ; output_idx < TotalOutputCount; output_idx++) {
113- float * output_ptr = &Output[output_idx * BlockSize];
114-
115- float32x4_t OldOutput0 = MlasLoadFloat32x4 (output_ptr);
116- float32x4_t OldOutput1 = MlasLoadFloat32x4 (output_ptr + 4 );
117- float32x4_t OldOutput2 = MlasLoadFloat32x4 (output_ptr + 8 );
118- float32x4_t OldOutput3 = MlasLoadFloat32x4 (output_ptr + 12 );
119-
120- for (size_t kernel_pos = 0 ; kernel_pos < MaxKernelPositions; kernel_pos++) {
121- size_t kh = kernel_pos / KernelWidth;
122- size_t kw = kernel_pos % KernelWidth;
123- const float * input_base = Input + output_idx * StrideWidthElements +
168+ const size_t KernelSize = KernelHeight * KernelWidth;
169+
170+ std::vector<float > im2col_buffer (TotalOutputCount * (IsNchwcFormat ? BlockSize : KernelSize));
171+
172+ if (BiasAddition && AccumulateOutput) {
173+ for (size_t f = 0 ; f < FilterCount; f++) {
174+ float * output = Output + f * OutputStrideElements;
175+ const float32x4_t b0 = MlasLoadFloat32x4 (&Bias[f * BlockSize]);
176+ const float32x4_t b1 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 4 ]);
177+ const float32x4_t b2 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 8 ]);
178+ const float32x4_t b3 = MlasLoadFloat32x4 (&Bias[f * BlockSize + 12 ]);
179+ for (size_t i = 0 ; i < TotalOutputCount; i++) {
180+ MlasStoreFloat32x4 (&output[i * BlockSize], MlasAddFloat32x4 (b0, MlasLoadFloat32x4 (&output[i * BlockSize])));
181+ MlasStoreFloat32x4 (&output[i * BlockSize + 4 ], MlasAddFloat32x4 (b1, MlasLoadFloat32x4 (&output[i * BlockSize + 4 ])));
182+ MlasStoreFloat32x4 (&output[i * BlockSize + 8 ], MlasAddFloat32x4 (b2, MlasLoadFloat32x4 (&output[i * BlockSize + 8 ])));
183+ MlasStoreFloat32x4 (&output[i * BlockSize + 12 ], MlasAddFloat32x4 (b3, MlasLoadFloat32x4 (&output[i * BlockSize + 12 ])));
184+ }
185+ }
186+ }
187+
188+ MLAS_SBGEMM_DATA_PARAMS gemm_params[16 ];
189+ const size_t K = IsNchwcFormat ? BlockSize : KernelSize;
190+
191+ // Helper lambda for im2col extraction at a kernel position
192+ auto extractIm2Col = [&](size_t kpos, float * col_base, size_t col_stride) {
193+ size_t kh = kpos / KernelWidth;
194+ size_t kw = kpos % KernelWidth;
195+ const float * row_start = InputBase + kh * DilatedInputWidthElements;
196+ const float * row_end = row_start + InputWidthElements;
197+
198+ for (size_t out_idx = 0 ; out_idx < TotalOutputCount; out_idx++) {
199+ const float * input_base = Input + out_idx * StrideWidthElements +
124200 kh * DilatedInputWidthElements + kw * DilationWidthElements;
125- const float * row_start = InputBase + kh * DilatedInputWidthElements;
126- const float * row_end = row_start + InputWidthElements;
127-
128- bool valid = (input_base >= row_start) && (input_base + 15 < row_end);
129- const float * safe_ptr = input_base;
130- safe_ptr = (input_base < row_start) ? row_start : safe_ptr;
131- safe_ptr = (input_base + 15 >= row_end) ? row_start : safe_ptr;
132-
133- float32x4_t validMask = vreinterpretq_f32_s32 (MlasBroadcastInt32x4 (valid ? -1 : 0 ));
134-
135- float32x4_t loaded0 = MlasLoadFloat32x4 (safe_ptr);
136- float32x4_t loaded1 = MlasLoadFloat32x4 (safe_ptr + 4 );
137- float32x4_t loaded2 = MlasLoadFloat32x4 (safe_ptr + 8 );
138- float32x4_t loaded3 = MlasLoadFloat32x4 (safe_ptr + 12 );
139-
140- float32x4_t data0 = MlasBlendFloat32x4 (ZeroVector, loaded0, validMask);
141- float32x4_t data1 = MlasBlendFloat32x4 (ZeroVector, loaded1, validMask);
142- float32x4_t data2 = MlasBlendFloat32x4 (ZeroVector, loaded2, validMask);
143- float32x4_t data3 = MlasBlendFloat32x4 (ZeroVector, loaded3, validMask);
144-
145- MlasStoreFloat32x4 (&tmpInput[kernel_pos * BlockSize], data0);
146- MlasStoreFloat32x4 (&tmpInput[kernel_pos * BlockSize + 4 ], data1);
147- MlasStoreFloat32x4 (&tmpInput[kernel_pos * BlockSize + 8 ], data2);
148- MlasStoreFloat32x4 (&tmpInput[kernel_pos * BlockSize + 12 ], data3);
201+ float * col_ptr = col_base + out_idx * col_stride;
202+
203+ if constexpr (IsNchwcFormat) {
204+ for (size_t ic = 0 ; ic < BlockSize; ic++) {
205+ const float * ie = input_base + ic;
206+ col_ptr[ic] = (ie >= row_start && ie < row_end) ? *ie : 0 .0f ;
207+ }
208+ } else {
209+ col_ptr[kpos] = (input_base >= row_start && input_base < row_end) ? *input_base : 0 .0f ;
210+ }
149211 }
212+ };
150213
151- MlasSBDotRowWise (tmpInput, Filter, MaxKernelPositions, output_ptr);
152-
153- float32x4_t Accumulator0 = MlasLoadFloat32x4 (output_ptr);
154- float32x4_t Accumulator1 = MlasLoadFloat32x4 (output_ptr + 4 );
155- float32x4_t Accumulator2 = MlasLoadFloat32x4 (output_ptr + 8 );
156- float32x4_t Accumulator3 = MlasLoadFloat32x4 (output_ptr + 12 );
157-
158- Accumulator0 = MlasAddFloat32x4 (Accumulator0, MlasAndFloat32x4 (OldOutput0, AccumulateMask));
159- Accumulator1 = MlasAddFloat32x4 (Accumulator1, MlasAndFloat32x4 (OldOutput1, AccumulateMask));
160- Accumulator2 = MlasAddFloat32x4 (Accumulator2, MlasAndFloat32x4 (OldOutput2, AccumulateMask));
161- Accumulator3 = MlasAddFloat32x4 (Accumulator3, MlasAndFloat32x4 (OldOutput3, AccumulateMask));
162-
163- Accumulator0 = MlasAddFloat32x4 (Accumulator0, MlasAndFloat32x4 (MlasLoadFloat32x4 (Bias), BiasMask));
164- Accumulator1 = MlasAddFloat32x4 (Accumulator1, MlasAndFloat32x4 (MlasLoadFloat32x4 (Bias + 4 ), BiasMask));
165- Accumulator2 = MlasAddFloat32x4 (Accumulator2, MlasAndFloat32x4 (MlasLoadFloat32x4 (Bias + 8 ), BiasMask));
166- Accumulator3 = MlasAddFloat32x4 (Accumulator3, MlasAndFloat32x4 (MlasLoadFloat32x4 (Bias + 12 ), BiasMask));
167-
168- float32x4_t Relu0 = MlasMaximumFloat32x4 (Accumulator0, ZeroVector);
169- float32x4_t Relu1 = MlasMaximumFloat32x4 (Accumulator1, ZeroVector);
170- float32x4_t Relu2 = MlasMaximumFloat32x4 (Accumulator2, ZeroVector);
171- float32x4_t Relu3 = MlasMaximumFloat32x4 (Accumulator3, ZeroVector);
172-
173- Accumulator0 = MlasBlendFloat32x4 (Accumulator0, Relu0, ReluMask);
174- Accumulator1 = MlasBlendFloat32x4 (Accumulator1, Relu1, ReluMask);
175- Accumulator2 = MlasBlendFloat32x4 (Accumulator2, Relu2, ReluMask);
176- Accumulator3 = MlasBlendFloat32x4 (Accumulator3, Relu3, ReluMask);
177-
178- MlasStoreFloat32x4 (output_ptr, Accumulator0);
179- MlasStoreFloat32x4 (output_ptr + 4 , Accumulator1);
180- MlasStoreFloat32x4 (output_ptr + 8 , Accumulator2);
181- MlasStoreFloat32x4 (output_ptr + 12 , Accumulator3);
214+ // Helper lambda to setup GEMM params
215+ auto setupGemmParams = [&](size_t filter_offset, bool zeroMode) {
216+ size_t idx = 0 ;
217+ for (size_t f = 0 ; f < FilterCount; f++) {
218+ gemm_params[idx].A = im2col_buffer.data ();
219+ gemm_params[idx].B = Filter + f * FilterStrideElements + filter_offset;
220+ gemm_params[idx].C = Output + f * OutputStrideElements;
221+ gemm_params[idx].lda = K;
222+ gemm_params[idx].ldb = BlockSize;
223+ gemm_params[idx].ldc = BlockSize;
224+ gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr ;
225+ gemm_params[idx].AIsfp32 = true ;
226+ gemm_params[idx].BIsfp32 = true ;
227+ gemm_params[idx].ZeroMode = zeroMode;
228+ gemm_params[idx].OutputProcessor = nullptr ;
229+ idx++;
230+ }
231+ return idx;
232+ };
233+
234+ const size_t numGemmCalls = IsNchwcFormat ? KernelSize : 1 ;
235+ for (size_t g = 0 ; g < numGemmCalls; g++) {
236+ if constexpr (IsNchwcFormat) {
237+ extractIm2Col (g, im2col_buffer.data (), BlockSize);
238+ } else {
239+ for (size_t kpos = 0 ; kpos < KernelSize; kpos++) {
240+ extractIm2Col (kpos, im2col_buffer.data (), KernelSize);
241+ }
242+ }
243+ size_t kh = g / KernelWidth, kw = g % KernelWidth;
244+ size_t filter_offset = IsNchwcFormat ? kh * (KernelWidth * BlockSize * BlockSize) + kw * (BlockSize * BlockSize) : 0 ;
245+ size_t idx = setupGemmParams (filter_offset, (g == 0 ) && !AccumulateOutput);
246+ MlasSBGemmBatch (TotalOutputCount, BlockSize, K, idx, gemm_params, nullptr );
247+ }
248+
249+ if (ReluActivation) {
250+ const float32x4_t ZeroVector = MlasBroadcastFloat32x4 (0 .0f );
251+ for (size_t f = 0 ; f < FilterCount; f++) {
252+ float * output = Output + f * OutputStrideElements;
253+ for (size_t i = 0 ; i < TotalOutputCount; i++) {
254+ MlasStoreFloat32x4 (&output[i * BlockSize], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize]), ZeroVector));
255+ MlasStoreFloat32x4 (&output[i * BlockSize + 4 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 4 ]), ZeroVector));
256+ MlasStoreFloat32x4 (&output[i * BlockSize + 8 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 8 ]), ZeroVector));
257+ MlasStoreFloat32x4 (&output[i * BlockSize + 12 ], MlasMaximumFloat32x4 (MlasLoadFloat32x4 (&output[i * BlockSize + 12 ]), ZeroVector));
258+ }
259+ }
182260 }
183261}
184262
263+ void MLASCALL MlasConvNchwcBf16KernelNeon (
264+ const float * Input, const float * Filter, float * Output,
265+ size_t StrideWidth, size_t DilationWidth, size_t FilterCount,
266+ size_t InputStride, size_t FilterStride, size_t OutputStride,
267+ size_t KernelHeight, size_t KernelWidth, const float * InputBase,
268+ size_t InputWidth, size_t DilatedInputWidth,
269+ size_t OutputCountLeftPad, size_t OutputCount, size_t OutputCountRightPad,
270+ const float * Bias, unsigned KernelFlags)
271+ {
272+ MlasConvBf16KernelNeonImpl<true >(Input, Filter, Output, StrideWidth, DilationWidth,
273+ FilterCount, InputStride, FilterStride, OutputStride, KernelHeight, KernelWidth,
274+ InputBase, InputWidth, DilatedInputWidth, OutputCountLeftPad, OutputCount,
275+ OutputCountRightPad, Bias, KernelFlags);
276+ }
277+
278+ void MLASCALL MlasConvNchwBf16KernelNeon (
279+ const float * Input, const float * Filter, float * Output,
280+ size_t StrideWidth, size_t DilationWidth, size_t FilterCount,
281+ size_t InputStride, size_t FilterStride, size_t OutputStride,
282+ size_t KernelHeight, size_t KernelWidth, const float * InputBase,
283+ size_t InputWidth, size_t DilatedInputWidth,
284+ size_t OutputCountLeftPad, size_t OutputCount, size_t OutputCountRightPad,
285+ const float * Bias, unsigned KernelFlags)
286+ {
287+ MlasConvBf16KernelNeonImpl<false >(Input, Filter, Output, StrideWidth, DilationWidth,
288+ FilterCount, InputStride, FilterStride, OutputStride, KernelHeight, KernelWidth,
289+ InputBase, InputWidth, DilatedInputWidth, OutputCountLeftPad, OutputCount,
290+ OutputCountRightPad, Bias, KernelFlags);
291+ }
292+
185293//
186294// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
187295//
0 commit comments