1
1
#ifndef CP_ALGO_MATH_CVECTOR_HPP
2
2
#define CP_ALGO_MATH_CVECTOR_HPP
3
- #include < algorithm>
4
- #include < cassert>
5
- #include < complex>
6
- #include < vector>
3
+ #include " ../util/complex.hpp"
4
+ #include " ../util/checkpoint.hpp"
5
+ #include < experimental/simd>
7
6
#include < ranges>
8
7
namespace cp_algo ::math::fft {
9
8
using ftype = double ;
10
- static constexpr size_t bytes = 32 ;
11
- static constexpr size_t flen = bytes / sizeof (ftype);
12
- using point = std::complex<ftype>;
13
- using vftype [[gnu::vector_size(bytes)]] = ftype;
14
- using vpoint = std::complex<vftype>;
9
+ using point = complex<ftype>;
10
+ using vftype = std::experimental::native_simd<ftype>;
11
+ using vpoint = complex<vftype>;
12
+ static constexpr size_t flen = vftype::size();
15
13
16
- #define WITH_IV (...) \
17
- [&]<size_t ... i>(std::index_sequence<i...>) { \
18
- return __VA_ARGS__; \
19
- }(std::make_index_sequence<flen>());
20
-
21
- template <typename ft>
22
- constexpr ft to_ft (auto x) {
23
- return ft{} + x;
24
- }
25
- template <typename pt>
26
- constexpr pt to_pt (point r) {
27
- using ft = std::conditional_t <std::is_same_v<point, pt>, ftype, vftype>;
28
- return {to_ft<ft>(r.real ()), to_ft<ft>(r.imag ())};
29
- }
30
14
struct cvector {
31
- static constexpr size_t pre_roots = 1 << 17 ;
15
+ static constexpr size_t pre_roots = 1 << 15 ;
32
16
std::vector<vftype> x, y;
33
17
cvector (size_t n) {
34
18
n = std::max (flen, std::bit_ceil (n));
35
19
x.resize (n / flen);
36
20
y.resize (n / flen);
21
+ checkpoint (" cvector create" );
37
22
}
38
23
template <class pt = point>
39
24
void set (size_t k, pt t) {
@@ -60,132 +45,147 @@ namespace cp_algo::math::fft {
60
45
size_t size () const {
61
46
return flen * std::size (x);
62
47
}
48
+
49
+
50
+ static auto dot_block (size_t k, cvector const & A, cvector const & B) {
51
+ auto rt = eval_point (k / flen / 2 );
52
+ if (k / flen % 2 ) {
53
+ rt = -rt;
54
+ }
55
+ auto [Bvx, Bvy] = B.vget (k);
56
+ auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (real (rt), imag (rt));
57
+ auto [Ax, Ay] = A.vget (k);
58
+ ftype Bx[2 * flen], By[2 * flen];
59
+ Bvx.copy_to (Bx + flen, std::experimental::vector_aligned);
60
+ Bvy.copy_to (By + flen, std::experimental::vector_aligned);
61
+ Brvx.copy_to (Bx, std::experimental::vector_aligned);
62
+ Brvy.copy_to (By, std::experimental::vector_aligned);
63
+ vpoint res = {0 , 0 };
64
+ for (size_t i = 0 ; i < flen; i++) {
65
+ vftype Bsx, Bsy;
66
+ Bsx.copy_from (Bx + flen - i, std::experimental::element_aligned);
67
+ Bsy.copy_from (By + flen - i, std::experimental::element_aligned);
68
+ res += vpoint (Ax[i], Ay[i]) * vpoint (Bsx, Bsy);
69
+ }
70
+ return res;
71
+ }
72
+
63
73
void dot (cvector const & t) {
64
- size_t n = size ();
74
+ size_t n = this -> size ();
65
75
for (size_t k = 0 ; k < n; k += flen) {
66
- set (k, get<vpoint>(k) * t. get <vpoint>(k ));
76
+ set (k, dot_block (k, * this , t ));
67
77
}
78
+ checkpoint (" dot" );
68
79
}
69
- static const cvector roots;
70
- template <class pt = point>
71
- static pt root (size_t n, size_t k) {
72
- if (n < pre_roots) {
73
- return roots.get <pt>(n + k);
80
+ static const cvector roots, evalp;
81
+ static std::array<size_t , pre_roots> eval_args;
82
+
83
+ template <bool precalc = false >
84
+ static size_t eval_arg (size_t n) {
85
+ if (n < pre_roots && !precalc) {
86
+ return eval_args[n];
87
+ } else if (n == 0 ) {
88
+ return 0 ;
74
89
} else {
75
- auto arg = std::numbers::pi / ftype (n);
76
- if constexpr (std::is_same_v<pt, point>) {
77
- return {cos (ftype (k) * arg), sin (ftype (k) * arg)};
78
- } else {
79
- return WITH_IV (pt{vftype{cos (ftype (k + i) * arg)...},
80
- vftype{sin (ftype (k + i) * arg)...}});
81
- }
90
+ return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
82
91
}
83
92
}
84
- template <class pt = point>
93
+ template < bool precalc = false >
94
+ static auto root (size_t n, size_t k) {
95
+ if (n < pre_roots && !precalc) {
96
+ return roots.get (n + k);
97
+ } else {
98
+ return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
99
+ }
100
+ }
101
+ template < bool precalc = false >
102
+ static point eval_point (size_t n) {
103
+ if (n < pre_roots && !precalc) {
104
+ return evalp.get (n);
105
+ } else if (n == 0 ) {
106
+ return 1 ;
107
+ } else {
108
+ size_t N = std::bit_floor (n);
109
+ return root (2 * N, eval_arg (n));
110
+ }
111
+ }
112
+
113
+ template <bool precalc = false >
85
114
static void exec_on_roots (size_t n, size_t m, auto &&callback) {
86
- size_t step = sizeof (pt) / sizeof (point);
87
- pt cur;
88
- pt arg = to_pt<pt>(root<point>(n, step));
89
- for (size_t i = 0 ; i < m; i += step) {
90
- if (i % 64 == 0 || n < pre_roots) {
91
- cur = root<pt>(n, i);
115
+ point cur;
116
+ point arg = root<precalc>(n, 1 );
117
+ for (size_t i = 0 ; i < m; i++) {
118
+ if (precalc || i % 32 == 0 || n < pre_roots) {
119
+ cur = root<precalc>(n, i);
92
120
} else {
93
121
cur *= arg;
94
122
}
95
123
callback (i, cur);
96
124
}
97
125
}
126
+ static void exec_on_evals (size_t n, auto &&callback) {
127
+ for (size_t i = 0 ; i < n; i++) {
128
+ callback (i, eval_point (i));
129
+ }
130
+ }
98
131
99
132
void ifft () {
100
133
size_t n = size ();
101
- for (size_t i = 1 ; i < n; i *= 2 ) {
102
- for (size_t j = 0 ; j < n; j += 2 * i) {
103
- auto butterfly = [&]<class pt >(size_t k, pt rt) {
104
- k += j;
105
- auto t = get<pt>(k + i) * conj (rt);
106
- set (k + i, get<pt>(k) - t);
107
- set (k, get<pt>(k) + t);
108
- };
109
- if (2 * i <= flen) {
110
- exec_on_roots (i, i, butterfly);
111
- } else {
112
- exec_on_roots<vpoint>(i, i, butterfly);
134
+ for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
135
+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
136
+ k *= 2 * i;
137
+ vpoint vrt = {real (rt), imag (rt)};
138
+ for (size_t j = k; j < k + i; j += flen) {
139
+ auto A = get<vpoint>(j) + get<vpoint>(j + i);
140
+ auto B = get<vpoint>(j) - get<vpoint>(j + i);
141
+ set (j, A);
142
+ set (j + i, B * conj (vrt));
113
143
}
114
- }
144
+ });
115
145
}
146
+ checkpoint (" ifft" );
116
147
for (size_t k = 0 ; k < n; k += flen) {
117
- set (k, get<vpoint>(k) /= to_pt<vpoint> (ftype (n) ));
148
+ set (k, get<vpoint>(k) /= (ftype)(n / flen ));
118
149
}
119
150
}
120
151
void fft () {
121
152
size_t n = size ();
122
- for (size_t i = n / 2 ; i >= 1 ; i /= 2 ) {
123
- for (size_t j = 0 ; j < n; j += 2 * i) {
124
- auto butterfly = [&]<class pt >(size_t k, pt rt) {
125
- k += j;
126
- auto A = get<pt>(k) + get<pt>(k + i);
127
- auto B = get<pt>(k) - get<pt>(k + i);
128
- set (k, A);
129
- set (k + i, B * rt);
130
- };
131
- if (2 * i <= flen) {
132
- exec_on_roots (i, i, butterfly);
133
- } else {
134
- exec_on_roots<vpoint>(i, i, butterfly);
153
+ for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
154
+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
155
+ k *= 2 * i;
156
+ vpoint vrt = {real (rt), imag (rt)};
157
+ for (size_t j = k; j < k + i; j += flen) {
158
+ auto t = get<vpoint>(j + i) * vrt;
159
+ set (j + i, get<vpoint>(j) - t);
160
+ set (j, get<vpoint>(j) + t);
135
161
}
136
- }
162
+ });
137
163
}
164
+ checkpoint (" fft" );
138
165
}
139
166
};
167
+ std::array<size_t , cvector::pre_roots> cvector::eval_args = []() {
168
+ std::array<size_t , pre_roots> res = {};
169
+ for (size_t i = 1 ; i < pre_roots; i++) {
170
+ res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
171
+ }
172
+ return res;
173
+ }();
140
174
const cvector cvector::roots = []() {
141
175
cvector res (pre_roots);
142
176
for (size_t n = 1 ; n < res.size (); n *= 2 ) {
143
- auto base = std::polar (1 ., std::numbers::pi / ftype (n));
144
- point cur = 1 ;
145
- for (size_t k = 0 ; k < n; k++) {
146
- if ((k & 15 ) == 0 ) {
147
- cur = std::polar (1 ., std::numbers::pi * ftype (k) / ftype (n));
148
- }
149
- res.set (n + k, cur);
150
- cur *= base;
151
- }
177
+ cvector::exec_on_roots<true >(n, n, [&](size_t k, auto rt) {
178
+ res.set (n + k, rt);
179
+ });
152
180
}
153
181
return res;
154
182
}();
155
-
156
- template <typename base>
157
- struct dft {
158
- cvector A;
159
-
160
- dft (std::vector<base> const & a, size_t n): A(n) {
161
- for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
162
- A.set (i, a[i]);
163
- }
164
- if (n) {
165
- A.fft ();
166
- }
167
- }
168
-
169
- std::vector<base> operator *= (dft const & B) {
170
- assert (A.size () == B.A .size ());
171
- size_t n = A.size ();
172
- if (!n) {
173
- return std::vector<base>();
174
- }
175
- A.dot (B.A );
176
- A.ifft ();
177
- std::vector<base> res (n);
178
- for (size_t k = 0 ; k < n; k++) {
179
- res[k] = A.get (k);
180
- }
181
- return res;
182
- }
183
-
184
- auto operator * (dft const & B) const {
185
- return dft (*this ) *= B;
183
+ const cvector cvector::evalp = []() {
184
+ cvector res (pre_roots);
185
+ for (size_t n = 0 ; n < res.size (); n++) {
186
+ res.set (n, cvector::eval_point<true >(n));
186
187
}
187
-
188
- point operator [](int i) const {return A.get (i);}
189
- };
188
+ return res;
189
+ }();
190
190
}
191
191
#endif // CP_ALGO_MATH_CVECTOR_HPP
0 commit comments