Skip to content

Commit b63aac1

Browse files
Hacks for NCHWC Conv
1 parent 8eb4f5a commit b63aac1

File tree

3 files changed

+233
-121
lines changed

3 files changed

+233
-121
lines changed

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,9 @@ extern "C" {
964964
#endif
965965
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
966966
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon;
967+
MLAS_CONV_FLOAT_KERNEL MlasConvNchwBf16KernelNeon;
967968
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
969+
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcBf16KernelNeon;
968970
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
969971
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseBf16KernelNeon;
970972
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,9 @@ Return Value:
567567

568568
#if defined(MLAS_USE_ARM_NEON_NCHWC)
569569
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
570+
this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
570571
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
572+
this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
571573
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
572574
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;
573575
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;

onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp

Lines changed: 229 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -24,52 +24,6 @@ Module Name:
2424

2525
constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE;
2626

27-
inline void MLASCALL
28-
MlasRowDot(const float* Arow, const float* Brow, float* out, int index, size_t len)
29-
{
30-
float32x4_t acc4 = MlasBroadcastFloat32x4(0.f);
31-
32-
size_t i = 0;
33-
for (; i + 8 <= len; i += 8) {
34-
float32x4_t a0 = MlasLoadFloat32x4(Arow + i);
35-
float32x4_t a1 = MlasLoadFloat32x4(Arow + i + 4);
36-
float32x4_t b0 = MlasLoadFloat32x4(Brow + i);
37-
float32x4_t b1 = MlasLoadFloat32x4(Brow + i + 4);
38-
39-
bfloat16x8_t a_bf16 = vcvtq_low_bf16_f32(a0);
40-
a_bf16 = vcvtq_high_bf16_f32(a_bf16, a1);
41-
42-
bfloat16x8_t b_bf16 = vcvtq_low_bf16_f32(b0);
43-
b_bf16 = vcvtq_high_bf16_f32(b_bf16, b1);
44-
45-
acc4 = vbfdotq_f32(acc4, a_bf16, b_bf16);
46-
}
47-
48-
float sum = vaddvq_f32(acc4);
49-
50-
for (; i < len; i++)
51-
sum += Arow[i] * Brow[i];
52-
53-
out[index] = sum;
54-
}
55-
56-
inline void MLASCALL
57-
MlasSBDotRowWise(const float* A, const float* B, size_t len, float* out)
58-
{
59-
float tmpA[len];
60-
float tmpB[len];
61-
62-
for (size_t r = 0; r < 16; r++) {
63-
for (size_t j = 0; j < len; j++)
64-
tmpA[j] = A[j * 16 + r];
65-
66-
for (size_t j = 0; j < len; j++)
67-
tmpB[j] = B[j * 16 + r];
68-
69-
MlasRowDot(tmpA, tmpB, out, r, len);
70-
}
71-
}
72-
7327
void
7428
MLASCALL
7529
MlasConvDepthwiseBf16KernelNeon(
@@ -91,97 +45,251 @@ void
9145
unsigned KernelFlags
9246
)
9347
{
94-
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
95-
const float32x4_t AccumulateMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT)));
96-
const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION)));
97-
const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));
48+
const bool AccumulateOutput = KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT;
49+
const bool BiasAddition = KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION;
50+
const bool ReluActivation = KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION;
9851

9952
const size_t StrideWidthElements = StrideWidth / sizeof(float);
10053
const size_t DilationWidthElements = DilationWidth / sizeof(float);
10154
const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float);
55+
const size_t InputWidthElements = InputWidth / sizeof(float);
56+
const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
57+
const size_t KernelSize = KernelHeight * KernelWidth;
10258

10359
MLAS_UNREFERENCED_PARAMETER(InputStride);
10460

