Skip to content

Commit 0b656a5

Browse files
committed
Subset compose + test
1 parent 2f6bae2 commit 0b656a5

File tree

2 files changed

+122
-34
lines changed

2 files changed

+122
-34
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111
#include <cstring>
1212
CP_ALGO_SIMD_PRAGMA_PUSH
1313
namespace cp_algo::math {
14-
const size_t logn = 20;
14+
const size_t max_logn = 20;
1515

1616
enum transform_dir { forw, inv };
1717

1818
template<auto N, transform_dir direction>
1919
inline void or_transform(auto &&a) {
20-
[[gnu::assume(N <= 1ull << 30)]];
20+
[[gnu::assume(N <= 1 << 30)]];
2121
if constexpr (N <= 32) {
2222
for(size_t i = 1; i < N; i *= 2) {
2323
for(size_t j = 0; j < N; j += 2 * i) {
2424
for(size_t k = j; k < j + i; k++) {
25-
for(size_t z = 0; z < logn; z++) {
25+
for(size_t z = 0; z < max_logn; z++) {
2626
if constexpr (direction == forw) {
2727
a[k + i][z] += a[k][z];
2828
} else {
@@ -37,8 +37,8 @@ namespace cp_algo::math {
3737
or_transform<half, direction>(&a[0]);
3838
or_transform<half, direction>(&a[half]);
3939
for (size_t i = 0; i < half; i++) {
40-
#pragma GCC unroll logn
41-
for(size_t z = 0; z < logn; z++) {
40+
#pragma GCC unroll max_logn
41+
for(size_t z = 0; z < max_logn; z++) {
4242
if constexpr (direction == forw) {
4343
a[i + half][z] += a[i][z];
4444
} else {
@@ -85,7 +85,7 @@ namespace cp_algo::math {
8585
// Create array buffers for each input
8686
auto create_buffers = [bottoms]<typename... Args>(const Args&...) {
8787
return std::make_tuple(
88-
big_vector<std::array<typename std::decay_t<Args>::value_type, logn>>(bottoms)...
88+
big_vector<std::array<typename std::decay_t<Args>::value_type, max_logn>>(bottoms)...
8989
);
9090
};
9191
auto buffers = std::apply(create_buffers, input_tuple);
@@ -130,8 +130,8 @@ namespace cp_algo::math {
130130
for(size_t i = 0; i < bottoms; i += K) {
131131
std::apply([&](auto&... bufs) {
132132
auto extract_one = [&](auto& buf) {
133-
std::array<u64x4, logn> aa;
134-
for(size_t j = 0; j < logn; j++) {
133+
std::array<u64x4, max_logn> aa;
134+
for(size_t j = 0; j < max_logn; j++) {
135135
for(size_t z = 0; z < K; z++) {
136136
aa[j][z] = buf[i + z][j].getr();
137137
}
@@ -145,7 +145,7 @@ namespace cp_algo::math {
145145
// Write results back: only first array needs to be written
146146
auto& first_buf = std::get<0>(std::forward_as_tuple(bufs...));
147147
const auto& first_aa = std::get<0>(aa_tuple);
148-
for(size_t j = 0; j < logn; j++) {
148+
for(size_t j = 0; j < max_logn; j++) {
149149
for(size_t z = 0; z < K; z++) {
150150
first_buf[i + z][j].setr((uint32_t)first_aa[j][z]);
151151
}
@@ -179,33 +179,40 @@ namespace cp_algo::math {
179179

180180
template<typename base>
181181
big_vector<base> subset_convolution(std::span<base> inpa, std::span<base> inpb) {
182-
auto outpa = on_rank_vectors([](auto &a, auto const& b) {
183-
std::decay_t<decltype(a)> res = {};
184-
const auto mod = base::mod();
185-
const auto imod = math::inv2(-mod);
186-
const auto r4 = u64x4() + uint64_t(-1) % mod + 1;
187-
for(size_t i = 0; i < logn; i++) {
188-
for(size_t j = 0; i + j + 1 < logn; j++) {
189-
res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j]));
190-
}
191-
if (i == logn / 2) {
192-
for(size_t k = logn - 2; k < logn; k++) {
193-
res[k] = res[k] >= base::modmod8() ? res[k] - base::modmod8() : res[k];
182+
big_vector<base> outpa;
183+
with_bit_floor(std::size(inpa), [&]<auto N>() {
184+
constexpr size_t lgn = std::bit_width(N) - 1;
185+
[[gnu::assume(lgn <= max_logn)]];
186+
outpa = on_rank_vectors([](auto &a, auto const& b) {
187+
std::decay_t<decltype(a)> res = {};
188+
const auto mod = base::mod();
189+
const auto imod = math::inv2(-mod);
190+
const auto modmod8 = base::modmod8();
191+
const auto r4 = u64x4() + uint64_t(-1) % mod + 1;
192+
auto add = [&](size_t i) {
193+
if constexpr (lgn) for(size_t j = 0; i + j + 1 < lgn; j++) {
194+
res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j]));
194195
}
196+
};
197+
if constexpr (lgn) for(size_t i = 0; i < lgn / 2; i++) { add(i); }
198+
if constexpr (lgn >= 20) {
199+
res[lgn - 1] = res[lgn - 1] >= modmod8 ? res[lgn - 1] - modmod8 : res[lgn - 1];
200+
res[lgn - 2] = res[lgn - 2] >= modmod8 ? res[lgn - 2] - modmod8 : res[lgn - 2];
195201
}
202+
if constexpr (lgn) for(size_t i = lgn / 2; i < lgn; i++) { add(i); }
203+
if constexpr (lgn) if constexpr (lgn) for(size_t k = 0; k < lgn; k++) {
204+
res[k] = montgomery_reduce(res[k], mod, imod);
205+
res[k] = montgomery_mul(res[k], r4, mod, imod);
206+
a[k] = res[k] >= mod ? res[k] - mod : res[k];
207+
}
208+
}, inpa, inpb);
209+
210+
outpa[0] = inpa[0] * inpb[0];
211+
for(size_t i = 1; i < std::size(inpa); i++) {
212+
outpa[i] += inpa[i] * inpb[0] + inpa[0] * inpb[i];
196213
}
197-
for(size_t k = 0; k < logn; k++) {
198-
res[k] = montgomery_reduce(res[k], mod, imod);
199-
res[k] = montgomery_mul(res[k], r4, mod, imod);
200-
a[k] = res[k] >= mod ? res[k] - mod : res[k];
201-
}
202-
}, inpa, inpb);
203-
204-
outpa[0] = inpa[0] * inpb[0];
205-
for(size_t i = 1; i < std::size(inpa); i++) {
206-
outpa[i] += inpa[i] * inpb[0] + inpa[0] * inpb[i];
207-
}
208-
checkpoint("fix 0");
214+
checkpoint("fix 0");
215+
});
209216
return outpa;
210217
}
211218

@@ -215,11 +222,50 @@ namespace cp_algo::math {
215222
return big_vector<base>{1};
216223
}
217224
size_t N = std::size(inpa);
218-
auto out0 = subset_exp<base>(std::span(inpa).first(N / 2));
225+
auto out0 = subset_exp(std::span(inpa).first(N / 2));
219226
auto out1 = subset_convolution<base>(out0, std::span(inpa).last(N / 2));
220227
out0.insert(end(out0), begin(out1), end(out1));
228+
cp_algo::checkpoint("extend out");
221229
return out0;
222230
}
231+
232+
template<typename base>
233+
big_vector<big_vector<base>> subset_compose(big_vector<std::span<base>> fd, std::span<base> inpa) {
234+
if (size(inpa) == 1) {
235+
big_vector<big_vector<base>> res(size(fd), {base(0)});
236+
big_vector<base> pw(size(fd[0]), 1);
237+
for (size_t i = 1; i < size(fd[0]); i++) {
238+
pw[i] = pw[i - 1] * inpa[0];
239+
}
240+
for (size_t i = 0; i < size(fd); i++) {
241+
for (size_t j = 0; j < size(fd[i]); j++) {
242+
res[i][0] += pw[j] * fd[i][j];
243+
}
244+
}
245+
cp_algo::checkpoint("base case");
246+
return res;
247+
}
248+
size_t N = std::size(inpa);
249+
big_vector<base> fdk(size(fd[0]));
250+
for (size_t i = 0; i + 1 < size(fdk); i++) {
251+
fdk[i] = fd.back()[i + 1] * base(i + 1);
252+
}
253+
fd.push_back(fdk);
254+
cp_algo::checkpoint("fdk");
255+
auto deeper = subset_compose(fd, std::span(inpa).first(N / 2));
256+
for(size_t i = 0; i + 1 < size(fd); i++) {
257+
auto next = subset_convolution<base>(deeper[i + 1], std::span(inpa).last(N / 2));
258+
deeper[i].insert(end(deeper[i]), begin(next), end(next));
259+
}
260+
deeper.pop_back();
261+
cp_algo::checkpoint("combine");
262+
return deeper;
263+
}
264+
265+
template<typename base>
266+
big_vector<base> subset_compose(std::span<base> f, std::span<base> inpa) {
267+
return subset_compose(big_vector{f}, inpa)[0];
268+
}
223269
}
224270
#pragma GCC pop_options
225271
#endif // CP_ALGO_MATH_SUBSET_CONVOLUTION_HPP
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// @brief Polynomial Composite Set Power Series
2+
#define PROBLEM "https://judge.yosupo.jp/problem/polynomial_composite_set_power_series"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
#include <iostream>
7+
#include "blazingio/blazingio.min.hpp"
8+
#define CP_ALGO_CHECKPOINT
9+
#include "cp-algo/number_theory/modint.hpp"
10+
#include "cp-algo/math/subset_convolution.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
15+
const int mod = 998244353;
16+
using base = cp_algo::math::modint<mod>;
17+
18+
void solve() {
19+
size_t M, n;
20+
cin >> M >> n;
21+
size_t N = 1 << n;
22+
cp_algo::big_vector<base> f(M);
23+
for(auto &it: f) {cin >> it;}
24+
cp_algo::big_vector<base> a(N);
25+
for(auto &it: a) {cin >> it;}
26+
cp_algo::checkpoint("read");
27+
auto c = cp_algo::math::subset_compose<base>(f, a);
28+
for(auto &it: c) {cout << it << ' ';}
29+
cp_algo::checkpoint("write");
30+
cp_algo::checkpoint<1>();
31+
}
32+
33+
signed main() {
34+
//freopen("input.txt", "r", stdin);
35+
ios::sync_with_stdio(0);
36+
cin.tie(0);
37+
int t;
38+
t = 1;// cin >> t;
39+
while(t--) {
40+
solve();
41+
}
42+
}

0 commit comments

Comments
 (0)