Skip to content

Commit 29c0959

Browse files
authored
add bit operator simd implement (#756)
Signed-off-by: LHT129 <[email protected]>
1 parent 4098faf commit 29c0959

File tree

10 files changed

+796
-0
lines changed

10 files changed

+796
-0
lines changed

src/simd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set (SIMD_SRCS
77
simd.cpp
88
simd_status.cpp
99
basic_func.cpp
10+
bit_simd.cpp
1011
fp32_simd.cpp
1112
fp16_simd.cpp
1213
bf16_simd.cpp

src/simd/avx.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,4 +768,99 @@ PQFastScanLookUp32(const uint8_t* lookup_table,
768768
#endif
769769
}
770770

771+
void
772+
BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
773+
#if defined(ENABLE_AVX)
774+
if (num_byte == 0) {
775+
return;
776+
}
777+
if (num_byte < 32) {
778+
return sse::BitAnd(x, y, num_byte, result);
779+
}
780+
int64_t i = 0;
781+
for (; i + 31 < num_byte; i += 32) {
782+
__m256 x_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(x + i));
783+
__m256 y_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(y + i));
784+
__m256 z_vec = _mm256_and_ps(x_vec, y_vec);
785+
_mm256_storeu_ps(reinterpret_cast<float*>(result + i), z_vec);
786+
}
787+
if (i < num_byte) {
788+
sse::BitAnd(x + i, y + i, num_byte - i, result + i);
789+
}
790+
#else
791+
return sse::BitAnd(x, y, num_byte, result);
792+
#endif
793+
}
794+
795+
void
796+
BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
797+
#if defined(ENABLE_AVX)
798+
if (num_byte == 0) {
799+
return;
800+
}
801+
if (num_byte < 32) {
802+
return sse::BitOr(x, y, num_byte, result);
803+
}
804+
int64_t i = 0;
805+
for (; i + 31 < num_byte; i += 32) {
806+
__m256 x_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(x + i));
807+
__m256 y_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(y + i));
808+
__m256 z_vec = _mm256_or_ps(x_vec, y_vec);
809+
_mm256_storeu_ps(reinterpret_cast<float*>(result + i), z_vec);
810+
}
811+
if (i < num_byte) {
812+
sse::BitOr(x + i, y + i, num_byte - i, result + i);
813+
}
814+
#else
815+
return sse::BitOr(x, y, num_byte, result);
816+
#endif
817+
}
818+
819+
void
820+
BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
821+
#if defined(ENABLE_AVX)
822+
if (num_byte == 0) {
823+
return;
824+
}
825+
if (num_byte < 32) {
826+
return sse::BitXor(x, y, num_byte, result);
827+
}
828+
int64_t i = 0;
829+
for (; i + 31 < num_byte; i += 32) {
830+
__m256 x_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(x + i));
831+
__m256 y_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(y + i));
832+
__m256 z_vec = _mm256_xor_ps(x_vec, y_vec);
833+
_mm256_storeu_ps(reinterpret_cast<float*>(result + i), z_vec);
834+
}
835+
if (i < num_byte) {
836+
sse::BitXor(x + i, y + i, num_byte - i, result + i);
837+
}
838+
#else
839+
return sse::BitXor(x, y, num_byte, result);
840+
#endif
841+
}
842+
843+
void
844+
BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) {
845+
#if defined(ENABLE_AVX)
846+
if (num_byte == 0) {
847+
return;
848+
}
849+
if (num_byte < 32) {
850+
return sse::BitNot(x, num_byte, result);
851+
}
852+
int64_t i = 0;
853+
__m256 all_one = _mm256_castsi256_ps(_mm256_set1_epi32(-1));
854+
for (; i + 31 < num_byte; i += 32) {
855+
__m256 x_vec = _mm256_loadu_ps(reinterpret_cast<const float*>(x + i));
856+
__m256 z_vec = _mm256_xor_ps(x_vec, all_one);
857+
_mm256_storeu_ps(reinterpret_cast<float*>(result + i), z_vec);
858+
}
859+
if (i < num_byte) {
860+
sse::BitNot(x + i, num_byte - i, result + i);
861+
}
862+
#else
863+
return sse::BitNot(x, num_byte, result);
864+
#endif
865+
}
771866
} // namespace vsag::avx

src/simd/avx2.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,4 +824,98 @@ PQFastScanLookUp32(const uint8_t* lookup_table,
824824
#endif
825825
}
826826

