Skip to content

Commit b586f2f

Browse files
committed
Support arthimetic operator for fp32 simd
Signed-off-by: LHT129 <[email protected]>
1 parent 29c0959 commit b586f2f

File tree

9 files changed

+385
-7
lines changed

9 files changed

+385
-7
lines changed

.circleci/fresh_ci_cache.commit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d2491912b3b18b2d8745cd7468001a91eab89692
1+
29c0959541dec4e5fae74fa7653fc2b1831cfb31

src/simd/avx.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,69 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) {
292292
#endif
293293
}
294294

295+
void
296+
FP32Add(const float* x, const float* y, float* z, uint64_t dim) {
297+
#if defined(ENABLE_AVX)
298+
if (dim < 8) {
299+
return sse::FP32Add(x, y, z, dim);
300+
}
301+
int i = 0;
302+
for (; i + 7 < dim; i += 8) {
303+
__m256 a = _mm256_loadu_ps(x + i);
304+
__m256 b = _mm256_loadu_ps(y + i);
305+
__m256 c = _mm256_add_ps(a, b);
306+
_mm256_storeu_ps(z + i, c);
307+
}
308+
if (i < dim) {
309+
sse::FP32Add(x + i, y + i, z + i, dim - i);
310+
}
311+
#else
312+
sse::FP32Add(x, y, z, dim);
313+
#endif
314+
}
315+
316+
void
317+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim) {
318+
#if defined(ENABLE_AVX)
319+
if (dim < 8) {
320+
return sse::FP32Mul(x, y, z, dim);
321+
}
322+
int i = 0;
323+
for (; i + 7 < dim; i += 8) {
324+
__m256 a = _mm256_loadu_ps(x + i);
325+
__m256 b = _mm256_loadu_ps(y + i);
326+
__m256 c = _mm256_mul_ps(a, b);
327+
_mm256_storeu_ps(z + i, c);
328+
}
329+
if (i < dim) {
330+
sse::FP32Mul(x + i, y + i, z + i, dim - i);
331+
}
332+
#else
333+
sse::FP32Mul(x, y, z, dim);
334+
#endif
335+
}
336+
337+
void
338+
FP32Div(const float* x, const float* y, float* z, uint64_t dim) {
339+
#if defined(ENABLE_AVX)
340+
if (dim < 8) {
341+
return sse::FP32Div(x, y, z, dim);
342+
}
343+
int i = 0;
344+
for (; i + 7 < dim; i += 8) {
345+
__m256 a = _mm256_loadu_ps(x + i);
346+
__m256 b = _mm256_loadu_ps(y + i);
347+
__m256 c = _mm256_div_ps(a, b);
348+
_mm256_storeu_ps(z + i, c);
349+
}
350+
if (i < dim) {
351+
sse::FP32Div(x + i, y + i, z + i, dim - i);
352+
}
353+
#else
354+
sse::FP32Div(x, y, z, dim);
355+
#endif
356+
}
357+
295358
#if defined(ENABLE_AVX)
296359
__inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) {
297360
return _mm256_set_epi16(data[7],

src/simd/avx2.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,69 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) {
286286
#endif
287287
}
288288

289+
void
290+
FP32Add(const float* x, const float* y, float* z, uint64_t dim) {
291+
#if defined(ENABLE_AVX2)
292+
if (dim < 8) {
293+
return sse::FP32Add(x, y, z, dim);
294+
}
295+
int i = 0;
296+
for (; i + 7 < dim; i += 8) {
297+
__m256 a = _mm256_loadu_ps(x + i);
298+
__m256 b = _mm256_loadu_ps(y + i);
299+
__m256 c = _mm256_add_ps(a, b);
300+
_mm256_storeu_ps(z + i, c);
301+
}
302+
if (i < dim) {
303+
sse::FP32Add(x + i, y + i, z + i, dim - i);
304+
}
305+
#else
306+
return sse::FP32Add(x, y, z, dim);
307+
#endif
308+
}
309+
310+
void
311+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim) {
312+
#if defined(ENABLE_AVX2)
313+
if (dim < 8) {
314+
return sse::FP32Mul(x, y, z, dim);
315+
}
316+
int i = 0;
317+
for (; i + 7 < dim; i += 8) {
318+
__m256 a = _mm256_loadu_ps(x + i);
319+
__m256 b = _mm256_loadu_ps(y + i);
320+
__m256 c = _mm256_mul_ps(a, b);
321+
_mm256_storeu_ps(z + i, c);
322+
}
323+
if (i < dim) {
324+
sse::FP32Mul(x + i, y + i, z + i, dim - i);
325+
}
326+
#else
327+
return sse::FP32Mul(x, y, z, dim);
328+
#endif
329+
}
330+
331+
void
332+
FP32Div(const float* x, const float* y, float* z, uint64_t dim) {
333+
#if defined(ENABLE_AVX2)
334+
if (dim < 8) {
335+
return sse::FP32Div(x, y, z, dim);
336+
}
337+
int i = 0;
338+
for (; i + 7 < dim; i += 8) {
339+
__m256 a = _mm256_loadu_ps(x + i);
340+
__m256 b = _mm256_loadu_ps(y + i);
341+
__m256 c = _mm256_div_ps(a, b);
342+
_mm256_storeu_ps(z + i, c);
343+
}
344+
if (i < dim) {
345+
sse::FP32Div(x + i, y + i, z + i, dim - i);
346+
}
347+
#else
348+
return sse::FP32Div(x, y, z, dim);
349+
#endif
350+
}
351+
289352
#if defined(ENABLE_AVX2)
290353
__inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) {
291354
__m128i bf16 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data));