61+
// Depthwise: 16 independent channels, each doing [TotalOutputCount][KernelSize] x [KernelSize][1]
62+
// Batch all 16 channels into one MlasSBGemmBatch call
63+
64+
std::vector<float> im2col_buffer(BlockSize * TotalOutputCount * KernelSize);
65+
std::vector<float> filter_cols(BlockSize * KernelSize);
66+
std::vector<float> output_buffer(BlockSize * TotalOutputCount);
67+
68+
// Prepare filter columns: transpose [KernelSize][16] -> 16 separate [KernelSize] vectors
69+
for (size_t c = 0; c < BlockSize; c++) {
70+
for (size_t k = 0; k < KernelSize; k++) {
71+
filter_cols[c * KernelSize + k] = Filter[k * BlockSize + c];
72+
}
73+
}
74+
75+
// im2col for all channels: [c][out_idx][kpos]
76+
for (size_t c = 0; c < BlockSize; c++) {
77+
for (size_t out_idx = 0; out_idx < TotalOutputCount; out_idx++) {
78+
for (size_t kpos = 0; kpos < KernelSize; kpos++) {
79+
size_t kh = kpos / KernelWidth;
80+
size_t kw = kpos % KernelWidth;
81+
const float* input_ptr = Input + out_idx * StrideWidthElements +
82+
kh * DilatedInputWidthElements + kw * DilationWidthElements + c;
83+
const float* row_start = InputBase + kh * DilatedInputWidthElements;
84+
const float* row_end = row_start + InputWidthElements;
85+
im2col_buffer[c * TotalOutputCount * KernelSize + out_idx * KernelSize + kpos] =
86+
(input_ptr >= row_start && input_ptr < row_end) ? *input_ptr : 0.0f;
87+
}
88+
}
89+
}
90+
91+
// Batched SBGEMM: 16 independent GEMMs, each M=TotalOutputCount, N=1, K=KernelSize
92+
MLAS_SBGEMM_DATA_PARAMS params[16];
93+
for (size_t c = 0; c < BlockSize; c++) {
94+
params[c].A = &im2col_buffer[c * TotalOutputCount * KernelSize];
95+
params[c].B = &filter_cols[c * KernelSize];
96+
params[c].C = &output_buffer[c * TotalOutputCount];
97+
params[c].lda = KernelSize;
98+
params[c].ldb = 1;
99+
params[c].ldc = 1;
100+
params[c].Bias = nullptr;
101+
params[c].AIsfp32 = true;
102+
params[c].BIsfp32 = true;
103+
params[c].ZeroMode = true;
104+
params[c].OutputProcessor = nullptr;
105+
}
106+
MlasSBGemmBatch(TotalOutputCount, 1, KernelSize, BlockSize, params, nullptr);
107+
108+
// Scatter results back to output and apply post-processing
109+
for (size_t out_idx = 0; out_idx < TotalOutputCount; out_idx++) {
110+
float* output_ptr = &Output[out_idx * BlockSize];
111+
for (size_t c = 0; c < BlockSize; c++) {
112+
float val = output_buffer[c * TotalOutputCount + out_idx];
113+
if (AccumulateOutput) val += output_ptr[c];
114+
if (BiasAddition) val += Bias[c];
115+
if (ReluActivation && val < 0) val = 0;
116+
output_ptr[c] = val;
117+
}
118+
}
119+
}
120+
121+
//
122+
// BF16 NCHW/NCHWc Convolution Kernel using im2col + SBGEMM.
123+
// NCHW: 1 input channel per kernel position, single GEMM with K=KernelSize
124+
// NCHWc: BlockSize input channels per kernel position, loop over kpos with K=BlockSize
125+
//
126+
// BF16 NCHW/NCHWc Convolution Kernel using im2col + SBGEMM.
127+
// NCHW: 1 input channel per kernel position, single GEMM with K=KernelSize
128+
// NCHWc: BlockSize input channels per kernel position, loop over kpos with K=BlockSize
129+
//
130+
template <bool IsNchwcFormat>
131+
void MLASCALL
132+
MlasConvBf16KernelNeonImpl(
133+
const float* Input,
134+
const float* Filter,
135+
float* Output,
136+
size_t StrideWidth,
137+
size_t DilationWidth,
138+
size_t FilterCount,
139+
size_t InputStride,
140+
size_t FilterStride,
141+
size_t OutputStride,
142+
size_t KernelHeight,
143+
size_t KernelWidth,
144+
const float* InputBase,
145+
size_t InputWidth,
146+
size_t DilatedInputWidth,
147+
size_t OutputCountLeftPad,
148+
size_t OutputCount,
149+
size_t OutputCountRightPad,
150+
const float* Bias,
151+
unsigned KernelFlags
152+
)
153+
{
154+
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
155+
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
156+
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
157+
158+
const size_t StrideWidthElements = StrideWidth / sizeof(float);
159+
const size_t DilationWidthElements = DilationWidth / sizeof(float);
160+
const size_t FilterStrideElements = FilterStride / sizeof(float);
161+
const size_t OutputStrideElements = OutputStride / sizeof(float);
105162
const size_t InputWidthElements = InputWidth / sizeof(float);
163+
const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float);
164+
165+
MLAS_UNREFERENCED_PARAMETER(InputStride);
106166