827+
void
828+
BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
829+
#if defined(ENABLE_AVX2)
830+
if (num_byte == 0) {
831+
return;
832+
}
833+
if (num_byte < 32) {
834+
return sse::BitAnd(x, y, num_byte, result);
835+
}
836+
int64_t i = 0;
837+
for (; i + 31 < num_byte; i += 32) {
838+
__m256i x_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + i));
839+
__m256i y_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(y + i));
840+
__m256i z_vec = _mm256_and_si256(x_vec, y_vec);
841+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec);
842+
}
843+
if (i < num_byte) {
844+
sse::BitAnd(x + i, y + i, num_byte - i, result + i);
845+
}
846+
#else
847+
return sse::BitAnd(x, y, num_byte, result);
848+
#endif
849+
}
850+
851+
void
852+
BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
853+
#if defined(ENABLE_AVX2)
854+
if (num_byte == 0) {
855+
return;
856+
}
857+
if (num_byte < 32) {
858+
return sse::BitOr(x, y, num_byte, result);
859+
}
860+
int64_t i = 0;
861+
for (; i + 31 < num_byte; i += 32) {
862+
__m256i x_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + i));
863+
__m256i y_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(y + i));
864+
__m256i z_vec = _mm256_or_si256(x_vec, y_vec);
865+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec);
866+
}
867+
if (i < num_byte) {
868+
sse::BitOr(x + i, y + i, num_byte - i, result + i);
869+
}
870+
#else
871+
return sse::BitOr(x, y, num_byte, result);
872+
#endif
873+
}
874+
875+
void
876+
BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
877+
#if defined(ENABLE_AVX2)
878+
if (num_byte == 0) {
879+
return;
880+
}
881+
if (num_byte < 32) {
882+
return sse::BitXor(x, y, num_byte, result);
883+
}
884+
int64_t i = 0;
885+
for (; i + 31 < num_byte; i += 32) {
886+
__m256i x_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + i));
887+
__m256i y_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(y + i));
888+
__m256i z_vec = _mm256_xor_si256(x_vec, y_vec);
889+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec);
890+
}
891+
if (i < num_byte) {
892+
sse::BitXor(x + i, y + i, num_byte - i, result + i);
893+
}
894+
#else
895+
return sse::BitXor(x, y, num_byte, result);
896+
#endif
897+
}
898+
899+
void
900+
BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) {
901+
#if defined(ENABLE_AVX2)
902+
if (num_byte == 0) {
903+
return;
904+
}
905+
if (num_byte < 32) {
906+
return sse::BitNot(x, num_byte, result);
907+
}
908+
int64_t i = 0;
909+
for (; i + 31 < num_byte; i += 32) {
910+
__m256i x_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + i));
911+
__m256i z_vec = _mm256_xor_si256(x_vec, _mm256_set1_epi8(0xFF));
912+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec);
913+
}
914+
if (i < num_byte) {
915+
sse::BitNot(x + i, num_byte - i, result + i);
916+
}
917+
#else
918+
return sse::BitNot(x, num_byte, result);
919+
#endif
920+
}
827921
} // namespace vsag::avx2

src/simd/avx512.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,4 +894,98 @@ PQFastScanLookUp32(const uint8_t* lookup_table,
894894
#endif
895895
}
896896

897+
void
898+
BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
899+
#if defined(ENABLE_AVX512)
900+
if (num_byte == 0) {
901+
return;
902+
}
903+
if (num_byte < 64) {
904+
return avx2::BitAnd(x, y, num_byte, result);
905+
}
906+
int64_t i = 0;
907+
for (; i + 63 < num_byte; i += 64) {
908+
__m512i x_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(x + i));
909+
__m512i y_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(y + i));
910+
__m512i z_vec = _mm512_and_si512(x_vec, y_vec);
911+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec);
912+
}
913+
if (i < num_byte) {
914+
avx2::BitAnd(x + i, y + i, num_byte - i, result + i);
915+
}
916+
#else
917+
return avx2::BitAnd(x, y, num_byte, result);
918+
#endif
919+
}
920+
921+
void
922+
BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
923+
#if defined(ENABLE_AVX512)
924+
if (num_byte == 0) {
925+
return;
926+
}
927+
if (num_byte < 64) {
928+
return avx2::BitOr(x, y, num_byte, result);
929+
}
930+
int64_t i = 0;
931+
for (; i + 63 < num_byte; i += 64) {
932+
__m512i x_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(x + i));
933+
__m512i y_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(y + i));
934+
__m512i z_vec = _mm512_or_si512(x_vec, y_vec);
935+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec);
936+
}
937+
if (i < num_byte) {
938+
avx2::BitOr(x + i, y + i, num_byte - i, result + i);
939+
}
940+
#else
941+
return avx2::BitOr(x, y, num_byte, result);
942+
#endif
943+
}
944+
945+
void
946+
BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) {
947+
#if defined(ENABLE_AVX512)
948+
if (num_byte == 0) {
949+
return;
950+
}
951+
if (num_byte < 64) {
952+
return avx2::BitXor(x, y, num_byte, result);
953+
}
954+
int64_t i = 0;
955+
for (; i + 63 < num_byte; i += 64) {
956+
__m512i x_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(x + i));
957+
__m512i y_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(y + i));
958+
__m512i z_vec = _mm512_xor_si512(x_vec, y_vec);
959+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec);
960+
}
961+
if (i < num_byte) {
962+
avx2::BitXor(x + i, y + i, num_byte - i, result + i);
963+
}
964+
#else
965+
return avx2::BitXor(x, y, num_byte, result);
966+
#endif
967+
}
968+
969+
void
970+
BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) {
971+
#if defined(ENABLE_AVX512)
972+
if (num_byte == 0) {
973+
return;
974+
}
975+
if (num_byte < 64) {
976+
return avx2::BitNot(x, num_byte, result);
977+
}
978+
int64_t i = 0;
979+
for (; i + 63 < num_byte; i += 64) {
980+
__m512i x_vec = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(x + i));
981+
__m512i z_vec = _mm512_xor_si512(x_vec, _mm512_set1_epi8(0xFF));
982+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec);
983+
}
984+
if (i < num_byte) {
985+
avx2::BitNot(x + i, num_byte - i, result + i);
986+
}
987+
#else
988+
return avx2::BitNot(x, num_byte, result);
989+
#endif
990+
}
897991
} // namespace vsag::avx512

0 commit comments

Comments
 (0)