Skip to content

Commit 8f71901

Browse files
Use MLAS intrinsics for MlasConvNchwcFloatKernelNeon
1 parent 4dea681 commit 8f71901

File tree

1 file changed

+31
-35
lines changed

1 file changed

+31
-35
lines changed

onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ void
106106
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
107107

108108
const size_t BlockSize = MlasNchwcGetBlockSize();
109+
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
109110

110111
const size_t StrideWidthElements = StrideWidth / sizeof(float);
111112
const size_t DilationWidthElements = DilationWidth / sizeof(float);
@@ -125,20 +126,17 @@ void
125126
const float* filter = Filter + filterSetBlock * FilterStrideElements;
126127
float* output = Output + filterSetBlock * OutputStrideElements;
127128

128-
float accumulator[BlockSize];
129+
float32x4_t Accumulator;
129130

130-
for (size_t i = 0; i < BlockSize; i++) {
131-
accumulator[i] = 0.0f;
132-
}
133131
if (AccumulateOutput) {
134-
for (size_t i = 0; i < BlockSize; i++) {
135-
accumulator[i] = output[output_idx * BlockSize + i];
136-
}
132+
Accumulator = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
133+
} else {
134+
Accumulator = vdupq_n_f32(0.0f);
137135
}
136+
138137
if (BiasAddition) {
139-
for (size_t i = 0; i < BlockSize; i++) {
140-
accumulator[i] += Bias[filterSetBlock * BlockSize + i];
141-
}
138+
const float32x4_t BiasVector = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]);
139+
Accumulator = vaddq_f32(Accumulator, BiasVector);
142140
}
143141

144142
for (size_t kh = 0; kh < KernelHeight; kh++) {
@@ -147,37 +145,35 @@ void
147145
kh * DilatedInputWidthElements + kw * DilationWidthElements;
148146

149147
for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) {
150-
for (size_t ic = 0; ic < BlockSize; ic++) {
151-
size_t kernel_pos = kh * (KernelWidth * BlockSize * BlockSize) +
152-
kw * (BlockSize * BlockSize) +
153-
filterBlock * (BlockSize) +
154-
ic;
155-
156-
const float* input_element = input_base + filterBlock;
157-
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
158-
const float* input_row_end = input_row_start + InputWidthElements;
159-
160-
float input_value;
161-
if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
162-
input_value = *input_element;
163-
} else {
164-
input_value = 0.0f;
165-
}
166-
167-
float filter_value = filter[kernel_pos];
168-
accumulator[ic] += input_value * filter_value;
148+
const float* input_element = input_base + filterBlock;
149+
const float* input_row_start = InputBase + kh * DilatedInputWidthElements;
150+
const float* input_row_end = input_row_start + InputWidthElements;
151+
152+
float input_value;
153+
if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) {
154+
input_value = *input_element;
155+
} else {
156+
input_value = 0.0f;
169157
}
170-
}
171158

159+
const float32x4_t InputVector = vdupq_n_f32(input_value);
160+
161+
size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
162+
kw * (BlockSize * BlockSize) +
163+
filterBlock * BlockSize;
164+
165+
const float32x4_t FilterVector = MlasLoadFloat32x4(&filter[kernel_base_pos]);
166+
167+
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
168+
}
172169
}
173170
}
174171

175-
for (size_t i = 0; i < BlockSize; i++) {
176-
if (ReluActivation && accumulator[i] < 0.0f) {
177-
accumulator[i] = 0.0f;
178-
}
179-
output[output_idx * BlockSize + i] = accumulator[i];
172+
if (ReluActivation) {
173+
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
180174
}
175+
176+
MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator);
181177
}
182178
}
183179
}

0 commit comments

Comments
 (0)