Skip to content

Commit 759583f

Browse files
Add a NEON kernel for Depthwise
1 parent 4ab3a6e commit 759583f

File tree

1 file changed

+38
-65
lines changed

1 file changed

+38
-65
lines changed

onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp

Lines changed: 38 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -317,90 +317,65 @@ void
317317
unsigned KernelFlags
318318
)
319319
{
320-
// Mark unused parameters
321-
(void)InputBase;
322-
(void)InputWidth;
323-
324320
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
325321
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
326322
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
327323

324+
const size_t BlockSize = MlasNchwcGetBlockSize();
328325
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
329326

330-
// Convert byte strides to element strides
331327
const size_t StrideWidthElements = StrideWidth / sizeof(float);
332328
const size_t DilationWidthElements = DilationWidth / sizeof(float);
333329
const size_t InputStrideElements = InputStride / sizeof(float);
334330
const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float);
335331

336-
// Mark unused variables in Depthwise kernel
337332
(void)InputStrideElements;
338333

339-
// For depthwise convolution, process 4 channels at a time
340-
const float bias = BiasAddition ? Bias[0] : 0.0f;
341-
const float32x4_t BiasVector = vdupq_n_f32(bias);
334+
const size_t InputWidthElements = InputWidth / sizeof(float);
342335

343-
const float* input_ptr = Input;
344-
size_t OutputIndex = 0;
336+
const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad;
345337

346-
// Handle left padding
347-
for (size_t i = 0; i < OutputCountLeftPad; i++) {
348-
float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4(&Output[OutputIndex]) : BiasVector;
349-
350-
for (size_t kh = 0; kh < KernelHeight; kh++) {
351-
for (size_t kw = 0; kw < KernelWidth; kw++) {
352-
const float* input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
353-
const float filter_value = Filter[kh * KernelWidth + kw];
354-
const float32x4_t FilterVector = vdupq_n_f32(filter_value);
355-
const float32x4_t InputVector = MlasLoadFloat32x4(input_element);
338+
for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) {
339+
bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount);
356340

357-
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
358-
}
359-
}
341+
float32x4_t Accumulator;
360342

361-
if (ReluActivation) {
362-
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
343+
if (AccumulateOutput) {
344+
Accumulator = MlasLoadFloat32x4(&Output[output_idx * BlockSize]);
345+
} else if (BiasAddition) {
346+
Accumulator = MlasLoadFloat32x4(Bias);
347+
} else {
348+
Accumulator = vdupq_n_f32(0.0f);
363349
}
364350

365-
MlasStoreFloat32x4(&Output[OutputIndex], Accumulator);
366-
OutputIndex += 4;
367-
input_ptr += StrideWidthElements;
368-
}
369-
370-
// Handle main output region
371-
for (size_t i = 0; i < OutputCount; i++) {
372-
float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4(&Output[OutputIndex]) : BiasVector;
373-
374351
for (size_t kh = 0; kh < KernelHeight; kh++) {
375352
for (size_t kw = 0; kw < KernelWidth; kw++) {
376-
const float* input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
377-
const float filter_value = Filter[kh * KernelWidth + kw];
378-
const float32x4_t FilterVector = vdupq_n_f32(filter_value);
379-
const float32x4_t InputVector = MlasLoadFloat32x4(input_element);
380-
381-
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
382-
}
383-
}
384-
385-
if (ReluActivation) {
386-
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
387-
}
388-
389-
MlasStoreFloat32x4(&Output[OutputIndex], Accumulator);
390-
OutputIndex += 4;
391-
input_ptr += StrideWidthElements;
392-
}
393-
394-
// Handle right padding
395-
for (size_t i = 0; i < OutputCountRightPad; i++) {
396-
float32x4_t Accumulator = AccumulateOutput ? MlasLoadFloat32x4(&Output[OutputIndex]) : BiasVector;
353+
size_t kernel_pos = kh * KernelWidth + kw;
354+
355+
const float* input_base = Input + output_idx * StrideWidthElements +
356+
kh * DilatedInputWidthElements + kw * DilationWidthElements;
357+
358+
float32x4_t InputVector;
359+
360+
if (is_main_region) {
361+
InputVector = MlasLoadFloat32x4(input_base);
362+
} else {
363+
float input_values[4];
364+
for (size_t i = 0; i < BlockSize; i++) {
365+
const float* input_element = input_base + i;
366+
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
367+
const float* input_row_end = input_row_start + InputWidthElements;
368+
369+
if (input_element >= input_row_start && input_element < input_row_end) {
370+
input_values[i] = *input_element;
371+
} else {
372+
input_values[i] = 0.0f;
373+
}
374+
}
375+
InputVector = MlasLoadFloat32x4(input_values);
376+
}
397377

398-
for (size_t kh = 0; kh < KernelHeight; kh++) {
399-
for (size_t kw = 0; kw < KernelWidth; kw++) {
400-
const float* input_element = input_ptr + kh * DilatedInputWidthElements + kw * DilationWidthElements;
401-
const float filter_value = Filter[kh * KernelWidth + kw];
402-
const float32x4_t FilterVector = vdupq_n_f32(filter_value);
403-
const float32x4_t InputVector = MlasLoadFloat32x4(input_element);
378+
const float32x4_t FilterVector = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize]);
404379

405380
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
406381
}
@@ -410,9 +385,7 @@ void
410385
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
411386
}
412387

413-
MlasStoreFloat32x4(&Output[OutputIndex], Accumulator);
414-
OutputIndex += 4;
415-
input_ptr += StrideWidthElements;
388+
MlasStoreFloat32x4(&Output[output_idx * BlockSize], Accumulator);
416389
}
417390
}
418391

0 commit comments

Comments
 (0)