@@ -14,7 +14,69 @@ namespace cp_algo::math::fft {
14
14
static constexpr size_t flen = vftype::size();
15
15
16
16
struct cvector {
17
- static constexpr size_t pre_roots = 1 << 15 ;
17
+ static constexpr size_t pre_roots = 1 << 16 ;
18
+ static constexpr std::array<point, pre_roots> roots = []() {
19
+ std::array<point, pre_roots> res = {};
20
+ for (size_t n = 1 ; n < res.size (); n *= 2 ) {
21
+ for (size_t k = 0 ; k < n; k++) {
22
+ res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
23
+ }
24
+ }
25
+ return res;
26
+ }();
27
+ static constexpr std::array<size_t , pre_roots> eval_args = []() {
28
+ std::array<size_t , pre_roots> res = {};
29
+ for (size_t i = 1 ; i < pre_roots; i++) {
30
+ res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
31
+ }
32
+ return res;
33
+ }();
34
+ static constexpr std::array<point, pre_roots> evalp = []() {
35
+ std::array<point, pre_roots> res = {};
36
+ res[0 ] = 1 ;
37
+ for (size_t n = 1 ; n < pre_roots; n++) {
38
+ res[n] = polar (1 ., std::numbers::pi * ftype (eval_args[n]) / ftype (2 * std::bit_floor (n)));
39
+ }
40
+ return res;
41
+ }();
42
+ static size_t eval_arg (size_t n) {
43
+ if (n < pre_roots) {
44
+ return eval_args[n];
45
+ } else {
46
+ return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
47
+ }
48
+ }
49
+ static auto root (size_t n, size_t k) {
50
+ if (n < pre_roots) {
51
+ return roots[n + k];
52
+ } else {
53
+ return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
54
+ }
55
+ }
56
+ static point eval_point (size_t n) {
57
+ if (n < pre_roots) {
58
+ return evalp[n];
59
+ } else {
60
+ return root (2 * std::bit_floor (n), eval_arg (n));
61
+ }
62
+ }
63
+ static void exec_on_roots (size_t n, size_t m, auto &&callback) {
64
+ point cur;
65
+ point arg = root (n, 1 );
66
+ for (size_t i = 0 ; i < m; i++) {
67
+ if (i % 32 == 0 || n < pre_roots) {
68
+ cur = root (n, i);
69
+ } else {
70
+ cur *= arg;
71
+ }
72
+ callback (i, cur);
73
+ }
74
+ }
75
+ static void exec_on_evals (size_t n, auto &&callback) {
76
+ for (size_t i = 0 ; i < n; i++) {
77
+ callback (i, eval_point (i));
78
+ }
79
+ }
18
80
std::vector<vftype> x, y;
19
81
cvector (size_t n) {
20
82
n = std::max (flen, std::bit_ceil (n));
@@ -80,57 +142,6 @@ namespace cp_algo::math::fft {
80
142
}
81
143
checkpoint (" dot" );
82
144
}
83
- static const cvector roots, evalp;
84
- static std::array<size_t , pre_roots> eval_args;
85
-
86
- template <bool precalc = false >
87
- static size_t eval_arg (size_t n) {
88
- if (n < pre_roots && !precalc) {
89
- return eval_args[n];
90
- } else if (n == 0 ) {
91
- return 0 ;
92
- } else {
93
- return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
94
- }
95
- }
96
- template < bool precalc = false >
97
- static auto root (size_t n, size_t k) {
98
- if (n < pre_roots && !precalc) {
99
- return roots.get (n + k);
100
- } else {
101
- return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
102
- }
103
- }
104
- template < bool precalc = false >
105
- static point eval_point (size_t n) {
106
- if (n < pre_roots && !precalc) {
107
- return evalp.get (n);
108
- } else if (n == 0 ) {
109
- return 1 ;
110
- } else {
111
- size_t N = std::bit_floor (n);
112
- return root (2 * N, eval_arg (n));
113
- }
114
- }
115
-
116
- template <bool precalc = false >
117
- static void exec_on_roots (size_t n, size_t m, auto &&callback) {
118
- point cur;
119
- point arg = root<precalc>(n, 1 );
120
- for (size_t i = 0 ; i < m; i++) {
121
- if (precalc || i % 32 == 0 || n < pre_roots) {
122
- cur = root<precalc>(n, i);
123
- } else {
124
- cur *= arg;
125
- }
126
- callback (i, cur);
127
- }
128
- }
129
- static void exec_on_evals (size_t n, auto &&callback) {
130
- for (size_t i = 0 ; i < n; i++) {
131
- callback (i, eval_point (i));
132
- }
133
- }
134
145
135
146
void ifft () {
136
147
size_t n = size ();
@@ -167,28 +178,5 @@ namespace cp_algo::math::fft {
167
178
checkpoint (" fft" );
168
179
}
169
180
};
170
- std::array<size_t , cvector::pre_roots> cvector::eval_args = []() {
171
- std::array<size_t , pre_roots> res = {};
172
- for (size_t i = 1 ; i < pre_roots; i++) {
173
- res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
174
- }
175
- return res;
176
- }();
177
- const cvector cvector::roots = []() {
178
- cvector res (pre_roots);
179
- for (size_t n = 1 ; n < res.size (); n *= 2 ) {
180
- cvector::exec_on_roots<true >(n, n, [&](size_t k, auto rt) {
181
- res.set (n + k, rt);
182
- });
183
- }
184
- return res;
185
- }();
186
- const cvector cvector::evalp = []() {
187
- cvector res (pre_roots);
188
- for (size_t n = 0 ; n < res.size (); n++) {
189
- res.set (n, cvector::eval_point<true >(n));
190
- }
191
- return res;
192
- }();
193
181
}
194
182
#endif // CP_ALGO_MATH_CVECTOR_HPP
0 commit comments