@@ -13,67 +13,40 @@ namespace cp_algo::math::fft {
13
13
using point = complex<ftype>;
14
14
using vftype [[gnu::vector_size(bytes)]] = ftype;
15
15
using vpoint = complex<vftype>;
16
- static constexpr vftype fz = {};
16
+ static constexpr vftype vz = {};
17
+ static constexpr vpoint vi = {vz, vz + 1 };
17
18
18
19
struct cvector {
19
- std::vector<vftype> x, y ;
20
+ std::vector<vpoint> r ;
20
21
cvector (size_t n) {
21
22
n = std::max (flen, std::bit_ceil (n));
22
- x.resize (n / flen);
23
- y.resize (n / flen);
23
+ r.resize (n / flen);
24
24
checkpoint (" cvector create" );
25
25
}
26
+
27
+ vpoint& at (size_t k) {return r[k / flen];}
28
+ vpoint at (size_t k) const {return r[k / flen];}
26
29
template <class pt = point>
27
30
void set (size_t k, pt t) {
28
31
if constexpr (std::is_same_v<pt, point>) {
29
- x [k / flen][k % flen] = real (t);
30
- y [k / flen][k % flen] = imag (t);
32
+ real (r [k / flen]) [k % flen] = real (t);
33
+ imag (r [k / flen]) [k % flen] = imag (t);
31
34
} else {
32
- x[k / flen] = real (t);
33
- y[k / flen] = imag (t);
35
+ at (k) = t;
34
36
}
35
37
}
36
38
template <class pt = point>
37
39
pt get (size_t k) const {
38
40
if constexpr (std::is_same_v<pt, point>) {
39
- return {x [k / flen][k % flen], y [k / flen][k % flen]};
41
+ return {real (r [k / flen]) [k % flen], imag (r [k / flen]) [k % flen]};
40
42
} else {
41
- return {x[k / flen], y[k / flen]} ;
43
+ return at (k) ;
42
44
}
43
45
}
44
- vpoint vget (size_t k) const {
45
- return get<vpoint>(k);
46
- }
47
46
48
47
size_t size () const {
49
- return flen * std::size (x );
48
+ return flen * std::size (r );
50
49
}
51
-
52
- static constexpr size_t pre_roots = 1 << 16 ;
53
- static constexpr std::array<point, pre_roots> roots = []() {
54
- std::array<point, pre_roots> res = {};
55
- for (size_t n = 1 ; n < res.size (); n *= 2 ) {
56
- for (size_t k = 0 ; k < n; k++) {
57
- res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
58
- }
59
- }
60
- return res;
61
- }();
62
- static constexpr std::array<size_t , pre_roots> eval_args = []() {
63
- std::array<size_t , pre_roots> res = {};
64
- for (size_t i = 1 ; i < pre_roots; i++) {
65
- res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
66
- }
67
- return res;
68
- }();
69
- static constexpr std::array<point, pre_roots> evalp = []() {
70
- std::array<point, pre_roots> res = {};
71
- res[0 ] = 1 ;
72
- for (size_t n = 1 ; n < pre_roots; n++) {
73
- res[n] = polar (1 ., std::numbers::pi * ftype (eval_args[n]) / ftype (2 * std::bit_floor (n)));
74
- }
75
- return res;
76
- }();
77
50
static size_t eval_arg (size_t n) {
78
51
if (n < pre_roots) {
79
52
return eval_args[n];
@@ -118,15 +91,17 @@ namespace cp_algo::math::fft {
118
91
if (k / flen % 2 ) {
119
92
rt = -rt;
120
93
}
121
- auto [Bvx, Bvy] = B.vget (k);
122
- auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (fz + real (rt), fz + imag (rt));
123
- auto [Ax, Ay] = A.vget (k);
124
- vftype Bx[2 ] = {Brvx, Bvx}, By[2 ] = {Brvy, Bvy};
125
- vpoint res = {fz, fz};
94
+ auto [Ax, Ay] = A.at (k);
95
+ auto Bv = B.at (k);
96
+ vpoint res = {vz, vz};
126
97
for (size_t i = 0 ; i < flen; i++) {
127
- auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128
- auto Bsy = (vftype*)((ftype*)By + flen - i);
129
- res += vpoint (fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
98
+ res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
99
+ real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
100
+ imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
101
+ auto x = real (Bv)[0 ];
102
+ auto y = imag (Bv)[0 ];
103
+ real (Bv)[0 ] = x * real (rt) - y * imag (rt);
104
+ imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
130
105
}
131
106
return res;
132
107
}
@@ -145,37 +120,36 @@ namespace cp_algo::math::fft {
145
120
if (4 * i <= n) { // radix-4
146
121
exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
147
122
k *= 4 * i;
148
- vpoint v1 = {fz + real (rt), fz - imag (rt)};
123
+ vpoint v1 = {vz + real (rt), vz - imag (rt)};
149
124
vpoint v2 = v1 * v1;
150
125
vpoint v3 = v1 * v2;
151
126
for (size_t j = k; j < k + i; j += flen) {
152
- auto A = get<vpoint> (j);
153
- auto B = get<vpoint> (j + i);
154
- auto C = get<vpoint> (j + 2 * i);
155
- auto D = get<vpoint> (j + 3 * i);
156
- set (j , (A + B + C + D) );
157
- set (j + 2 * i, (A + B - C - D) * v2) ;
158
- set (j + i, (A - B - vpoint (fz, fz + 1 ) * (C - D)) * v1) ;
159
- set (j + 3 * i, (A - B + vpoint (fz, fz + 1 ) * (C - D)) * v3) ;
127
+ auto A = at (j);
128
+ auto B = at (j + i);
129
+ auto C = at (j + 2 * i);
130
+ auto D = at (j + 3 * i);
131
+ at (j) = (A + B + C + D);
132
+ at (j + 2 * i) = (A + B - C - D) * v2;
133
+ at (j + i) = (A - B - vi * (C - D)) * v1;
134
+ at (j + 3 * i) = (A - B + vi * (C - D)) * v3;
160
135
}
161
136
});
162
137
i *= 2 ;
163
138
} else { // radix-2 fallback
164
139
exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
165
140
k *= 2 * i;
166
- vpoint cvrt = {fz + real (rt), fz - imag (rt)};
141
+ vpoint cvrt = {vz + real (rt), vz - imag (rt)};
167
142
for (size_t j = k; j < k + i; j += flen) {
168
- auto A = get<vpoint>(j) + get<vpoint>(j + i);
169
- auto B = get<vpoint>(j) - get<vpoint>(j + i);
170
- set (j, A);
171
- set (j + i, B * cvrt);
143
+ auto B = at (j) - at (j + i);
144
+ at (j) += at (j + i);
145
+ at (j + i) = B * cvrt;
172
146
}
173
147
});
174
148
}
175
149
}
176
150
checkpoint (" ifft" );
177
151
for (size_t k = 0 ; k < n; k += flen) {
178
- set (k, get<vpoint>(k) /= fz + (ftype)(n / flen));
152
+ set (k, get<vpoint>(k) /= vz + (ftype)(n / flen));
179
153
}
180
154
}
181
155
void fft () {
@@ -185,34 +159,59 @@ namespace cp_algo::math::fft {
185
159
i /= 2 ;
186
160
exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
187
161
k *= 4 * i;
188
- vpoint v1 = {fz + real (rt), fz + imag (rt)};
162
+ vpoint v1 = {vz + real (rt), vz + imag (rt)};
189
163
vpoint v2 = v1 * v1;
190
164
vpoint v3 = v1 * v2;
191
165
for (size_t j = k; j < k + i; j += flen) {
192
- auto A = get<vpoint> (j);
193
- auto B = get<vpoint> (j + i) * v1;
194
- auto C = get<vpoint> (j + 2 * i) * v2;
195
- auto D = get<vpoint> (j + 3 * i) * v3;
196
- set (j , (A + C) + (B + D) );
197
- set (j + i, (A + C) - (B + D) );
198
- set (j + 2 * i, (A - C) + vpoint (fz, fz + 1 ) * (B - D) );
199
- set (j + 3 * i, (A - C) - vpoint (fz, fz + 1 ) * (B - D) );
166
+ auto A = at (j);
167
+ auto B = at (j + i) * v1;
168
+ auto C = at (j + 2 * i) * v2;
169
+ auto D = at (j + 3 * i) * v3;
170
+ at (j) = (A + C) + (B + D);
171
+ at (j + i) = (A + C) - (B + D);
172
+ at (j + 2 * i) = (A - C) + vi * (B - D);
173
+ at (j + 3 * i) = (A - C) - vi * (B - D);
200
174
}
201
175
});
202
176
} else { // radix-2 fallback
203
177
exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
204
178
k *= 2 * i;
205
- vpoint vrt = {fz + real (rt), fz + imag (rt)};
179
+ vpoint vrt = {vz + real (rt), vz + imag (rt)};
206
180
for (size_t j = k; j < k + i; j += flen) {
207
- auto t = get<vpoint> (j + i) * vrt;
208
- set (j + i, get<vpoint> (j) - t) ;
209
- set (j, get<vpoint>(j ) + t) ;
181
+ auto t = at (j + i) * vrt;
182
+ at (j + i) = at (j) - t;
183
+ at (j ) += t ;
210
184
}
211
185
});
212
186
}
213
187
}
214
188
checkpoint (" fft" );
215
189
}
190
+ static constexpr size_t pre_roots = 1 << 16 ;
191
+ static constexpr std::array<point, pre_roots> roots = []() {
192
+ std::array<point, pre_roots> res = {};
193
+ for (size_t n = 1 ; n < res.size (); n *= 2 ) {
194
+ for (size_t k = 0 ; k < n; k++) {
195
+ res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
196
+ }
197
+ }
198
+ return res;
199
+ }();
200
+ static constexpr std::array<size_t , pre_roots> eval_args = []() {
201
+ std::array<size_t , pre_roots> res = {};
202
+ for (size_t i = 1 ; i < pre_roots; i++) {
203
+ res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
204
+ }
205
+ return res;
206
+ }();
207
+ static constexpr std::array<point, pre_roots> evalp = []() {
208
+ std::array<point, pre_roots> res = {};
209
+ res[0 ] = 1 ;
210
+ for (size_t n = 1 ; n < pre_roots; n++) {
211
+ res[n] = polar (1 ., std::numbers::pi * ftype (eval_args[n]) / ftype (2 * std::bit_floor (n)));
212
+ }
213
+ return res;
214
+ }();
216
215
};
217
216
}
218
217
#endif // CP_ALGO_MATH_CVECTOR_HPP
0 commit comments