Skip to content

Commit 9eb1fd2

Browse files
committed
make root precalc constexpr
1 parent 349bbdc commit 9eb1fd2

File tree

3 files changed

+75
-86
lines changed

3 files changed

+75
-86
lines changed

cp-algo/math/cvector.hpp

Lines changed: 63 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,69 @@ namespace cp_algo::math::fft {
1414
static constexpr size_t flen = vftype::size();
1515

1616
struct cvector {
17-
static constexpr size_t pre_roots = 1 << 15;
17+
static constexpr size_t pre_roots = 1 << 16;
18+
static constexpr std::array<point, pre_roots> roots = []() {
19+
std::array<point, pre_roots> res = {};
20+
for(size_t n = 1; n < res.size(); n *= 2) {
21+
for(size_t k = 0; k < n; k++) {
22+
res[n + k] = polar(1., std::numbers::pi / ftype(n) * ftype(k));
23+
}
24+
}
25+
return res;
26+
}();
27+
static constexpr std::array<size_t, pre_roots> eval_args = []() {
28+
std::array<size_t, pre_roots> res = {};
29+
for(size_t i = 1; i < pre_roots; i++) {
30+
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
31+
}
32+
return res;
33+
}();
34+
static constexpr std::array<point, pre_roots> evalp = []() {
35+
std::array<point, pre_roots> res = {};
36+
res[0] = 1;
37+
for(size_t n = 1; n < pre_roots; n++) {
38+
res[n] = polar(1., std::numbers::pi * ftype(eval_args[n]) / ftype(2 * std::bit_floor(n)));
39+
}
40+
return res;
41+
}();
42+
static size_t eval_arg(size_t n) {
43+
if(n < pre_roots) {
44+
return eval_args[n];
45+
} else {
46+
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
47+
}
48+
}
49+
static auto root(size_t n, size_t k) {
50+
if(n < pre_roots) {
51+
return roots[n + k];
52+
} else {
53+
return polar(1., std::numbers::pi / (ftype)n * (ftype)k);
54+
}
55+
}
56+
static point eval_point(size_t n) {
57+
if(n < pre_roots) {
58+
return evalp[n];
59+
} else {
60+
return root(2 * std::bit_floor(n), eval_arg(n));
61+
}
62+
}
63+
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
64+
point cur;
65+
point arg = root(n, 1);
66+
for(size_t i = 0; i < m; i++) {
67+
if(i % 32 == 0 || n < pre_roots) {
68+
cur = root(n, i);
69+
} else {
70+
cur *= arg;
71+
}
72+
callback(i, cur);
73+
}
74+
}
75+
static void exec_on_evals(size_t n, auto &&callback) {
76+
for(size_t i = 0; i < n; i++) {
77+
callback(i, eval_point(i));
78+
}
79+
}
1880
std::vector<vftype> x, y;
1981
cvector(size_t n) {
2082
n = std::max(flen, std::bit_ceil(n));
@@ -80,57 +142,6 @@ namespace cp_algo::math::fft {
80142
}
81143
checkpoint("dot");
82144
}
83-
static const cvector roots, evalp;
84-
static std::array<size_t, pre_roots> eval_args;
85-
86-
template<bool precalc = false>
87-
static size_t eval_arg(size_t n) {
88-
if(n < pre_roots && !precalc) {
89-
return eval_args[n];
90-
} else if(n == 0) {
91-
return 0;
92-
} else {
93-
return eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1);
94-
}
95-
}
96-
template< bool precalc = false>
97-
static auto root(size_t n, size_t k) {
98-
if(n < pre_roots && !precalc) {
99-
return roots.get(n + k);
100-
} else {
101-
return polar(1., std::numbers::pi / (ftype)n * (ftype)k);
102-
}
103-
}
104-
template< bool precalc = false>
105-
static point eval_point(size_t n) {
106-
if(n < pre_roots && !precalc) {
107-
return evalp.get(n);
108-
} else if(n == 0) {
109-
return 1;
110-
} else {
111-
size_t N = std::bit_floor(n);
112-
return root(2 * N, eval_arg(n));
113-
}
114-
}
115-
116-
template<bool precalc = false>
117-
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
118-
point cur;
119-
point arg = root<precalc>(n, 1);
120-
for(size_t i = 0; i < m; i++) {
121-
if(precalc || i % 32 == 0 || n < pre_roots) {
122-
cur = root<precalc>(n, i);
123-
} else {
124-
cur *= arg;
125-
}
126-
callback(i, cur);
127-
}
128-
}
129-
static void exec_on_evals(size_t n, auto &&callback) {
130-
for(size_t i = 0; i < n; i++) {
131-
callback(i, eval_point(i));
132-
}
133-
}
134145

