Skip to content

Commit 5969058

Browse files
committed
Add support for AVX-VNNI and update SIMD implementation
1 parent 51dfac4 commit 5969058

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

CMakeLists.txt

+12-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ if(NOT KIWI_CPU_ARCH)
3838
set(KIWI_CPU_ARCH "${KIWI_CPU_ARCH}" PARENT_SCOPE)
3939
endif()
4040

41+
set( AVX_VNNI_SUPPORTED (KIWI_USE_CPUINFO AND
42+
(MSVC OR
43+
(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR
44+
(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
45+
)
46+
))
47+
4148
if(APPLE)
4249
set(CMAKE_OSX_ARCHITECTURES "${KIWI_CPU_ARCH}")
4350
endif()
@@ -120,6 +127,11 @@ if(KIWI_USE_CPUINFO)
120127
)
121128
endif()
122129

130+
if (AVX_VNNI_SUPPORTED)
131+
message(STATUS "AVX-VNNI is supported")
132+
set ( ADDITIONAL_FLAGS ${ADDITIONAL_FLAGS} "-DKIWI_AVX_VNNI_SUPPORTED" )
133+
endif()
134+
123135
if(MSVC)
124136
set ( CMAKE_C_FLAGS_DEBUG "-DDEBUG -DC_FLAGS -Zi -Od /utf-8 /bigobj" )
125137
set ( CMAKE_CXX_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}" )
@@ -161,7 +173,6 @@ if (KIWI_CPU_ARCH MATCHES "x86_64")
161173
src/archImpl/avx512vnni.cpp
162174
)
163175
# If AVX-VNNI is supported (MSVC, GCC 11+ or Clang 11+)
164-
set ( AVX_VNNI_SUPPORTED (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11)))
165176
if (AVX_VNNI_SUPPORTED)
166177
set( CORE_SRCS
167178
${CORE_SRCS}

src/ArchAvailable.h

+8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ namespace kiwi
1414
#if CPUINFO_ARCH_X86_64
1515
static_cast<std::ptrdiff_t>(ArchType::avx512vnni),
1616
static_cast<std::ptrdiff_t>(ArchType::avx512bw),
17+
#ifdef KIWI_AVX_VNNI_SUPPORTED
1718
static_cast<std::ptrdiff_t>(ArchType::avx_vnni),
19+
#endif
1820
static_cast<std::ptrdiff_t>(ArchType::avx2),
1921
static_cast<std::ptrdiff_t>(ArchType::sse4_1),
2022
#endif
@@ -28,7 +30,9 @@ namespace kiwi
2830
#ifdef KIWI_ARCH_X86_64
2931
static_cast<std::ptrdiff_t>(ArchType::avx512vnni),
3032
static_cast<std::ptrdiff_t>(ArchType::avx512bw),
33+
#ifdef KIWI_AVX_VNNI_SUPPORTED
3134
static_cast<std::ptrdiff_t>(ArchType::avx_vnni),
35+
#endif
3236
static_cast<std::ptrdiff_t>(ArchType::avx2),
3337
static_cast<std::ptrdiff_t>(ArchType::sse4_1),
3438
#endif
@@ -48,7 +52,9 @@ namespace kiwi
4852
#if CPUINFO_ARCH_X86_64
4953
static_cast<std::ptrdiff_t>(ArchType::avx512vnni),
5054
static_cast<std::ptrdiff_t>(ArchType::avx512bw),
55+
#ifdef KIWI_AVX_VNNI_SUPPORTED
5156
static_cast<std::ptrdiff_t>(ArchType::avx_vnni),
57+
#endif
5258
static_cast<std::ptrdiff_t>(ArchType::avx2),
5359
static_cast<std::ptrdiff_t>(ArchType::sse4_1)
5460
#endif
@@ -59,7 +65,9 @@ namespace kiwi
5965
#ifdef KIWI_ARCH_X86_64
6066
static_cast<std::ptrdiff_t>(ArchType::avx512vnni),
6167
static_cast<std::ptrdiff_t>(ArchType::avx512bw),
68+
#ifdef KIWI_AVX_VNNI_SUPPORTED
6269
static_cast<std::ptrdiff_t>(ArchType::avx_vnni),
70+
#endif
6371
static_cast<std::ptrdiff_t>(ArchType::avx2),
6472
static_cast<std::ptrdiff_t>(ArchType::sse4_1)
6573
#endif

src/SIMD.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ namespace kiwi
600600
// reduce sum of eight int32_t to one int32_t
601601
__m256i sum = _mm256_hadd_epi32(acc, acc);
602602
sum = _mm256_hadd_epi32(sum, sum);
603-
return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
603+
return _mm_cvtsi128_si32(_mm256_castsi256_si128(sum)) + _mm256_extract_epi32(sum, 4);
604604
}
605605
};
606606

0 commit comments

Comments
 (0)