Skip to content

Commit 8690e44

Browse files
committed
Improve randomizer
1 parent 9d865a9 commit 8690e44

File tree

2 files changed

+80
-21
lines changed

2 files changed

+80
-21
lines changed

cp-algo/math/cvector.hpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
namespace stdx = std::experimental;
1010
namespace cp_algo::math::fft {
1111
using ftype = double;
12-
static constexpr size_t bytes = 32;
13-
static constexpr size_t flen = bytes / sizeof(ftype);
12+
static constexpr size_t flen = 4;
13+
static constexpr size_t bytes = flen * sizeof(ftype);
1414
using point = complex<ftype>;
1515
using vftype [[gnu::vector_size(bytes)]] = ftype;
1616
using vpoint = complex<vftype>;
@@ -65,7 +65,7 @@ namespace cp_algo::math::fft {
6565
return sz;
6666
}
6767
static size_t eval_arg(size_t n) {
68-
if(n < pre_roots) {
68+
if(n < pre_evals) {
6969
return eval_args[n];
7070
} else {
7171
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
@@ -74,13 +74,17 @@ namespace cp_algo::math::fft {
7474
static auto root(size_t n, size_t k) {
7575
if(n < pre_roots) {
7676
return roots[n + k];
77+
} else if (k % 2 == 0) {
78+
return root(n / 2, k / 2);
7779
} else {
7880
return polar(1., std::numbers::pi / (ftype)n * (ftype)k);
7981
}
8082
}
8183
static point eval_point(size_t n) {
82-
if(n < pre_roots) {
83-
return evalp[n];
84+
if(n % 2) {
85+
return eval_point(n - 1) * point(0, 1);
86+
} else if(n / 2 < pre_evals) {
87+
return evalp[n / 2];
8488
} else {
8589
return root(2 * std::bit_floor(n), eval_arg(n));
8690
}
@@ -203,7 +207,8 @@ namespace cp_algo::math::fft {
203207
}
204208
checkpoint("fft");
205209
}
206-
static constexpr size_t pre_roots = 1 << 16;
210+
static constexpr size_t pre_roots = 1 << 14;
211+
static constexpr size_t pre_evals = 1 << 16;
207212
static constexpr std::array<point, pre_roots> roots = []() {
208213
std::array<point, pre_roots> res = {};
209214
for(size_t n = 1; n < res.size(); n *= 2) {
@@ -213,18 +218,18 @@ namespace cp_algo::math::fft {
213218
}
214219
return res;
215220
}();
216-
static constexpr std::array<size_t, pre_roots> eval_args = []() {
217-
std::array<size_t, pre_roots> res = {};
218-
for(size_t i = 1; i < pre_roots; i++) {
221+
static constexpr std::array<size_t, pre_evals> eval_args = []() {
222+
std::array<size_t, pre_evals> res = {};
223+
for(size_t i = 1; i < pre_evals; i++) {
219224
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
220225
}
221226
return res;
222227
}();
223-
static constexpr std::array<point, pre_roots> evalp = []() {
224-
std::array<point, pre_roots> res = {};
228+
static constexpr std::array<point, pre_evals> evalp = []() {
229+
std::array<point, pre_evals> res = {};
225230
res[0] = 1;
226-
for(size_t n = 1; n < pre_roots; n++) {
227-
res[n] = polar(1., std::numbers::pi * ftype(eval_args[n]) / ftype(2 * std::bit_floor(n)));
231+
for(size_t n = 1; n < pre_evals; n++) {
232+
res[n] = polar(1., std::numbers::pi * ftype(eval_args[n]) / ftype(4 * std::bit_floor(n)));
228233
}
229234
return res;
230235
}();

cp-algo/math/fft.hpp

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,34 @@
22
#define CP_ALGO_MATH_FFT_HPP
33
#include "../number_theory/modint.hpp"
44
#include "../util/checkpoint.hpp"
5+
#include "../random/rng.hpp"
56
#include "cvector.hpp"
6-
#include <ranges>
77
#include <iostream>
8+
#include <ranges>
89
namespace cp_algo::math::fft {
910
template<modint_type base>
1011
struct dft {
1112
int split;
1213
cvector A, B;
13-
14+
static base factor, ifactor;
15+
static bool init;
16+
1417
dft(auto const& a, size_t n): A(n), B(n) {
18+
if(!init) {
19+
factor = 1 + random::rng() % (base::mod() - 1);
20+
ifactor = base(1) / factor;
21+
init = true;
22+
}
1523
split = int(std::sqrt(base::mod())) + 1;
24+
base cur = 1;
1625
cvector::exec_on_roots(2 * n, std::min(n, size(a)), [&](size_t i, auto rt) {
1726
auto splt = [&](size_t i) {
27+
#ifdef CP_ALGO_FFT_RANDOMIZER
28+
auto ai = ftype(i < size(a) ? (a[i] * cur).rem() : 0);
29+
cur *= factor;
30+
#else
1831
auto ai = ftype(i < size(a) ? a[i].rem() : 0);
32+
#endif
1933
auto rem = std::remainder(ai, split);
2034
auto quo = (ai - rem) / split;
2135
return std::pair{rem, quo};
@@ -32,7 +46,7 @@ namespace cp_algo::math::fft {
3246
}
3347
}
3448

35-
void mul(auto &&C, auto const& D, auto &res, size_t k) {
49+
void mul(auto &&C, auto const& D, auto &res, size_t k, [[maybe_unused]] base ifactor) {
3650
assert(A.size() == C.size());
3751
size_t n = A.size();
3852
if(!n) {
@@ -73,6 +87,8 @@ namespace cp_algo::math::fft {
7387
B.ifft();
7488
C.ifft();
7589
auto splitsplit = (base(split) * split).rem();
90+
base cur = 1;
91+
base step = bpow(ifactor, n);
7692
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
7793
rt = conj(rt);
7894
auto Ai = A.get(i) * rt;
@@ -82,21 +98,28 @@ namespace cp_algo::math::fft {
8298
int64_t A1 = llround(real(Ci));
8399
int64_t A2 = llround(real(Bi));
84100
res[i] = A0 + A1 * split + A2 * splitsplit;
101+
#ifdef CP_ALGO_FFT_RANDOMIZER
102+
res[i] *= cur;
103+
#endif
85104
if(n + i >= k) {
86105
return;
87106
}
88107
int64_t B0 = llround(imag(Ai));
89108
int64_t B1 = llround(imag(Ci));
90109
int64_t B2 = llround(imag(Bi));
91110
res[n + i] = B0 + B1 * split + B2 * splitsplit;
111+
#ifdef CP_ALGO_FFT_RANDOMIZER
112+
res[n + i] *= cur * step;
113+
cur *= ifactor;
114+
#endif
92115
});
93116
checkpoint("recover mod");
94117
}
95118
void mul_inplace(auto &&B, auto& res, size_t k) {
96-
mul(B.A, B.B, res, k);
119+
mul(B.A, B.B, res, k, ifactor * B.ifactor);
97120
}
98121
void mul(auto const& B, auto& res, size_t k) {
99-
mul(cvector(B.A), B.B, res, k);
122+
mul(cvector(B.A), B.B, res, k, ifactor * B.ifactor);
100123
}
101124
std::vector<base> operator *= (dft &B) {
102125
std::vector<base> res(2 * A.size());
@@ -111,9 +134,12 @@ namespace cp_algo::math::fft {
111134
auto operator * (dft const& B) const {
112135
return dft(*this) *= B;
113136
}
114-
137+
115138
point operator [](int i) const {return A.get(i);}
116139
};
140+
template<modint_type base> base dft<base>::factor = 1;
141+
template<modint_type base> base dft<base>::ifactor = 1;
142+
template<modint_type base> bool dft<base>::init = false;
117143

118144
void mul_slow(auto &a, auto const& b, size_t k) {
119145
if(empty(a) || empty(b)) {
@@ -155,8 +181,36 @@ namespace cp_algo::math::fft {
155181
}
156182
}
157183
void mul(auto &a, auto const& b) {
158-
if(size(a)) {
159-
mul_truncate(a, b, size(a) + size(b) - 1);
184+
size_t N = size(a) + size(b) - 1;
185+
if(std::max(size(a), size(b)) > (1 << 23)) {
186+
// do karatsuba to save memory
187+
auto n = (std::max(size(a), size(b)) + 1) / 2;
188+
auto a0 = to<std::vector>(a | std::views::take(n));
189+
auto a1 = to<std::vector>(a | std::views::drop(n));
190+
auto b0 = to<std::vector>(b | std::views::take(n));
191+
auto b1 = to<std::vector>(b | std::views::drop(n));
192+
a0.resize(n); a1.resize(n);
193+
b0.resize(n); b1.resize(n);
194+
auto a01 = to<std::vector>(std::views::zip_transform(std::plus{}, a0, a1));
195+
auto b01 = to<std::vector>(std::views::zip_transform(std::plus{}, b0, b1));
196+
mul(a0, b0);
197+
mul(a1, b1);
198+
mul(a01, b01);
199+
a.assign(4 * n, 0);
200+
for(auto [i, ai]: a0 | std::views::enumerate) {
201+
a[i] += ai;
202+
a[i + n] -= ai;
203+
}
204+
for(auto [i, ai]: a1 | std::views::enumerate) {
205+
a[i + n] -= ai;
206+
a[i + 2 * n] += ai;
207+
}
208+
for(auto [i, ai]: a01 | std::views::enumerate) {
209+
a[i + n] += ai;
210+
}
211+
a.resize(N);
212+
} else if(size(a)) {
213+
mul_truncate(a, b, N);
160214
}
161215
}
162216
}

0 commit comments

Comments
 (0)