Skip to content

Commit 382e61b

Browse files
Remove unnecessary code & formatting changes
1 parent ccbaa87 commit 382e61b

File tree

2 files changed

+28
-124
lines changed

2 files changed

+28
-124
lines changed

onnxruntime/core/mlas/lib/sconv.h

Lines changed: 4 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -35,105 +35,7 @@ Module Name:
3535
// Define the convolution kernel flags.
3636
//
3737

38-
#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001
39-
#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002
40-
#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004
41-
#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008
42-
43-
//
44-
// Define the prototypes of the NEON convolution kernels.
45-
//
46-
47-
#if defined(__aarch64__) || defined(_M_ARM64)
48-
49-
extern "C" {
50-
51-
void
52-
MLASCALL
53-
MlasConvNchwFloatKernelNeon(
54-
const float* Input,
55-
const float* Filter,
56-
float* Output,
57-
size_t StrideWidth,
58-
size_t DilationWidth,
59-
size_t FilterCount,
60-
size_t InputStride,
61-
size_t FilterStride,
62-
size_t OutputStride,
63-
size_t KernelHeight,
64-
size_t KernelWidth,
65-
const float* InputBase,
66-
size_t InputWidth,
67-
size_t DilatedInputWidth,
68-
size_t OutputCountLeftPad,
69-
size_t OutputCount,
70-
size_t OutputCountRightPad,
71-
const float* Bias,
72-
unsigned KernelFlags
73-
);
74-
75-
void
76-
MLASCALL
77-
MlasConvNchwcFloatKernelNeon(
78-
const float* Input,
79-
const float* Filter,
80-
float* Output,
81-
size_t StrideWidth,
82-
size_t DilationWidth,
83-
size_t FilterCount,
84-
size_t InputStride,
85-
size_t FilterStride,
86-
size_t OutputStride,
87-
size_t KernelHeight,
88-
size_t KernelWidth,
89-
const float* InputBase,
90-
size_t InputWidth,
91-
size_t DilatedInputWidth,
92-
size_t OutputCountLeftPad,
93-
size_t OutputCount,
94-
size_t OutputCountRightPad,
95-
const float* Bias,
96-
unsigned KernelFlags
97-
);
98-
99-
void
100-
MLASCALL
101-
MlasConvDepthwiseFloatKernelNeon(
102-
const float* Input,
103-
const float* Filter,
104-
float* Output,
105-
size_t StrideWidth,
106-
size_t DilationWidth,
107-
size_t InputStride,
108-
size_t KernelHeight,
109-
size_t KernelWidth,
110-
const float* InputBase,
111-
size_t InputWidth,
112-
size_t DilatedInputWidth,
113-
size_t OutputCountLeftPad,
114-
size_t OutputCount,
115-
size_t OutputCountRightPad,
116-
const float* Bias,
117-
unsigned KernelFlags
118-
);
119-
120-
void
121-
MLASCALL
122-
MlasConvPointwiseFloatKernelNeon(
123-
const float* Input,
124-
const float* Filter,
125-
float* Output,
126-
size_t StrideWidth,
127-
size_t InputChannels,
128-
size_t FilterCount,
129-
size_t InputStride,
130-
size_t FilterStride,
131-
size_t OutputStride,
132-
size_t OutputCount,
133-
const float* Bias,
134-
unsigned KernelFlags
135-
);
136-
137-
}
138-
139-
#endif
38+
#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001
39+
#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002
40+
#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004
41+
#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008

onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Module Name:
2121
#include <algorithm>
2222
#include <cstddef>
2323

24-
#include "arm_neon.h"
2524
#include "mlasi.h"
2625

2726
void
@@ -53,7 +52,7 @@ void
5352
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
5453

5554
const size_t BlockSize = MlasNchwcGetBlockSize();
56-
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
55+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
5756

5857
const size_t StrideWidthElements = StrideWidth / sizeof(float);
5958
const size_t DilationWidthElements = DilationWidth / sizeof(float);
@@ -78,12 +77,12 @@ void
7877
if (AccumulateOutput) {
7978
Accumulator = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
8079
} else {
81-
Accumulator = vdupq_n_f32(0.0f);
80+
Accumulator = MlasBroadcastFloat32x4(0.0f);
8281
}
8382

8483
if (BiasAddition) {
8584
const float32x4_t BiasVector = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]);
86-
Accumulator = vaddq_f32(Accumulator, BiasVector);
85+
Accumulator = MlasAddFloat32x4(Accumulator, BiasVector);
8786
}
8887

8988
for (size_t kh = 0; kh < KernelHeight; kh++) {
@@ -101,7 +100,7 @@ void
101100
input_value = 0.0f;
102101
}
103102

104-
const float32x4_t InputVector = vdupq_n_f32(input_value);
103+
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
105104