src/simd/avx512.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,69 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) {
329329
#endif
330330
}
331331

332+
void
333+
FP32Add(const float* x, const float* y, float* z, uint64_t dim) {
334+
#if defined(ENABLE_AVX512)
335+
if (dim < 16) {
336+
return avx2::FP32Add(x, y, z, dim);
337+
}
338+
uint64_t i = 0;
339+
for (; i + 15 < dim; i += 16) {
340+
__m512 x_vec = _mm512_loadu_ps(x + i);
341+
__m512 y_vec = _mm512_loadu_ps(y + i);
342+
__m512 sum_vec = _mm512_add_ps(x_vec, y_vec);
343+
_mm512_storeu_ps(z + i, sum_vec);
344+
}
345+
if (dim > i) {
346+
avx2::FP32Add(x + i, y + i, z + i, dim - i);
347+
}
348+
#else
349+
return avx2::FP32Add(x, y, z, dim);
350+
#endif
351+
}
352+
353+
void
354+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim) {
355+
#if defined(ENABLE_AVX512)
356+
if (dim < 16) {
357+
return avx2::FP32Mul(x, y, z, dim);
358+
}
359+
uint64_t i = 0;
360+
for (; i + 15 < dim; i += 16) {
361+
__m512 x_vec = _mm512_loadu_ps(x + i);
362+
__m512 y_vec = _mm512_loadu_ps(y + i);
363+
__m512 mul_vec = _mm512_mul_ps(x_vec, y_vec);
364+
_mm512_storeu_ps(z + i, mul_vec);
365+
}
366+
if (dim > i) {
367+
avx2::FP32Mul(x + i, y + i, z + i, dim - i);
368+
}
369+
#else
370+
return avx2::FP32Mul(x, y, z, dim);
371+
#endif
372+
}
373+
374+
void
375+
FP32Div(const float* x, const float* y, float* z, uint64_t dim) {
376+
#if defined(ENABLE_AVX512)
377+
if (dim < 16) {
378+
return avx2::FP32Div(x, y, z, dim);
379+
}
380+
uint64_t i = 0;
381+
for (; i + 15 < dim; i += 16) {
382+
__m512 x_vec = _mm512_loadu_ps(x + i);
383+
__m512 y_vec = _mm512_loadu_ps(y + i);
384+
__m512 div_vec = _mm512_div_ps(x_vec, y_vec);
385+
_mm512_storeu_ps(z + i, div_vec);
386+
}
387+
if (dim > i) {
388+
avx2::FP32Div(x + i, y + i, z + i, dim - i);
389+
}
390+
#else
391+
return avx2::FP32Div(x, y, z, dim);
392+
#endif
393+
}
394+
332395
#if defined(ENABLE_AVX512)
333396
__inline __m512i __attribute__((__always_inline__)) load_16_short(const uint16_t* data) {
334397
__m256i bf16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data));