107167
const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
108-
109-
const size_t MaxKernelPositions = KernelHeight * KernelWidth;
110-
float tmpInput[MaxKernelPositions * BlockSize];
111-
112-
for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) {
113-
float* output_ptr = &Output[output_idx * BlockSize];
114-
115-
float32x4_t OldOutput0 = MlasLoadFloat32x4(output_ptr);
116-
float32x4_t OldOutput1 = MlasLoadFloat32x4(output_ptr + 4);
117-
float32x4_t OldOutput2 = MlasLoadFloat32x4(output_ptr + 8);
118-
float32x4_t OldOutput3 = MlasLoadFloat32x4(output_ptr + 12);
119-
120-
for (size_t kernel_pos = 0; kernel_pos < MaxKernelPositions; kernel_pos++) {
121-
size_t kh = kernel_pos / KernelWidth;
122-
size_t kw = kernel_pos % KernelWidth;
123-
const float* input_base = Input + output_idx * StrideWidthElements +
168+
const size_t KernelSize = KernelHeight * KernelWidth;
169+
170+
std::vector<float> im2col_buffer(TotalOutputCount * (IsNchwcFormat ? BlockSize : KernelSize));
171+
172+
if (BiasAddition && AccumulateOutput) {
173+
for (size_t f = 0; f < FilterCount; f++) {
174+
float* output = Output + f * OutputStrideElements;
175+
const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]);
176+
const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]);
177+
const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]);
178+
const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]);
179+
for (size_t i = 0; i < TotalOutputCount; i++) {
180+
MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize])));
181+
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4])));
182+
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8])));
183+
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12])));
184+
}
185+
}
186+
}
187+
188+
MLAS_SBGEMM_DATA_PARAMS gemm_params[16];
189+
const size_t K = IsNchwcFormat ? BlockSize : KernelSize;
190+
191+
// Helper lambda for im2col extraction at a kernel position
192+
auto extractIm2Col = [&](size_t kpos, float* col_base, size_t col_stride) {
193+
size_t kh = kpos / KernelWidth;
194+
size_t kw = kpos % KernelWidth;
195+
const float* row_start = InputBase + kh * DilatedInputWidthElements;
196+
const float* row_end = row_start + InputWidthElements;
197+
198+
for (size_t out_idx = 0; out_idx < TotalOutputCount; out_idx++) {
199+
const float* input_base = Input + out_idx * StrideWidthElements +
124200
kh * DilatedInputWidthElements + kw * DilationWidthElements;
125-
const float* row_start = InputBase + kh * DilatedInputWidthElements;
126-
const float* row_end = row_start + InputWidthElements;
127-
128-
bool valid = (input_base >= row_start) && (input_base + 15 < row_end);
129-
const float* safe_ptr = input_base;
130-
safe_ptr = (input_base < row_start) ? row_start : safe_ptr;
131-
safe_ptr = (input_base + 15 >= row_end) ? row_start : safe_ptr;
132-
133-
float32x4_t validMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(valid ? -1 : 0));
134-
135-
float32x4_t loaded0 = MlasLoadFloat32x4(safe_ptr);
136-
float32x4_t loaded1 = MlasLoadFloat32x4(safe_ptr + 4);
137-
float32x4_t loaded2 = MlasLoadFloat32x4(safe_ptr + 8);
138-
float32x4_t loaded3 = MlasLoadFloat32x4(safe_ptr + 12);
139-
140-
float32x4_t data0 = MlasBlendFloat32x4(ZeroVector, loaded0, validMask);
141-
float32x4_t data1 = MlasBlendFloat32x4(ZeroVector, loaded1, validMask);
142-
float32x4_t data2 = MlasBlendFloat32x4(ZeroVector, loaded2, validMask);
143-
float32x4_t data3 = MlasBlendFloat32x4(ZeroVector, loaded3, validMask);
144-
145-
MlasStoreFloat32x4(&tmpInput[kernel_pos * BlockSize], data0);
146-
MlasStoreFloat32x4(&tmpInput[kernel_pos * BlockSize + 4], data1);
147-
MlasStoreFloat32x4(&tmpInput[kernel_pos * BlockSize + 8], data2);
148-
MlasStoreFloat32x4(&tmpInput[kernel_pos * BlockSize + 12], data3);
201+
float* col_ptr = col_base + out_idx * col_stride;
202+
203+
if constexpr (IsNchwcFormat) {
204+
for (size_t ic = 0; ic < BlockSize; ic++) {
205+
const float* ie = input_base + ic;
206+
col_ptr[ic] = (ie >= row_start && ie < row_end) ? *ie : 0.0f;
207+
}
208+
} else {
209+
col_ptr[kpos] = (input_base >= row_start && input_base < row_end) ? *input_base : 0.0f;
210+
}
149211
}
212+
};
150213

