Skip to content

Commit 6bc6693

Browse files
committed
add 64 bit scatter/gather options to the Avx2 operator class
1 parent 8286bb3 commit 6bc6693

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

shared/libebm/compute/avx2_ebm/avx2_32.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
106106
return Avx2_32_Int(_mm256_and_si256(m_data, other.m_data));
107107
}
108108

109+
friend inline Avx2_32_Int PermuteForInterleaf(const Avx2_32_Int& val) noexcept {
110+
// this function permutes the values into positions that the Interleaf function expects
111+
// but for any SIMD implementation the positions can be variable as long as they work together
112+
113+
// TODO: we might be able to move this operation to where we store the packed indexes so that
114+
// it doesn't need to execute in the tight loop
115+
116+
return Avx2_32_Int(_mm256_permutevar8x32_epi32(val.m_data, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)));
117+
}
118+
109119
private:
110120
inline Avx2_32_Int(const TPack& data) noexcept : m_data(data) {}
111121

@@ -219,6 +229,22 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
219229
return Avx2_32_Float(_mm256_i32gather_ps(a, i.m_data, 1 << cShift));
220230
}
221231

232+
template<int cShift>
233+
inline static void DoubleLoad(const T* const a,
234+
const Avx2_32_Int& i,
235+
Avx2_32_Float& ret1,
236+
Avx2_32_Float& ret2) noexcept {
237+
// i is treated as signed, so we should only use the lower 31 bits otherwise we'll read from memory before a
238+
static_assert(
239+
0 == cShift || 1 == cShift || 2 == cShift || 3 == cShift, "_mm256_i32gather_epi64 allows certain shift sizes");
240+
const __m128i i1 = _mm256_extracti128_si256(i.m_data, 0);
241+
// we're purposely using the 64-bit double version of this because we want to fetch the gradient
242+
// and hessian together in one operation
243+
ret1 = Avx2_32_Float(_mm256_i32gather_pd(reinterpret_cast<const double*>(a), i1, 1 << cShift));
244+
const __m128i i2 = _mm256_extracti128_si256(i.m_data, 1);
245+
ret2 = Avx2_32_Float(_mm256_i32gather_pd(reinterpret_cast<const double*>(a), i2, 1 << cShift));
246+
}
247+
222248
template<int cShift = k_cTypeShift>
223249
inline void Store(T* const a, const TInt& i) const noexcept {
224250
alignas(k_cAlignment) TInt::T ints[k_cSIMDPack];
@@ -242,6 +268,45 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
242268
*IndexByte(a, static_cast<size_t>(ints[7]) << cShift) = floats[7];
243269
}
244270

271+
template<int cShift>
272+
inline static void DoubleStore(T* const a,
273+
const TInt& i,
274+
const Avx2_32_Float& val1,
275+
const Avx2_32_Float& val2) noexcept {
276+
// i is treated as signed, so we should only use the lower 31 bits otherwise we'll read from memory before a
277+
278+
alignas(k_cAlignment) TInt::T ints[k_cSIMDPack];
279+
alignas(k_cAlignment) uint64_t floats1[k_cSIMDPack >> 1];
280+
alignas(k_cAlignment) uint64_t floats2[k_cSIMDPack >> 1];
281+
282+
i.Store(ints);
283+
val1.Store(reinterpret_cast<T*>(floats1));
284+
val2.Store(reinterpret_cast<T*>(floats2));
285+
286+
// if we shifted ints[] without converting to size_t first the compiler cannot
287+
// use the built in index shifting because ints could be 32 bits and shifting
288+
// right would chop off some bits, but when converted to size_t first then
289+
// that isn't an issue so the compiler can optimize the shift away and incorporate
290+
// it into the store assembly instruction
291+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[0]) << cShift) = floats1[0];
292+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[1]) << cShift) = floats1[1];
293+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[2]) << cShift) = floats1[2];
294+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[3]) << cShift) = floats1[3];
295+
296+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[4]) << cShift) = floats2[0];
297+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[5]) << cShift) = floats2[1];
298+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[6]) << cShift) = floats2[2];
299+
*IndexByte(reinterpret_cast<uint64_t*>(a), static_cast<size_t>(ints[7]) << cShift) = floats2[3];
300+
}
301+
302+
inline static void Interleaf(Avx2_32_Float& val0, Avx2_32_Float& val1) noexcept {
303+
// this function permutes the values into positions that the PermuteForInterleaf function expects
304+
// but for any SIMD implementation, the positions can be variable as long as they work together
305+
__m256 temp = _mm256_unpacklo_ps(val0.m_data, val1.m_data);
306+
val1 = Avx2_32_Float(_mm256_unpackhi_ps(val0.m_data, val1.m_data));
307+
val0 = Avx2_32_Float(temp);
308+
}
309+
245310
template<typename TFunc>
246311
friend inline Avx2_32_Float ApplyFunc(const TFunc& func, const Avx2_32_Float& val) noexcept {
247312
alignas(k_cAlignment) T aTemp[k_cSIMDPack];

0 commit comments

Comments
 (0)