8
8
namespace stdx = std::experimental;
9
9
namespace cp_algo ::math::fft {
10
10
using ftype = double ;
11
+ static constexpr size_t bytes = 32 ;
12
+ static constexpr size_t flen = bytes / sizeof (ftype);
11
13
using point = complex<ftype>;
12
- using vftype = stdx::native_simd< ftype> ;
14
+ using vftype [[gnu::vector_size(bytes)]] = ftype;
13
15
using vpoint = complex<vftype>;
14
- static constexpr size_t flen = vftype::size() ;
16
+ static constexpr vftype fz = {} ;
15
17
16
18
struct cvector {
17
19
std::vector<vftype> x, y;
@@ -117,15 +119,14 @@ namespace cp_algo::math::fft {
117
119
rt = -rt;
118
120
}
119
121
auto [Bvx, Bvy] = B.vget (k);
120
- auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (real (rt), imag (rt));
122
+ auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (fz + real (rt), fz + imag (rt));
121
123
auto [Ax, Ay] = A.vget (k);
122
124
vftype Bx[2 ] = {Brvx, Bvx}, By[2 ] = {Brvy, Bvy};
123
- vpoint res = {0 , 0 };
125
+ vpoint res = {fz, fz };
124
126
for (size_t i = 0 ; i < flen; i++) {
125
- vftype Bsx, Bsy;
126
- Bsx.copy_from ((ftype*)Bx + flen - i, stdx::element_aligned);
127
- Bsy.copy_from ((ftype*)By + flen - i, stdx::element_aligned);
128
- res += vpoint (Ax[i], Ay[i]) * vpoint{Bsx, Bsy};
127
+ auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128
+ auto Bsy = (vftype*)((ftype*)By + flen - i);
129
+ res += vpoint (fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
129
130
}
130
131
return res;
131
132
}
@@ -144,7 +145,7 @@ namespace cp_algo::math::fft {
144
145
if (4 * i <= n) { // radix-4
145
146
exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
146
147
k *= 4 * i;
147
- vpoint v1 = {real (rt), - imag (rt)};
148
+ vpoint v1 = {fz + real (rt), fz - imag (rt)};
148
149
vpoint v2 = v1 * v1;
149
150
vpoint v3 = v1 * v2;
150
151
for (size_t j = k; j < k + i; j += flen) {
@@ -154,15 +155,15 @@ namespace cp_algo::math::fft {
154
155
auto D = get<vpoint>(j + 3 * i);
155
156
set (j , (A + B + C + D));
156
157
set (j + 2 * i, (A + B - C - D) * v2);
157
- set (j + i, (A - B - vpoint (0 , 1 ) * (C - D)) * v1);
158
- set (j + 3 * i, (A - B + vpoint (0 , 1 ) * (C - D)) * v3);
158
+ set (j + i, (A - B - vpoint (fz, fz + 1 ) * (C - D)) * v1);
159
+ set (j + 3 * i, (A - B + vpoint (fz, fz + 1 ) * (C - D)) * v3);
159
160
}
160
161
});
161
162
i *= 2 ;
162
163
} else { // radix-2 fallback
163
164
exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
164
165
k *= 2 * i;
165
- vpoint cvrt = {real (rt), - imag (rt)};
166
+ vpoint cvrt = {fz + real (rt), fz - imag (rt)};
166
167
for (size_t j = k; j < k + i; j += flen) {
167
168
auto A = get<vpoint>(j) + get<vpoint>(j + i);
168
169
auto B = get<vpoint>(j) - get<vpoint>(j + i);
@@ -174,7 +175,7 @@ namespace cp_algo::math::fft {
174
175
}
175
176
checkpoint (" ifft" );
176
177
for (size_t k = 0 ; k < n; k += flen) {
177
- set (k, get<vpoint>(k) /= (ftype)(n / flen));
178
+ set (k, get<vpoint>(k) /= fz + (ftype)(n / flen));
178
179
}
179
180
}
180
181
void fft () {
@@ -184,7 +185,7 @@ namespace cp_algo::math::fft {
184
185
i /= 2 ;
185
186
exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
186
187
k *= 4 * i;
187
- vpoint v1 = {real (rt), imag (rt)};
188
+ vpoint v1 = {fz + real (rt), fz + imag (rt)};
188
189
vpoint v2 = v1 * v1;
189
190
vpoint v3 = v1 * v2;
190
191
for (size_t j = k; j < k + i; j += flen) {
@@ -194,14 +195,14 @@ namespace cp_algo::math::fft {
194
195
auto D = get<vpoint>(j + 3 * i) * v3;
195
196
set (j , (A + C) + (B + D));
196
197
set (j + i, (A + C) - (B + D));
197
- set (j + 2 * i, (A - C) + vpoint (0 , 1 ) * (B - D));
198
- set (j + 3 * i, (A - C) - vpoint (0 , 1 ) * (B - D));
198
+ set (j + 2 * i, (A - C) + vpoint (fz, fz + 1 ) * (B - D));
199
+ set (j + 3 * i, (A - C) - vpoint (fz, fz + 1 ) * (B - D));
199
200
}
200
201
});
201
202
} else { // radix-2 fallback
202
203
exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
203
204
k *= 2 * i;
204
- vpoint vrt = {real (rt), imag (rt)};
205
+ vpoint vrt = {fz + real (rt), fz + imag (rt)};
205
206
for (size_t j = k; j < k + i; j += flen) {
206
207
auto t = get<vpoint>(j + i) * vrt;
207
208
set (j + i, get<vpoint>(j) - t);
0 commit comments