2
2
#define CP_ALGO_MATH_FFT_HPP
3
3
#include " ../number_theory/modint.hpp"
4
4
#include " ../util/checkpoint.hpp"
5
+ #include " ../random/rng.hpp"
5
6
#include " cvector.hpp"
6
- #include < ranges>
7
7
#include < iostream>
8
+ #include < ranges>
8
9
namespace cp_algo ::math::fft {
9
10
template <modint_type base>
10
11
struct dft {
11
12
int split;
12
13
cvector A, B;
13
-
14
+ static base factor, ifactor;
15
+ static bool init;
16
+
14
17
dft (auto const & a, size_t n): A(n), B(n) {
18
+ if (!init) {
19
+ factor = 1 + random::rng () % (base::mod () - 1 );
20
+ ifactor = base (1 ) / factor;
21
+ init = true ;
22
+ }
15
23
split = int (std::sqrt (base::mod ())) + 1 ;
24
+ base cur = 1 ;
16
25
cvector::exec_on_roots (2 * n, std::min (n, size (a)), [&](size_t i, auto rt) {
17
26
auto splt = [&](size_t i) {
27
+ #ifdef CP_ALGO_FFT_RANDOMIZER
28
+ auto ai = ftype (i < size (a) ? (a[i] * cur).rem () : 0 );
29
+ cur *= factor;
30
+ #else
18
31
auto ai = ftype (i < size (a) ? a[i].rem () : 0 );
32
+ #endif
19
33
auto rem = std::remainder (ai, split);
20
34
auto quo = (ai - rem) / split;
21
35
return std::pair{rem, quo};
@@ -32,7 +46,7 @@ namespace cp_algo::math::fft {
32
46
}
33
47
}
34
48
35
- void mul (auto &&C, auto const & D, auto &res, size_t k) {
49
+ void mul (auto &&C, auto const & D, auto &res, size_t k, [[maybe_unused]] base ifactor ) {
36
50
assert (A.size () == C.size ());
37
51
size_t n = A.size ();
38
52
if (!n) {
@@ -73,6 +87,8 @@ namespace cp_algo::math::fft {
73
87
B.ifft ();
74
88
C.ifft ();
75
89
auto splitsplit = (base (split) * split).rem ();
90
+ base cur = 1 ;
91
+ base step = bpow (ifactor, n);
76
92
cvector::exec_on_roots (2 * n, std::min (n, k), [&](size_t i, point rt) {
77
93
rt = conj (rt);
78
94
auto Ai = A.get (i) * rt;
@@ -82,21 +98,28 @@ namespace cp_algo::math::fft {
82
98
int64_t A1 = llround (real (Ci));
83
99
int64_t A2 = llround (real (Bi));
84
100
res[i] = A0 + A1 * split + A2 * splitsplit;
101
+ #ifdef CP_ALGO_FFT_RANDOMIZER
102
+ res[i] *= cur;
103
+ #endif
85
104
if (n + i >= k) {
86
105
return ;
87
106
}
88
107
int64_t B0 = llround (imag (Ai));
89
108
int64_t B1 = llround (imag (Ci));
90
109
int64_t B2 = llround (imag (Bi));
91
110
res[n + i] = B0 + B1 * split + B2 * splitsplit;
111
+ #ifdef CP_ALGO_FFT_RANDOMIZER
112
+ res[n + i] *= cur * step;
113
+ cur *= ifactor;
114
+ #endif
92
115
});
93
116
checkpoint (" recover mod" );
94
117
}
95
118
void mul_inplace (auto &&B, auto & res, size_t k) {
96
- mul (B.A , B.B , res, k);
119
+ mul (B.A , B.B , res, k, ifactor * B. ifactor );
97
120
}
98
121
void mul (auto const & B, auto & res, size_t k) {
99
- mul (cvector (B.A ), B.B , res, k);
122
+ mul (cvector (B.A ), B.B , res, k, ifactor * B. ifactor );
100
123
}
101
124
std::vector<base> operator *= (dft &B) {
102
125
std::vector<base> res (2 * A.size ());
@@ -111,9 +134,12 @@ namespace cp_algo::math::fft {
111
134
auto operator * (dft const & B) const {
112
135
return dft (*this ) *= B;
113
136
}
114
-
137
+
115
138
point operator [](int i) const {return A.get (i);}
116
139
};
140
+ template <modint_type base> base dft<base>::factor = 1 ;
141
+ template <modint_type base> base dft<base>::ifactor = 1 ;
142
+ template <modint_type base> bool dft<base>::init = false ;
117
143
118
144
void mul_slow (auto &a, auto const & b, size_t k) {
119
145
if (empty (a) || empty (b)) {
@@ -155,8 +181,36 @@ namespace cp_algo::math::fft {
155
181
}
156
182
}
157
183
void mul (auto &a, auto const & b) {
158
- if (size (a)) {
159
- mul_truncate (a, b, size (a) + size (b) - 1 );
184
+ size_t N = size (a) + size (b) - 1 ;
185
+ if (std::max (size (a), size (b)) > (1 << 23 )) {
186
+ // do karatsuba to save memory
187
+ auto n = (std::max (size (a), size (b)) + 1 ) / 2 ;
188
+ auto a0 = to<std::vector>(a | std::views::take (n));
189
+ auto a1 = to<std::vector>(a | std::views::drop (n));
190
+ auto b0 = to<std::vector>(b | std::views::take (n));
191
+ auto b1 = to<std::vector>(b | std::views::drop (n));
192
+ a0.resize (n); a1.resize (n);
193
+ b0.resize (n); b1.resize (n);
194
+ auto a01 = to<std::vector>(std::views::zip_transform (std::plus{}, a0, a1));
195
+ auto b01 = to<std::vector>(std::views::zip_transform (std::plus{}, b0, b1));
196
+ mul (a0, b0);
197
+ mul (a1, b1);
198
+ mul (a01, b01);
199
+ a.assign (4 * n, 0 );
200
+ for (auto [i, ai]: a0 | std::views::enumerate) {
201
+ a[i] += ai;
202
+ a[i + n] -= ai;
203
+ }
204
+ for (auto [i, ai]: a1 | std::views::enumerate) {
205
+ a[i + n] -= ai;
206
+ a[i + 2 * n] += ai;
207
+ }
208
+ for (auto [i, ai]: a01 | std::views::enumerate) {
209
+ a[i + n] += ai;
210
+ }
211
+ a.resize (N);
212
+ } else if (size (a)) {
213
+ mul_truncate (a, b, N);
160
214
}
161
215
}
162
216
}
0 commit comments