src/simd/fp32_simd.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ GetFP32ComputeL2SqrBatch4() {
111111
}
112112
FP32ComputeBatch4Type FP32ComputeL2SqrBatch4 = GetFP32ComputeL2SqrBatch4();
113113

114-
static FP32SubType
114+
static FP32ArithmeticType
115115
GetFP32Sub() {
116116
if (SimdStatus::SupportAVX512()) {
117117
#if defined(ENABLE_AVX512)
@@ -132,5 +132,74 @@ GetFP32Sub() {
132132
}
133133
return generic::FP32Sub;
134134
}
135-
FP32SubType FP32Sub = GetFP32Sub();
135+
FP32ArithmeticType FP32Sub = GetFP32Sub();
136+
137+
static FP32ArithmeticType
138+
GetFP32Add() {
139+
if (SimdStatus::SupportAVX512()) {
140+
#if defined(ENABLE_AVX512)
141+
return avx512::FP32Add;
142+
#endif
143+
} else if (SimdStatus::SupportAVX2()) {
144+
#if defined(ENABLE_AVX2)
145+
return avx2::FP32Add;
146+
#endif
147+
} else if (SimdStatus::SupportAVX()) {
148+
#if defined(ENABLE_AVX)
149+
return avx::FP32Add;
150+
#endif
151+
} else if (SimdStatus::SupportSSE()) {
152+
#if defined(ENABLE_SSE)
153+
return sse::FP32Add;
154+
#endif
155+
}
156+
return generic::FP32Add;
157+
}
158+
FP32ArithmeticType FP32Add = GetFP32Add();
159+
160+
static FP32ArithmeticType
161+
GetFP32Mul() {
162+
if (SimdStatus::SupportAVX512()) {
163+
#if defined(ENABLE_AVX512)
164+
return avx512::FP32Mul;
165+
#endif
166+
} else if (SimdStatus::SupportAVX2()) {
167+
#if defined(ENABLE_AVX2)
168+
return avx2::FP32Mul;
169+
#endif
170+
} else if (SimdStatus::SupportAVX()) {
171+
#if defined(ENABLE_AVX)
172+
return avx::FP32Mul;
173+
#endif
174+
} else if (SimdStatus::SupportSSE()) {
175+
#if defined(ENABLE_SSE)
176+
return sse::FP32Mul;
177+
#endif
178+
}
179+
return generic::FP32Mul;
180+
}
181+
FP32ArithmeticType FP32Mul = GetFP32Mul();
182+
183+
static FP32ArithmeticType
184+
GetFP32Div() {
185+
if (SimdStatus::SupportAVX512()) {
186+
#if defined(ENABLE_AVX512)
187+
return avx512::FP32Div;
188+
#endif
189+
} else if (SimdStatus::SupportAVX2()) {
190+
#if defined(ENABLE_AVX2)
191+
return avx2::FP32Div;
192+
#endif
193+
} else if (SimdStatus::SupportAVX()) {
194+
#if defined(ENABLE_AVX)
195+
return avx::FP32Div;
196+
#endif
197+
} else if (SimdStatus::SupportSSE()) {
198+
#if defined(ENABLE_SSE)
199+
return sse::FP32Div;
200+
#endif
201+
}
202+
return generic::FP32Div;
203+
}
204+
FP32ArithmeticType FP32Div = GetFP32Div();
136205
} // namespace vsag

src/simd/fp32_simd.h

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ FP32ComputeL2SqrBatch4(const float* query,
4949
float& result4);
5050
void
5151
FP32Sub(const float* x, const float* y, float* z, uint64_t dim);
52+
void
53+
FP32Add(const float* x, const float* y, float* z, uint64_t dim);
54+
void
55+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim);
56+
void
57+
FP32Div(const float* x, const float* y, float* z, uint64_t dim);
5258
} // namespace generic
5359

