Skip to content

Commit 362e385

Browse files
committed
subset power projection + test
1 parent dcd6932 commit 362e385

File tree

2 files changed

+158
-45
lines changed

2 files changed

+158
-45
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 116 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ namespace cp_algo::math {
1919
inline void xor_transform(auto &&a) {
2020
[[gnu::assume(N <= 1 << 30)]];
2121
if constexpr (N <= 32) {
22-
for(size_t i = 1; i < N; i *= 2) {
23-
for(size_t j = 0; j < N; j += 2 * i) {
24-
for(size_t k = j; k < j + i; k++) {
25-
for(size_t z = 0; z < max_logn; z++) {
22+
for (size_t i = 1; i < N; i *= 2) {
23+
for (size_t j = 0; j < N; j += 2 * i) {
24+
for (size_t k = j; k < j + i; k++) {
25+
for (size_t z = 0; z < max_logn; z++) {
2626
auto x = a[k][z] + a[k + i][z];
2727
auto y = a[k][z] - a[k + i][z];
2828
a[k][z] = x;
@@ -32,18 +32,38 @@ namespace cp_algo::math {
3232
}
3333
}
3434
} else {
35-
constexpr auto half = N / 2;
36-
xor_transform<half, direction>(&a[0]);
37-
xor_transform<half, direction>(&a[half]);
38-
for (size_t i = 0; i < half; i++) {
35+
auto add = [&](auto &a, auto &b) __attribute__((always_inline)) {
36+
auto x = a + b, y = a - b;
37+
a = x, b = y;
38+
};
39+
constexpr auto quar = N / 4;
40+
41+
for (size_t i = 0; i < (size_t)quar; i++) {
42+
auto x0 = a[i + (size_t)quar * 0];
43+
auto x1 = a[i + (size_t)quar * 1];
44+
auto x2 = a[i + (size_t)quar * 2];
45+
auto x3 = a[i + (size_t)quar * 3];
46+
47+
#pragma GCC unroll max_logn
48+
for (size_t z = 0; z < max_logn; z++) {
49+
add(x0[z], x2[z]);
50+
add(x1[z], x3[z]);
51+
}
3952
#pragma GCC unroll max_logn
40-
for(size_t z = 0; z < max_logn; z++) {
41-
auto x = a[i][z] + a[i + half][z];
42-
auto y = a[i][z] - a[i + half][z];
43-
a[i][z] = x;
44-
a[i + half][z] = y;
53+
for (size_t z = 0; z < max_logn; z++) {
54+
add(x0[z], x1[z]);
55+
add(x2[z], x3[z]);
4556
}
57+
58+
a[i + (size_t)quar * 0] = x0;
59+
a[i + (size_t)quar * 1] = x1;
60+
a[i + (size_t)quar * 2] = x2;
61+
a[i + (size_t)quar * 3] = x3;
4662
}
63+
xor_transform<quar, direction>(&a[quar * 0]);
64+
xor_transform<quar, direction>(&a[quar * 1]);
65+
xor_transform<quar, direction>(&a[quar * 2]);
66+
xor_transform<quar, direction>(&a[quar * 3]);
4767
}
4868
}
4969

@@ -183,9 +203,9 @@ namespace cp_algo::math {
183203
}
184204

185205
template<typename base>
186-
big_vector<base> subset_convolution(std::span<base> inpa, std::span<base> inpb) {
206+
big_vector<base> subset_convolution(std::span<base> f, std::span<base> g) {
187207
big_vector<base> outpa;
188-
with_bit_floor(std::size(inpa), [&]<auto N>() {
208+
with_bit_floor(std::size(f), [&]<auto N>() {
189209
constexpr size_t lgn = std::bit_width(N) - 1;
190210
[[gnu::assume(lgn <= max_logn)]];
191211
outpa = on_rank_vectors([](auto &a, auto const& b) {
@@ -204,56 +224,56 @@ namespace cp_algo::math {
204224
res[k] = montgomery_mul(res[k], r4, mod, imod);
205225
a[k] = res[k] >= mod ? res[k] - mod : res[k];
206226
}
207-
}, inpa, inpb);
227+
}, f, g);
208228

209-
outpa[0] = inpa[0] * inpb[0];
210-
for(size_t i = 1; i < std::size(inpa); i++) {
211-
outpa[i] += inpa[i] * inpb[0] + inpa[0] * inpb[i];
229+
outpa[0] = f[0] * g[0];
230+
for(size_t i = 1; i < std::size(f); i++) {
231+
outpa[i] += f[i] * g[0] + f[0] * g[i];
212232
}
213233
checkpoint("fix 0");
214234
});
215235
return outpa;
216236
}
217237

218238
template<typename base>
219-
big_vector<base> subset_exp(std::span<base> inpa) {
220-
if (size(inpa) == 1) {
239+
big_vector<base> subset_exp(std::span<base> g) {
240+
if (size(g) == 1) {
221241
return big_vector<base>{1};
222242
}
223-
size_t N = std::size(inpa);
224-
auto out0 = subset_exp(std::span(inpa).first(N / 2));
225-
auto out1 = subset_convolution<base>(out0, std::span(inpa).last(N / 2));
243+
size_t N = std::size(g);
244+
auto out0 = subset_exp(std::span(g).first(N / 2));
245+
auto out1 = subset_convolution<base>(out0, std::span(g).last(N / 2));
226246
out0.insert(end(out0), begin(out1), end(out1));
227247
cp_algo::checkpoint("extend out");
228248
return out0;
229249
}
230250

231251
template<typename base>
232-
big_vector<big_vector<base>> subset_compose(big_vector<std::span<base>> fd, std::span<base> inpa) {
233-
if (size(inpa) == 1) {
234-
big_vector<big_vector<base>> res(size(fd), {base(0)});
235-
big_vector<base> pw(size(fd[0]), 1);
236-
for (size_t i = 1; i < size(fd[0]); i++) {
237-
pw[i] = pw[i - 1] * inpa[0];
252+
big_vector<big_vector<base>> subset_compose(std::span<base> f, std::span<base> g, size_t n) {
253+
if (size(g) == 1) {
254+
size_t M = size(f);
255+
big_vector res(n, big_vector<base>{0});
256+
big_vector<base> pw(M+1);
257+
pw[0] = 1;
258+
for (size_t j = 1; j < M; j++) {
259+
pw[j] = pw[j - 1] * g[0];
238260
}
239-
for (size_t i = 0; i < size(fd); i++) {
240-
for (size_t j = 0; j < size(fd[i]); j++) {
241-
res[i][0] += pw[j] * fd[i][j];
261+
for (size_t i = 0; i < n; i++) {
262+
for (size_t j = 0; j < M; j++) {
263+
res[i][0] += pw[j] * f[j];
264+
}
265+
for (size_t j = M; j > i; j--) {
266+
pw[j] = pw[j - 1] * base(j);
242267
}
268+
pw[i] = 0;
243269
}
244270
cp_algo::checkpoint("base case");
245271
return res;
246272
}
247-
size_t N = std::size(inpa);
248-
big_vector<base> fdk(size(fd[0]));
249-
for (size_t i = 0; i + 1 < size(fdk); i++) {
250-
fdk[i] = fd.back()[i + 1] * base(i + 1);
251-
}
252-
fd.push_back(fdk);
253-
cp_algo::checkpoint("fdk");
254-
auto deeper = subset_compose(fd, std::span(inpa).first(N / 2));
255-
for(size_t i = 0; i + 1 < size(fd); i++) {
256-
auto next = subset_convolution<base>(deeper[i + 1], std::span(inpa).last(N / 2));
273+
size_t N = std::size(g);
274+
auto deeper = subset_compose(f, std::span(g).first(N / 2), n + 1);
275+
for(size_t i = 0; i + 1 < size(deeper); i++) {
276+
auto next = subset_convolution<base>(deeper[i + 1], std::span(g).last(N / 2));
257277
deeper[i].insert(end(deeper[i]), begin(next), end(next));
258278
}
259279
deeper.pop_back();
@@ -262,8 +282,59 @@ namespace cp_algo::math {
262282
}
263283

264284
template<typename base>
265-
big_vector<base> subset_compose(std::span<base> f, std::span<base> inpa) {
266-
return subset_compose(big_vector{f}, inpa)[0];
285+
big_vector<base> subset_compose(std::span<base> f, std::span<base> g) {
286+
return subset_compose(f, g, 1)[0];
287+
}
288+
289+
// Transpose of f -> f * g = h
290+
template<typename base>
291+
big_vector<base> subset_conv_transpose(std::span<base> h, std::span<base> g) {
292+
std::ranges::reverse(h);
293+
auto res = subset_convolution<base>(h, g);
294+
std::ranges::reverse(h);
295+
std::ranges::reverse(res);
296+
return res;
297+
}
298+
299+
template<typename base>
300+
big_vector<base> subset_power_projection(big_vector<big_vector<base>> &&fg, std::span<base> g, size_t M) {
301+
if (size(g) == 1) {
302+
size_t n = size(fg);
303+
big_vector<base> res(M);
304+
big_vector<base> pw(M+1);
305+
pw[0] = 1;
306+
for (size_t j = 1; j < M; j++) {
307+
pw[j] = pw[j - 1] * g[0];
308+
}
309+
for (size_t i = 0; i < size(fg); i++) {
310+
for (size_t j = 0; j < M; j++) {
311+
res[j] += pw[j] * fg[i][0];
312+
}
313+
for (size_t j = M; j > i; j--) {
314+
pw[j] = pw[j - 1] * base(j);
315+
}
316+
pw[i] = 0;
317+
}
318+
cp_algo::checkpoint("base case");
319+
return res;
320+
}
321+
size_t N = std::size(g);
322+
fg.emplace_back(N / 2);
323+
for(auto&& [i, h]: fg | std::views::enumerate | std::views::reverse | std::views::drop(1)) {
324+
auto prev = subset_conv_transpose<base>(std::span(h).last(N / 2), std::span(g).last(N / 2));
325+
for (size_t j = 0; j < N / 2; j++) {
326+
fg[i + 1][j] += prev[j];
327+
}
328+
fg[i + 1].resize(N / 2);
329+
}
330+
fg[0].resize(N / 2);
331+
cp_algo::checkpoint("decombine");
332+
return subset_power_projection(std::move(fg), std::span(g).first(N / 2), M);
333+
}
334+
335+
template<typename base>
336+
big_vector<base> subset_power_projection(std::span<base> g, std::span<base> w, size_t M) {
337+
return subset_power_projection({{begin(w), end(w)}}, g, M);
267338
}
268339
}
269340
#pragma GCC pop_options
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// @brief Power Projection of Set Power Series
2+
#define PROBLEM "https://judge.yosupo.jp/problem/power_projection_of_set_power_series"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
#include <iostream>
7+
#include "blazingio/blazingio.min.hpp"
8+
#define CP_ALGO_CHECKPOINT
9+
#include "cp-algo/number_theory/modint.hpp"
10+
#include "cp-algo/math/subset_convolution.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
15+
const int mod = 998244353;
16+
using base = cp_algo::math::modint<mod>;
17+
18+
void solve() {
19+
uint32_t n, M;
20+
cin >> n >> M;
21+
size_t N = 1 << n;
22+
cp_algo::big_vector<base> g(N);
23+
for(auto &it: g) {cin >> it;}
24+
cp_algo::big_vector<base> w(N);
25+
for(auto &it: w) {cin >> it;}
26+
cp_algo::checkpoint("read");
27+
auto c = cp_algo::math::subset_power_projection<base>(g, w, M);
28+
for(auto &it: c) {cout << it << ' ';}
29+
cp_algo::checkpoint("write");
30+
cp_algo::checkpoint<1>();
31+
}
32+
33+
signed main() {
34+
//freopen("input.txt", "r", stdin);
35+
ios::sync_with_stdio(0);
36+
cin.tie(0);
37+
int t;
38+
t = 1;// cin >> t;
39+
while(t--) {
40+
solve();
41+
}
42+
}

0 commit comments

Comments
 (0)