151-
MlasSBDotRowWise(tmpInput, Filter, MaxKernelPositions, output_ptr);
152-
153-
float32x4_t Accumulator0 = MlasLoadFloat32x4(output_ptr);
154-
float32x4_t Accumulator1 = MlasLoadFloat32x4(output_ptr + 4);
155-
float32x4_t Accumulator2 = MlasLoadFloat32x4(output_ptr + 8);
156-
float32x4_t Accumulator3 = MlasLoadFloat32x4(output_ptr + 12);
157-
158-
Accumulator0 = MlasAddFloat32x4(Accumulator0, MlasAndFloat32x4(OldOutput0, AccumulateMask));
159-
Accumulator1 = MlasAddFloat32x4(Accumulator1, MlasAndFloat32x4(OldOutput1, AccumulateMask));
160-
Accumulator2 = MlasAddFloat32x4(Accumulator2, MlasAndFloat32x4(OldOutput2, AccumulateMask));
161-
Accumulator3 = MlasAddFloat32x4(Accumulator3, MlasAndFloat32x4(OldOutput3, AccumulateMask));
162-
163-
Accumulator0 = MlasAddFloat32x4(Accumulator0, MlasAndFloat32x4(MlasLoadFloat32x4(Bias), BiasMask));
164-
Accumulator1 = MlasAddFloat32x4(Accumulator1, MlasAndFloat32x4(MlasLoadFloat32x4(Bias + 4), BiasMask));
165-
Accumulator2 = MlasAddFloat32x4(Accumulator2, MlasAndFloat32x4(MlasLoadFloat32x4(Bias + 8), BiasMask));
166-
Accumulator3 = MlasAddFloat32x4(Accumulator3, MlasAndFloat32x4(MlasLoadFloat32x4(Bias + 12), BiasMask));
167-
168-
float32x4_t Relu0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector);
169-
float32x4_t Relu1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector);
170-
float32x4_t Relu2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector);
171-
float32x4_t Relu3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector);
172-
173-
Accumulator0 = MlasBlendFloat32x4(Accumulator0, Relu0, ReluMask);
174-
Accumulator1 = MlasBlendFloat32x4(Accumulator1, Relu1, ReluMask);
175-
Accumulator2 = MlasBlendFloat32x4(Accumulator2, Relu2, ReluMask);
176-
Accumulator3 = MlasBlendFloat32x4(Accumulator3, Relu3, ReluMask);
177-
178-
MlasStoreFloat32x4(output_ptr, Accumulator0);
179-
MlasStoreFloat32x4(output_ptr + 4, Accumulator1);
180-
MlasStoreFloat32x4(output_ptr + 8, Accumulator2);
181-
MlasStoreFloat32x4(output_ptr + 12, Accumulator3);
214+
// Helper lambda to setup GEMM params
215+
auto setupGemmParams = [&](size_t filter_offset, bool zeroMode) {
216+
size_t idx = 0;
217+
for (size_t f = 0; f < FilterCount; f++) {
218+
gemm_params[idx].A = im2col_buffer.data();
219+
gemm_params[idx].B = Filter + f * FilterStrideElements + filter_offset;
220+
gemm_params[idx].C = Output + f * OutputStrideElements;
221+
gemm_params[idx].lda = K;
222+
gemm_params[idx].ldb = BlockSize;
223+
gemm_params[idx].ldc = BlockSize;
224+
gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr;
225+
gemm_params[idx].AIsfp32 = true;
226+
gemm_params[idx].BIsfp32 = true;
227+
gemm_params[idx].ZeroMode = zeroMode;
228+
gemm_params[idx].OutputProcessor = nullptr;
229+
idx++;
230+
}
231+
return idx;
232+
};
233+
234+
const size_t numGemmCalls = IsNchwcFormat ? KernelSize : 1;
235+
for (size_t g = 0; g < numGemmCalls; g++) {
236+
if constexpr (IsNchwcFormat) {
237+
extractIm2Col(g, im2col_buffer.data(), BlockSize);
238+
} else {
239+
for (size_t kpos = 0; kpos < KernelSize; kpos++) {
240+
extractIm2Col(kpos, im2col_buffer.data(), KernelSize);
241+
}
242+
}
243+
size_t kh = g / KernelWidth, kw = g % KernelWidth;
244+
size_t filter_offset = IsNchwcFormat ? kh * (KernelWidth * BlockSize * BlockSize) + kw * (BlockSize * BlockSize) : 0;
245+
size_t idx = setupGemmParams(filter_offset, (g == 0) && !AccumulateOutput);
246+
MlasSBGemmBatch(TotalOutputCount, BlockSize, K, idx, gemm_params, nullptr);
247+
}
248+
249+
if (ReluActivation) {
250+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
251+
for (size_t f = 0; f < FilterCount; f++) {
252+
float* output = Output + f * OutputStrideElements;
253+
for (size_t i = 0; i < TotalOutputCount; i++) {
254+
MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector));
255+
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector));
256+
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector));
257+
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector));
258+
}
259+
}
182260
}
183261
}
184262