5460
namespace sse {
@@ -80,6 +86,12 @@ FP32ComputeL2SqrBatch4(const float* query,
8086
float& result4);
8187
void
8288
FP32Sub(const float* x, const float* y, float* z, uint64_t dim);
89+
void
90+
FP32Add(const float* x, const float* y, float* z, uint64_t dim);
91+
void
92+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim);
93+
void
94+
FP32Div(const float* x, const float* y, float* z, uint64_t dim);
8395
} // namespace sse
8496

8597
namespace avx {
@@ -111,6 +123,12 @@ FP32ComputeL2SqrBatch4(const float* query,
111123
float& result4);
112124
void
113125
FP32Sub(const float* x, const float* y, float* z, uint64_t dim);
126+
void
127+
FP32Add(const float* x, const float* y, float* z, uint64_t dim);
128+
void
129+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim);
130+
void
131+
FP32Div(const float* x, const float* y, float* z, uint64_t dim);
114132
} // namespace avx
115133

116134
namespace avx2 {
@@ -142,6 +160,12 @@ FP32ComputeL2SqrBatch4(const float* query,
142160
float& result4);
143161
void
144162
FP32Sub(const float* x, const float* y, float* z, uint64_t dim);
163+
void
164+
FP32Add(const float* x, const float* y, float* z, uint64_t dim);
165+
void
166+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim);
167+
void
168+
FP32Div(const float* x, const float* y, float* z, uint64_t dim);
145169
} // namespace avx2
146170

147171
namespace avx512 {
@@ -173,6 +197,12 @@ FP32ComputeL2SqrBatch4(const float* query,
173197
float& result4);
174198
void
175199
FP32Sub(const float* x, const float* y, float* z, uint64_t dim);
200+
void
201+
FP32Add(const float* x, const float* y, float* z, uint64_t dim);
202+
void
203+
FP32Mul(const float* x, const float* y, float* z, uint64_t dim);
204+
void
205+
FP32Div(const float* x, const float* y, float* z, uint64_t dim);
176206
} // namespace avx512
177207

178208
using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim);
@@ -192,6 +222,9 @@ using FP32ComputeBatch4Type = void (*)(const float* query,
192222
extern FP32ComputeBatch4Type FP32ComputeIPBatch4;
193223
extern FP32ComputeBatch4Type FP32ComputeL2SqrBatch4;
194224

195-
using FP32SubType = void (*)(const float* x, const float* y, float* z, uint64_t dim);
196-
extern FP32SubType FP32Sub;
225+
using FP32ArithmeticType = void (*)(const float* x, const float* y, float* z, uint64_t dim);
226+
extern FP32ArithmeticType FP32Sub;
227+
extern FP32ArithmeticType FP32Add;
228+
extern FP32ArithmeticType FP32Mul;
229+
extern FP32ArithmeticType FP32Div;
197230
} // namespace vsag

src/simd/fp32_simd_test.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using namespace vsag;
4545
} \
4646
};
4747

48-
#define TEST_FP32_SUB_ACCURACY(Func) \
48+
#define TEST_FP32_ARTHIMETIC_ACCURACY(Func) \
4949
{ \
5050
std::vector<float> gt(dim, 0.0F); \
5151
generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, gt.data(), dim); \
@@ -176,7 +176,10 @@ TEST_CASE("FP32 SIMD Compute", "[ut][simd]") {
176176
for (uint64_t i = 0; i < count; ++i) {
177177
TEST_FP32_COMPUTE_ACCURACY(FP32ComputeIP);
178178
TEST_FP32_COMPUTE_ACCURACY(FP32ComputeL2Sqr);
179-
TEST_FP32_SUB_ACCURACY(FP32Sub);
179+
TEST_FP32_ARTHIMETIC_ACCURACY(FP32Sub);
180+
TEST_FP32_ARTHIMETIC_ACCURACY(FP32Add);
181+
TEST_FP32_ARTHIMETIC_ACCURACY(FP32Mul);
182+
TEST_FP32_ARTHIMETIC_ACCURACY(FP32Div);
180183
}
181184
for (uint64_t i = 0; i < count; i += 4) {
182185
TEST_FP32_COMPUTE_ACCURACY_BATCH4(FP32ComputeIP, FP32ComputeIPBatch4);

0 commit comments

Comments
 (0)