135146
void ifft() {
136147
size_t n = size();
@@ -167,28 +178,5 @@ namespace cp_algo::math::fft {
167178
checkpoint("fft");
168179
}
169180
};
170-
std::array<size_t, cvector::pre_roots> cvector::eval_args = []() {
171-
std::array<size_t, pre_roots> res = {};
172-
for(size_t i = 1; i < pre_roots; i++) {
173-
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
174-
}
175-
return res;
176-
}();
177-
const cvector cvector::roots = []() {
178-
cvector res(pre_roots);
179-
for(size_t n = 1; n < res.size(); n *= 2) {
180-
cvector::exec_on_roots<true>(n, n, [&](size_t k, auto rt) {
181-
res.set(n + k, rt);
182-
});
183-
}
184-
return res;
185-
}();
186-
const cvector cvector::evalp = []() {
187-
cvector res(pre_roots);
188-
for(size_t n = 0; n < res.size(); n++) {
189-
res.set(n, cvector::eval_point<true>(n));
190-
}
191-
return res;
192-
}();
193181
}
194182
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/util/complex.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace cp_algo {
3131
T imag() const {return y;}
3232
T& real() {return x;}
3333
T& imag() {return y;}
34-
static complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
34+
static constexpr complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
3535
auto operator <=> (complex const& t) const = default;
3636
};
3737
template<typename T>
@@ -43,7 +43,10 @@ namespace cp_algo {
4343
template<typename T> T& imag(complex<T> &x) {return x.imag();}
4444
template<typename T> T real(complex<T> const& x) {return x.real();}
4545
template<typename T> T imag(complex<T> const& x) {return x.imag();}
46-
template<typename T> complex<T> polar(T r, T theta) {return complex<T>::polar(r, theta);}
46+
template<typename T>
47+
constexpr complex<T> polar(T r, T theta) {
48+
return complex<T>::polar(r, theta);
49+
}
4750
template<typename T>
4851
std::ostream& operator << (std::ostream &out, complex<T> x) {
4952
return out << x.real() << ' ' << x.imag();

verify/poly/wildcard.test.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// @brief Wildcard Pattern Matching
22
#define PROBLEM "https://judge.yosupo.jp/problem/wildcard_pattern_matching"
33
#pragma GCC optimize("Ofast,unroll-loops")
4+
#define CP_ALGO_CHECKPOINT
45
#include "cp-algo/math/cvector.hpp"
56
#include "cp-algo/random/rng.hpp"
67
#include <bits/stdc++.h>
@@ -36,19 +37,16 @@ string matches(string const& A, string const& B, char wild = '*') {
3637
project[1][i] = conj(project[0][i]);
3738
}
3839
}
39-
array ST = {&A, &B};
4040
vector<cvector> P;
4141
P.emplace_back(size(A));
4242
P.emplace_back(size(A));
43-
for(int i: {0, 1}) {
44-
size_t N = ST[i]->size();
45-
for(size_t k = 0; k < N; k++) {
46-
char c = ST[i]->at(k);
47-
size_t idx = i ? N - k - 1 : k;
48-
point val = c == wild ? 0 : project[i][c - 'a'];
49-
P[i].set(idx, val);
50-
}
43+
for(auto [i, c]: A | views::enumerate) {
44+
P[0].set(i, (c != wild) * project[0][c - 'a']);
45+
}
46+
for(auto [i, c]: B | views::reverse | views::enumerate) {
47+
P[1].set(i, (c != wild) * project[1][c - 'a']);
5148
}
49+
cp_algo::checkpoint("cvector fill");
5250
semicorr(P[0], P[1]);
5351
string ans(size(A) - size(B) + 1, '0');
5452
for(size_t j = 0; j < size(ans); j++) {

0 commit comments

Comments
 (0)