263+
void MLASCALL MlasConvNchwcBf16KernelNeon(
264+
const float* Input, const float* Filter, float* Output,
265+
size_t StrideWidth, size_t DilationWidth, size_t FilterCount,
266+
size_t InputStride, size_t FilterStride, size_t OutputStride,
267+
size_t KernelHeight, size_t KernelWidth, const float* InputBase,
268+
size_t InputWidth, size_t DilatedInputWidth,
269+
size_t OutputCountLeftPad, size_t OutputCount, size_t OutputCountRightPad,
270+
const float* Bias, unsigned KernelFlags)
271+
{
272+
MlasConvBf16KernelNeonImpl<true>(Input, Filter, Output, StrideWidth, DilationWidth,
273+
FilterCount, InputStride, FilterStride, OutputStride, KernelHeight, KernelWidth,
274+
InputBase, InputWidth, DilatedInputWidth, OutputCountLeftPad, OutputCount,
275+
OutputCountRightPad, Bias, KernelFlags);
276+
}
277+
278+
void MLASCALL MlasConvNchwBf16KernelNeon(
279+
const float* Input, const float* Filter, float* Output,
280+
size_t StrideWidth, size_t DilationWidth, size_t FilterCount,
281+
size_t InputStride, size_t FilterStride, size_t OutputStride,
282+
size_t KernelHeight, size_t KernelWidth, const float* InputBase,
283+
size_t InputWidth, size_t DilatedInputWidth,
284+
size_t OutputCountLeftPad, size_t OutputCount, size_t OutputCountRightPad,
285+
const float* Bias, unsigned KernelFlags)
286+
{
287+
MlasConvBf16KernelNeonImpl<false>(Input, Filter, Output, StrideWidth, DilationWidth,
288+
FilterCount, InputStride, FilterStride, OutputStride, KernelHeight, KernelWidth,
289+
InputBase, InputWidth, DilatedInputWidth, OutputCountLeftPad, OutputCount,
290+
OutputCountRightPad, Bias, KernelFlags);
291+
}
292+
185293
//
186294
// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
187295
//

0 commit comments

Comments
 (0)