|
7 | 7 | #include <arcticdb/util/vector_common.hpp>
|
8 | 8 |
|
9 | 9 | namespace arcticdb {
|
10 |
| -#ifndef _WIN32 |
11 | 10 |
|
12 |
| -template<typename T> |
13 |
| -class FloatMinFinder { |
| 11 | +#if HAS_VECTOR_EXTENSIONS |
| 12 | + |
| 13 | +template<typename T, typename Comparator> |
| 14 | +class FloatExtremumFinder { |
14 | 15 | static_assert(is_supported_float<T>::value, "Type must be float or double");
|
15 | 16 | static_assert(std::is_floating_point_v<T>, "Type must be floating point");
|
16 |
| - |
17 | 17 | public:
|
18 | 18 | static T find(const T* data, size_t n) {
|
| 19 | + if (n == 0) |
| 20 | + return Comparator::identity(); |
19 | 21 | using vec_t = vector_type<T>;
|
20 |
| - |
21 |
| - vec_t vmin; |
22 |
| - for(size_t i = 0; i < sizeof(vec_t)/sizeof(T); i++) { |
23 |
| - reinterpret_cast<T*>(&vmin)[i] = std::numeric_limits<T>::infinity(); |
24 |
| - } |
| 22 | + constexpr size_t lane_count = sizeof(vec_t) / sizeof(T); |
| 23 | + vec_t vext; |
| 24 | + for (size_t i = 0; i < lane_count; i++) |
| 25 | + reinterpret_cast<T*>(&vext)[i] = Comparator::identity(); |
25 | 26 |
|
26 | 27 | const vec_t* vdata = reinterpret_cast<const vec_t*>(data);
|
27 |
| - const size_t elements_per_vector = sizeof(vec_t) / sizeof(T); |
28 |
| - const size_t vlen = n / elements_per_vector; |
29 |
| - |
30 |
| - for(size_t i = 0; i < vlen; i++) { |
| 28 | + size_t vlen = n / lane_count; |
| 29 | + for (size_t i = 0; i < vlen; i++) { |
31 | 30 | vec_t v = vdata[i];
|
32 |
| - vmin = (v < vmin) ? v : vmin; |
| 31 | + if constexpr (Comparator::is_min) |
| 32 | + vext = (v < vext) ? v : vext; |
| 33 | + else |
| 34 | + vext = (v > vext) ? v : vext; |
33 | 35 | }
|
34 |
| - |
35 |
| - T min_val = std::numeric_limits<T>::infinity(); |
36 |
| - const T* min_arr = reinterpret_cast<const T*>(&vmin); |
37 |
| - for(size_t i = 0; i < elements_per_vector; i++) { |
38 |
| - if (min_arr[i] == min_arr[i]) { // Not NaN |
39 |
| - min_val = std::min(min_val, min_arr[i]); |
40 |
| - } |
| 36 | + T result = Comparator::identity(); |
| 37 | + const T* lanes = reinterpret_cast<const T*>(&vext); |
| 38 | + for (size_t i = 0; i < lane_count; i++) { |
| 39 | + if (lanes[i] == lanes[i]) |
| 40 | + result = Comparator::compare(lanes[i], result); |
41 | 41 | }
|
42 |
| - |
43 |
| - const T* remain = data + (vlen * elements_per_vector); |
44 |
| - for(size_t i = 0; i < n % elements_per_vector; i++) { |
45 |
| - if (remain[i] == remain[i]) { // Not NaN |
46 |
| - min_val = std::min(min_val, remain[i]); |
47 |
| - } |
| 42 | + const T* remain = data + (vlen * lane_count); |
| 43 | + size_t remain_count = n % lane_count; |
| 44 | + for (size_t i = 0; i < remain_count; i++) { |
| 45 | + if (remain[i] == remain[i]) |
| 46 | + result = Comparator::compare(remain[i], result); |
48 | 47 | }
|
49 |
| - |
50 |
| - return min_val; |
| 48 | + return result; |
51 | 49 | }
|
52 | 50 | };
|
53 | 51 |
|
54 | 52 | template<typename T>
|
55 |
| -class FloatMaxFinder { |
56 |
| - static_assert(is_supported_float<T>::value, "Type must be float or double"); |
57 |
| - static_assert(std::is_floating_point_v<T>, "Type must be floating point"); |
58 |
| - |
59 |
| -public: |
60 |
| - static T find(const T* data, size_t n) { |
61 |
| - using vec_t = vector_type<T>; |
62 |
| - |
63 |
| - vec_t vmax; |
64 |
| - for(size_t i = 0; i < sizeof(vec_t)/sizeof(T); i++) { |
65 |
| - reinterpret_cast<T*>(&vmax)[i] = -std::numeric_limits<T>::infinity(); |
66 |
| - } |
67 |
| - |
68 |
| - const vec_t* vdata = reinterpret_cast<const vec_t*>(data); |
69 |
| - const size_t elements_per_vector = sizeof(vec_t) / sizeof(T); |
70 |
| - const size_t vlen = n / elements_per_vector; |
71 |
| - |
72 |
| - for(size_t i = 0; i < vlen; i++) { |
73 |
| - vec_t v = vdata[i]; |
74 |
| - vmax = (v > vmax) ? v : vmax; |
75 |
| - } |
76 |
| - |
77 |
| - T max_val = -std::numeric_limits<T>::infinity(); |
78 |
| - const T* max_arr = reinterpret_cast<const T*>(&vmax); |
79 |
| - for(size_t i = 0; i < elements_per_vector; i++) { |
80 |
| - if (max_arr[i] == max_arr[i]) { // Not NaN |
81 |
| - max_val = std::max(max_val, max_arr[i]); |
82 |
| - } |
83 |
| - } |
84 |
| - |
85 |
| - const T* remain = data + (vlen * elements_per_vector); |
86 |
| - for(size_t i = 0; i < n % elements_per_vector; i++) { |
87 |
| - if (remain[i] == remain[i]) { // Not NaN |
88 |
| - max_val = std::max(max_val, remain[i]); |
89 |
| - } |
90 |
| - } |
| 53 | +struct FloatMinComparator { |
| 54 | + static constexpr bool is_min = true; |
| 55 | + static T identity() { return std::numeric_limits<T>::infinity(); } |
| 56 | + static T compare(T a, T b) { return std::min(a, b); } |
| 57 | +}; |
91 | 58 |
|
92 |
| - return max_val; |
93 |
| - } |
| 59 | +template<typename T> |
| 60 | +struct FloatMaxComparator { |
| 61 | + static constexpr bool is_min = false; |
| 62 | + static T identity() { return -std::numeric_limits<T>::infinity(); } |
| 63 | + static T compare(T a, T b) { return std::max(a, b); } |
94 | 64 | };
|
95 | 65 |
|
96 | 66 | template<typename T>
|
97 |
| -T find_float_min(const T *data, size_t n) { |
98 |
| - return FloatMinFinder<T>::find(data, n); |
| 67 | +T find_float_min(const T* data, size_t n) { |
| 68 | + return FloatExtremumFinder<T, FloatMinComparator<T>>::find(data, n); |
99 | 69 | }
|
100 | 70 |
|
101 | 71 | template<typename T>
|
102 |
| -T find_float_max(const T *data, size_t n) { |
103 |
| - return FloatMaxFinder<T>::find(data, n); |
| 72 | +T find_float_max(const T* data, size_t n) { |
| 73 | + return FloatExtremumFinder<T, FloatMaxComparator<T>>::find(data, n); |
104 | 74 | }
|
105 | 75 |
|
106 | 76 | #else
|
107 | 77 |
|
108 | 78 | template<typename T>
|
109 |
| -typename std::enable_if<std::is_integral<T>::value, T>::type |
| 79 | +typename std::enable_if<std::is_floating_point<T>::value, T>::type |
110 | 80 | find_float_min(const T *data, size_t n) {
|
111 | 81 | return *std::min_element(data, data + n);
|
112 | 82 | }
|
|
0 commit comments