106105
size_t kernel_base_pos = kh * KernelWidth + kw;
107106

@@ -153,7 +152,7 @@ void
153152
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
154153

155154
const size_t BlockSize = MlasNchwcGetBlockSize();
156-
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
155+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
157156

158157
const size_t StrideWidthElements = StrideWidth / sizeof(float);
159158
const size_t DilationWidthElements = DilationWidth / sizeof(float);
@@ -178,12 +177,12 @@ void
178177
if (AccumulateOutput) {
179178
Accumulator = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
180179
} else {
181-
Accumulator = vdupq_n_f32(0.0f);
180+
Accumulator = MlasBroadcastFloat32x4(0.0f);
182181
}
183182

184183
if (BiasAddition) {
185184
const float32x4_t BiasVector = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]);
186-
Accumulator = vaddq_f32(Accumulator, BiasVector);
185+
Accumulator = MlasAddFloat32x4(Accumulator, BiasVector);
187186
}
188187

189188
for (size_t kh = 0; kh < KernelHeight; kh++) {
@@ -203,7 +202,7 @@ void
203202
input_value = 0.0f;
204203
}
205204

206-
const float32x4_t InputVector = vdupq_n_f32(input_value);
205+
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
207206

208207
size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) +
209208
kw * (BlockSize * BlockSize) +
@@ -259,7 +258,7 @@ void
259258
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;
260259

261260
const size_t BlockSize = MlasNchwcGetBlockSize();
262-
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
261+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
263262

264263
const size_t StrideWidthElements = StrideWidth / sizeof(float);
265264
const size_t DilationWidthElements = DilationWidth / sizeof(float);
@@ -279,10 +278,11 @@ void
279278

280279
if (AccumulateOutput) {
281280
Accumulator = MlasLoadFloat32x4(&Output[output_idx * BlockSize]);
282-
} else if (BiasAddition) {
283-
Accumulator = MlasLoadFloat32x4(Bias);
284281
} else {
285-
Accumulator = vdupq_n_f32(0.0f);
282+
Accumulator = MlasBroadcastFloat32x4(0.0f);
283+
}
284+
if (BiasAddition) {
285+
Accumulator = MlasAddFloat32x4(Accumulator, MlasLoadFloat32x4(Bias));
286286
}
287287

288288
for (size_t kh = 0; kh < KernelHeight; kh++) {
@@ -361,25 +361,27 @@ void
361361
const size_t OutputStrideElements = OutputStride / sizeof(float);
362362

363363
const size_t BlockSize = MlasNchwcGetBlockSize();
364-
const float32x4_t ZeroVector = vdupq_n_f32(0.0f);
364+
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
365365

366-
for (size_t i = 0; i < OutputCount; i++) {
366+
for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) {
367367
for (size_t f = 0; f < FilterCount; f++) {
368368
const float* filter = Filter + f * FilterStrideElements;
369369
float* output = Output + f * OutputStrideElements;
370370
float32x4_t Accumulator;
371371
if (AccumulateOutput) {
372-
Accumulator = MlasLoadFloat32x4(&output[i * BlockSize]);
373-
} else if (BiasAddition) {
374-
Accumulator = MlasLoadFloat32x4(&Bias[f * BlockSize]);
372+
Accumulator = MlasLoadFloat32x4(&output[output_idx * BlockSize]);
375373
} else {
376-
Accumulator = vdupq_n_f32(0.0f);
374+
Accumulator = MlasBroadcastFloat32x4(0.0f);
375+
}
376+
if (BiasAddition) {
377+
const float32x4_t BiasVector = MlasLoadFloat32x4(&Bias[f * BlockSize]);
378+
Accumulator = MlasAddFloat32x4(Accumulator, BiasVector);
377379
}
378380
for (size_t c = 0; c < InputChannels; c++) {
379-
const float* input_ptr = Input + c * InputStrideElements + i * StrideWidthElements;
381+
const float* input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements;
380382
for (size_t input_b = 0; input_b < BlockSize; input_b++) {
381383
const float input_value = input_ptr[input_b];
382-
const float32x4_t InputVector = vdupq_n_f32(input_value);
384+
const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value);
383385
const float* filter_ptr = filter + (c * BlockSize + input_b) * BlockSize;
384386
const float32x4_t FilterVector = MlasLoadFloat32x4(filter_ptr);
385387
Accumulator = MlasMultiplyAddFloat32x4(InputVector, FilterVector, Accumulator);
@@ -388,7 +390,7 @@ void
388390
if (ReluActivation) {
389391
Accumulator = MlasMaximumFloat32x4(Accumulator, ZeroVector);
390392
}
391-
MlasStoreFloat32x4(&output[i * BlockSize], Accumulator);
393+
MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator);
392394
}
393395
}
394396
}

0 commit comments

Comments
 (0)