Skip to content

Commit dc87cf5

Browse files
committed
improve fft, add "checkpoints"
1 parent b4a5edd commit dc87cf5

File tree

5 files changed

+188
-124
lines changed

5 files changed

+188
-124
lines changed

cp-algo/math/cvector.hpp

Lines changed: 116 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,24 @@
11
#ifndef CP_ALGO_MATH_CVECTOR_HPP
22
#define CP_ALGO_MATH_CVECTOR_HPP
3-
#include <algorithm>
4-
#include <cassert>
5-
#include <complex>
6-
#include <vector>
3+
#include "../util/complex.hpp"
4+
#include "../util/checkpoint.hpp"
5+
#include <experimental/simd>
76
#include <ranges>
87
namespace cp_algo::math::fft {
98
using ftype = double;
10-
static constexpr size_t bytes = 32;
11-
static constexpr size_t flen = bytes / sizeof(ftype);
12-
using point = std::complex<ftype>;
13-
using vftype [[gnu::vector_size(bytes)]] = ftype;
14-
using vpoint = std::complex<vftype>;
9+
using point = complex<ftype>;
10+
using vftype = std::experimental::native_simd<ftype>;
11+
using vpoint = complex<vftype>;
12+
static constexpr size_t flen = vftype::size();
1513

16-
#define WITH_IV(...) \
17-
[&]<size_t ... i>(std::index_sequence<i...>) { \
18-
return __VA_ARGS__; \
19-
}(std::make_index_sequence<flen>());
20-
21-
template<typename ft>
22-
constexpr ft to_ft(auto x) {
23-
return ft{} + x;
24-
}
25-
template<typename pt>
26-
constexpr pt to_pt(point r) {
27-
using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>;
28-
return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())};
29-
}
3014
struct cvector {
31-
static constexpr size_t pre_roots = 1 << 17;
15+
static constexpr size_t pre_roots = 1 << 15;
3216
std::vector<vftype> x, y;
3317
cvector(size_t n) {
3418
n = std::max(flen, std::bit_ceil(n));
3519
x.resize(n / flen);
3620
y.resize(n / flen);
21+
checkpoint("cvector create");
3722
}
3823
template<class pt = point>
3924
void set(size_t k, pt t) {
@@ -60,132 +45,147 @@ namespace cp_algo::math::fft {
6045
size_t size() const {
6146
return flen * std::size(x);
6247
}
48+
49+
50+
static auto dot_block(size_t k, cvector const& A, cvector const& B) {
51+
auto rt = eval_point(k / flen / 2);
52+
if(k / flen % 2) {
53+
rt = -rt;
54+
}
55+
auto [Bvx, Bvy] = B.vget(k);
56+
auto [Brvx, Brvy] = vpoint(Bvx, Bvy) * vpoint(real(rt), imag(rt));
57+
auto [Ax, Ay] = A.vget(k);
58+
ftype Bx[2 * flen], By[2 * flen];
59+
Bvx.copy_to(Bx + flen, std::experimental::vector_aligned);
60+
Bvy.copy_to(By + flen, std::experimental::vector_aligned);
61+
Brvx.copy_to(Bx, std::experimental::vector_aligned);
62+
Brvy.copy_to(By, std::experimental::vector_aligned);
63+
vpoint res = {0, 0};
64+
for(size_t i = 0; i < flen; i++) {
65+
vftype Bsx, Bsy;
66+
Bsx.copy_from(Bx + flen - i, std::experimental::element_aligned);
67+
Bsy.copy_from(By + flen - i, std::experimental::element_aligned);
68+
res += vpoint(Ax[i], Ay[i]) * vpoint(Bsx, Bsy);
69+
}
70+
return res;
71+
}
72+
6373
void dot(cvector const& t) {
64-
size_t n = size();
74+
size_t n = this->size();
6575
for(size_t k = 0; k < n; k += flen) {
66-
set(k, get<vpoint>(k) * t.get<vpoint>(k));
76+
set(k, dot_block(k, *this, t));
6777
}
78+
checkpoint("dot");
6879
}
69-
static const cvector roots;
70-
template<class pt = point>
71-
static pt root(size_t n, size_t k) {
72-
if(n < pre_roots) {
73-
return roots.get<pt>(n + k);
80+
static const cvector roots, evalp;
81+
static std::array<size_t, pre_roots> eval_args;
82+
83+
template<bool precalc = false>
84+
static size_t eval_arg(size_t n) {
85+
if(n < pre_roots && !precalc) {
86+
return eval_args[n];
87+
} else if(n == 0) {
88+
return 0;
7489
} else {
75-
auto arg = std::numbers::pi / ftype(n);
76-
if constexpr(std::is_same_v<pt, point>) {
77-
return {cos(ftype(k) * arg), sin(ftype(k) * arg)};
78-
} else {
79-
return WITH_IV(pt{vftype{cos(ftype(k + i) * arg)...},
80-
vftype{sin(ftype(k + i) * arg)...}});
81-
}
90+
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
8291
}
8392
}
84-
template<class pt = point>
93+
template< bool precalc = false>
94+
static auto root(size_t n, size_t k) {
95+
if(n < pre_roots && !precalc) {
96+
return roots.get(n + k);
97+
} else {
98+
return polar(1., std::numbers::pi / (ftype)n * (ftype)k);
99+
}
100+
}
101+
template< bool precalc = false>
102+
static point eval_point(size_t n) {
103+
if(n < pre_roots && !precalc) {
104+
return evalp.get(n);
105+
} else if(n == 0) {
106+
return 1;
107+
} else {
108+
size_t N = std::bit_floor(n);
109+
return root(2 * N, eval_arg(n));
110+
}
111+
}
112+
113+
template<bool precalc = false>
85114
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
86-
size_t step = sizeof(pt) / sizeof(point);
87-
pt cur;
88-
pt arg = to_pt<pt>(root<point>(n, step));
89-
for(size_t i = 0; i < m; i += step) {
90-
if(i % 64 == 0 || n < pre_roots) {
91-
cur = root<pt>(n, i);
115+
point cur;
116+
point arg = root<precalc>(n, 1);
117+
for(size_t i = 0; i < m; i++) {
118+
if(precalc || i % 32 == 0 || n < pre_roots) {
119+
cur = root<precalc>(n, i);
92120
} else {
93121
cur *= arg;
94122
}
95123
callback(i, cur);
96124
}
97125
}
126+
static void exec_on_evals(size_t n, auto &&callback) {
127+
for(size_t i = 0; i < n; i++) {
128+
callback(i, eval_point(i));
129+
}
130+
}
98131

99132
void ifft() {
100133
size_t n = size();
101-
for(size_t i = 1; i < n; i *= 2) {
102-
for(size_t j = 0; j < n; j += 2 * i) {
103-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
104-
k += j;
105-
auto t = get<pt>(k + i) * conj(rt);
106-
set(k + i, get<pt>(k) - t);
107-
set(k, get<pt>(k) + t);
108-
};
109-
if(2 * i <= flen) {
110-
exec_on_roots(i, i, butterfly);
111-
} else {
112-
exec_on_roots<vpoint>(i, i, butterfly);
134+
for(size_t i = flen; i <= n / 2; i *= 2) {
135+
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
136+
k *= 2 * i;
137+
vpoint vrt = {real(rt), imag(rt)};
138+
for(size_t j = k; j < k + i; j += flen) {
139+
auto A = get<vpoint>(j) + get<vpoint>(j + i);
140+
auto B = get<vpoint>(j) - get<vpoint>(j + i);
141+
set(j, A);
142+
set(j + i, B * conj(vrt));
113143
}
114-
}
144+
});
115145
}
146+
checkpoint("ifft");
116147
for(size_t k = 0; k < n; k += flen) {
117-
set(k, get<vpoint>(k) /= to_pt<vpoint>(ftype(n)));
148+
set(k, get<vpoint>(k) /= (ftype)(n / flen));
118149
}
119150
}
120151
void fft() {
121152
size_t n = size();
122-
for(size_t i = n / 2; i >= 1; i /= 2) {
123-
for(size_t j = 0; j < n; j += 2 * i) {
124-
auto butterfly = [&]<class pt>(size_t k, pt rt) {
125-
k += j;
126-
auto A = get<pt>(k) + get<pt>(k + i);
127-
auto B = get<pt>(k) - get<pt>(k + i);
128-
set(k, A);
129-
set(k + i, B * rt);
130-
};
131-
if(2 * i <= flen) {
132-
exec_on_roots(i, i, butterfly);
133-
} else {
134-
exec_on_roots<vpoint>(i, i, butterfly);
153+
for(size_t i = n / 2; i >= flen; i /= 2) {
154+
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
155+
k *= 2 * i;
156+
vpoint vrt = {real(rt), imag(rt)};
157+
for(size_t j = k; j < k + i; j += flen) {
158+
auto t = get<vpoint>(j + i) * vrt;
159+
set(j + i, get<vpoint>(j) - t);
160+
set(j, get<vpoint>(j) + t);
135161
}
136-
}
162+
});
137163
}
164+
checkpoint("fft");
138165
}
139166
};
167+
std::array<size_t, cvector::pre_roots> cvector::eval_args = []() {
168+
std::array<size_t, pre_roots> res = {};
169+
for(size_t i = 1; i < pre_roots; i++) {
170+
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
171+
}
172+
return res;
173+
}();
140174
const cvector cvector::roots = []() {
141175
cvector res(pre_roots);
142176
for(size_t n = 1; n < res.size(); n *= 2) {
143-
auto base = std::polar(1., std::numbers::pi / ftype(n));
144-
point cur = 1;
145-
for(size_t k = 0; k < n; k++) {
146-
if((k & 15) == 0) {
147-
cur = std::polar(1., std::numbers::pi * ftype(k) / ftype(n));
148-
}
149-
res.set(n + k, cur);
150-
cur *= base;
151-
}
177+
cvector::exec_on_roots<true>(n, n, [&](size_t k, auto rt) {
178+
res.set(n + k, rt);
179+
});
152180
}
153181
return res;
154182
}();
155-
156-
template<typename base>
157-
struct dft {
158-
cvector A;
159-
160-
dft(std::vector<base> const& a, size_t n): A(n) {
161-
for(size_t i = 0; i < std::min(n, a.size()); i++) {
162-
A.set(i, a[i]);
163-
}
164-
if(n) {
165-
A.fft();
166-
}
167-
}
168-
169-
std::vector<base> operator *= (dft const& B) {
170-
assert(A.size() == B.A.size());
171-
size_t n = A.size();
172-
if(!n) {
173-
return std::vector<base>();
174-
}
175-
A.dot(B.A);
176-
A.ifft();
177-
std::vector<base> res(n);
178-
for(size_t k = 0; k < n; k++) {
179-
res[k] = A.get(k);
180-
}
181-
return res;
182-
}
183-
184-
auto operator * (dft const& B) const {
185-
return dft(*this) *= B;
183+
const cvector cvector::evalp = []() {
184+
cvector res(pre_roots);
185+
for(size_t n = 0; n < res.size(); n++) {
186+
res.set(n, cvector::eval_point<true>(n));
186187
}
187-
188-
point operator [](int i) const {return A.get(i);}
189-
};
188+
return res;
189+
}();
190190
}
191191
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/math/fft.hpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#ifndef CP_ALGO_MATH_FFT_HPP
22
#define CP_ALGO_MATH_FFT_HPP
33
#include "../number_theory/modint.hpp"
4+
#include "../util/checkpoint.hpp"
45
#include "cvector.hpp"
6+
#include <ranges>
7+
#include <iostream>
58
namespace cp_algo::math::fft {
69
template<modint_type base>
7-
struct dft<base> {
10+
struct dft {
811
int split;
912
cvector A, B;
1013

@@ -18,6 +21,7 @@ namespace cp_algo::math::fft {
1821
B.set(ti, B.get(ti) + quo * rt);
1922

2023
});
24+
checkpoint("dft init");
2125
if(n) {
2226
A.fft();
2327
B.fft();
@@ -31,12 +35,47 @@ namespace cp_algo::math::fft {
3135
res = {};
3236
return;
3337
}
34-
for(size_t i = 0; i < n; i += flen) {
35-
auto tmp = A.vget(i) * D.vget(i) + B.vget(i) * C.vget(i);
36-
A.set(i, A.vget(i) * C.vget(i));
37-
B.set(i, B.vget(i) * D.vget(i));
38-
C.set(i, tmp);
38+
for(size_t k = 0; k < n; k += flen) {
39+
auto rt = cvector::eval_point(k / flen / 2);
40+
if(k / flen % 2) {
41+
rt = -rt;
42+
}
43+
auto [Ax, Ay] = A.vget(k);
44+
auto [Bx, By] = B.vget(k);
45+
auto [Cvx, Cvy] = C.vget(k);
46+
auto [Dvx, Dvy] = D.vget(k);
47+
auto [Crvx, Crvy] = vpoint(Cvx, Cvy) * vpoint(real(rt), imag(rt));
48+
auto [Drvx, Drvy] = vpoint(Dvx, Dvy) * vpoint(real(rt), imag(rt));
49+
ftype Cx[2 * flen], Cy[2 * flen];
50+
ftype Dx[2 * flen], Dy[2 * flen];
51+
Cvx.copy_to(Cx + flen, std::experimental::vector_aligned);
52+
Cvy.copy_to(Cy + flen, std::experimental::vector_aligned);
53+
Dvx.copy_to(Dx + flen, std::experimental::vector_aligned);
54+
Dvy.copy_to(Dy + flen, std::experimental::vector_aligned);
55+
Crvx.copy_to(Cx, std::experimental::vector_aligned);
56+
Crvy.copy_to(Cy, std::experimental::vector_aligned);
57+
Drvx.copy_to(Dx, std::experimental::vector_aligned);
58+
Drvy.copy_to(Dy, std::experimental::vector_aligned);
59+
vpoint AC, AD, BC, BD;
60+
AC = AD = BC = BD = {0, 0};
61+
for(size_t i = 0; i < flen; i++) {
62+
vftype Csx, Csy, Dsx, Dsy;
63+
Csx.copy_from(Cx + flen - i, std::experimental::element_aligned);
64+
Csy.copy_from(Cy + flen - i, std::experimental::element_aligned);
65+
Dsx.copy_from(Dx + flen - i, std::experimental::element_aligned);
66+
Dsy.copy_from(Dy + flen - i, std::experimental::element_aligned);
67+
vpoint As = {Ax[i], Ay[i]}, Bs = {Bx[i], By[i]};
68+
vpoint Cs = {Csx, Csy}, Ds = {Dsx, Dsy};
69+
AC += As * Cs;
70+
AD += As * Ds;
71+
BC += Bs * Cs;
72+
BD += Bs * Ds;
73+
}
74+
A.set(k, AC);
75+
C.set(k, AD + BC);
76+
B.set(k, BD);
3977
}
78+
checkpoint("dot");
4079
A.ifft();
4180
B.ifft();
4281
C.ifft();
@@ -58,6 +97,7 @@ namespace cp_algo::math::fft {
5897
int64_t B2 = llround(imag(Bi));
5998
res[n + i] = B0 + B1 * split + B2 * splitsplit;
6099
});
100+
checkpoint("recover mod");
61101
}
62102
void mul_inplace(auto &&B, auto& res, size_t k) {
63103
mul(B.A, B.B, res, k);

cp-algo/util/checkpoint.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef CP_ALGO_UTIL_CHECKPOINT_HPP
2+
#define CP_ALGO_UTIL_CHECKPOINT_HPP
3+
#include <iostream>
4+
#include <chrono>
5+
#include <string>
6+
namespace cp_algo {
7+
void checkpoint(std::string const& msg = "") {
8+
static double last = 0;
9+
double now = (double)clock() / CLOCKS_PER_SEC;
10+
double delta = now - last;
11+
last = now;
12+
if(msg.size()) {
13+
std::cerr << msg << ": " << delta * 1000 << " ms\n";
14+
}
15+
}
16+
}
17+
#endif // CP_ALGO_UTIL_CHECKPOINT_HPP

0 commit comments

Comments
 (0)