@@ -19,10 +19,10 @@ namespace cp_algo::math {
1919 inline void xor_transform (auto &&a) {
2020 [[gnu::assume (N <= 1 << 30 )]];
2121 if constexpr (N <= 32 ) {
22- for (size_t i = 1 ; i < N; i *= 2 ) {
23- for (size_t j = 0 ; j < N; j += 2 * i) {
24- for (size_t k = j; k < j + i; k++) {
25- for (size_t z = 0 ; z < max_logn; z++) {
22+ for (size_t i = 1 ; i < N; i *= 2 ) {
23+ for (size_t j = 0 ; j < N; j += 2 * i) {
24+ for (size_t k = j; k < j + i; k++) {
25+ for (size_t z = 0 ; z < max_logn; z++) {
2626 auto x = a[k][z] + a[k + i][z];
2727 auto y = a[k][z] - a[k + i][z];
2828 a[k][z] = x;
@@ -32,18 +32,38 @@ namespace cp_algo::math {
3232 }
3333 }
3434 } else {
35- constexpr auto half = N / 2 ;
36- xor_transform<half, direction>(&a[0 ]);
37- xor_transform<half, direction>(&a[half]);
38- for (size_t i = 0 ; i < half; i++) {
35+ auto add = [&](auto &a, auto &b) __attribute__ ((always_inline)) {
36+ auto x = a + b, y = a - b;
37+ a = x, b = y;
38+ };
39+ constexpr auto quar = N / 4 ;
40+
41+ for (size_t i = 0 ; i < (size_t )quar; i++) {
42+ auto x0 = a[i + (size_t )quar * 0 ];
43+ auto x1 = a[i + (size_t )quar * 1 ];
44+ auto x2 = a[i + (size_t )quar * 2 ];
45+ auto x3 = a[i + (size_t )quar * 3 ];
46+
47+ #pragma GCC unroll max_logn
48+ for (size_t z = 0 ; z < max_logn; z++) {
49+ add (x0[z], x2[z]);
50+ add (x1[z], x3[z]);
51+ }
3952 #pragma GCC unroll max_logn
40- for (size_t z = 0 ; z < max_logn; z++) {
41- auto x = a[i][z] + a[i + half][z];
42- auto y = a[i][z] - a[i + half][z];
43- a[i][z] = x;
44- a[i + half][z] = y;
53+ for (size_t z = 0 ; z < max_logn; z++) {
54+ add (x0[z], x1[z]);
55+ add (x2[z], x3[z]);
4556 }
57+
58+ a[i + (size_t )quar * 0 ] = x0;
59+ a[i + (size_t )quar * 1 ] = x1;
60+ a[i + (size_t )quar * 2 ] = x2;
61+ a[i + (size_t )quar * 3 ] = x3;
4662 }
63+ xor_transform<quar, direction>(&a[quar * 0 ]);
64+ xor_transform<quar, direction>(&a[quar * 1 ]);
65+ xor_transform<quar, direction>(&a[quar * 2 ]);
66+ xor_transform<quar, direction>(&a[quar * 3 ]);
4767 }
4868 }
4969
@@ -183,9 +203,9 @@ namespace cp_algo::math {
183203 }
184204
185205 template <typename base>
186- big_vector<base> subset_convolution (std::span<base> inpa , std::span<base> inpb ) {
206+ big_vector<base> subset_convolution (std::span<base> f , std::span<base> g ) {
187207 big_vector<base> outpa;
188- with_bit_floor (std::size (inpa ), [&]<auto N>() {
208+ with_bit_floor (std::size (f ), [&]<auto N>() {
189209 constexpr size_t lgn = std::bit_width (N) - 1 ;
190210 [[gnu::assume (lgn <= max_logn)]];
191211 outpa = on_rank_vectors ([](auto &a, auto const & b) {
@@ -204,56 +224,56 @@ namespace cp_algo::math {
204224 res[k] = montgomery_mul (res[k], r4, mod, imod);
205225 a[k] = res[k] >= mod ? res[k] - mod : res[k];
206226 }
207- }, inpa, inpb );
227+ }, f, g );
208228
209- outpa[0 ] = inpa [0 ] * inpb [0 ];
210- for (size_t i = 1 ; i < std::size (inpa ); i++) {
211- outpa[i] += inpa [i] * inpb [0 ] + inpa [0 ] * inpb [i];
229+ outpa[0 ] = f [0 ] * g [0 ];
230+ for (size_t i = 1 ; i < std::size (f ); i++) {
231+ outpa[i] += f [i] * g [0 ] + f [0 ] * g [i];
212232 }
213233 checkpoint (" fix 0" );
214234 });
215235 return outpa;
216236 }
217237
218238 template <typename base>
219- big_vector<base> subset_exp (std::span<base> inpa ) {
220- if (size (inpa ) == 1 ) {
239+ big_vector<base> subset_exp (std::span<base> g ) {
240+ if (size (g ) == 1 ) {
221241 return big_vector<base>{1 };
222242 }
223- size_t N = std::size (inpa );
224- auto out0 = subset_exp (std::span (inpa ).first (N / 2 ));
225- auto out1 = subset_convolution<base>(out0, std::span (inpa ).last (N / 2 ));
243+ size_t N = std::size (g );
244+ auto out0 = subset_exp (std::span (g ).first (N / 2 ));
245+ auto out1 = subset_convolution<base>(out0, std::span (g ).last (N / 2 ));
226246 out0.insert (end (out0), begin (out1), end (out1));
227247 cp_algo::checkpoint (" extend out" );
228248 return out0;
229249 }
230250
231251 template <typename base>
232- big_vector<big_vector<base>> subset_compose (big_vector<std::span<base>> fd, std::span<base> inpa) {
233- if (size (inpa) == 1 ) {
234- big_vector<big_vector<base>> res (size (fd), {base (0 )});
235- big_vector<base> pw (size (fd[0 ]), 1 );
236- for (size_t i = 1 ; i < size (fd[0 ]); i++) {
237- pw[i] = pw[i - 1 ] * inpa[0 ];
252+ big_vector<big_vector<base>> subset_compose (std::span<base> f, std::span<base> g, size_t n) {
253+ if (size (g) == 1 ) {
254+ size_t M = size (f);
255+ big_vector res (n, big_vector<base>{0 });
256+ big_vector<base> pw (M+1 );
257+ pw[0 ] = 1 ;
258+ for (size_t j = 1 ; j < M; j++) {
259+ pw[j] = pw[j - 1 ] * g[0 ];
238260 }
239- for (size_t i = 0 ; i < size (fd); i++) {
240- for (size_t j = 0 ; j < size (fd[i]); j++) {
241- res[i][0 ] += pw[j] * fd[i][j];
261+ for (size_t i = 0 ; i < n; i++) {
262+ for (size_t j = 0 ; j < M; j++) {
263+ res[i][0 ] += pw[j] * f[j];
264+ }
265+ for (size_t j = M; j > i; j--) {
266+ pw[j] = pw[j - 1 ] * base (j);
242267 }
268+ pw[i] = 0 ;
243269 }
244270 cp_algo::checkpoint (" base case" );
245271 return res;
246272 }
247- size_t N = std::size (inpa);
248- big_vector<base> fdk (size (fd[0 ]));
249- for (size_t i = 0 ; i + 1 < size (fdk); i++) {
250- fdk[i] = fd.back ()[i + 1 ] * base (i + 1 );
251- }
252- fd.push_back (fdk);
253- cp_algo::checkpoint (" fdk" );
254- auto deeper = subset_compose (fd, std::span (inpa).first (N / 2 ));
255- for (size_t i = 0 ; i + 1 < size (fd); i++) {
256- auto next = subset_convolution<base>(deeper[i + 1 ], std::span (inpa).last (N / 2 ));
273+ size_t N = std::size (g);
274+ auto deeper = subset_compose (f, std::span (g).first (N / 2 ), n + 1 );
275+ for (size_t i = 0 ; i + 1 < size (deeper); i++) {
276+ auto next = subset_convolution<base>(deeper[i + 1 ], std::span (g).last (N / 2 ));
257277 deeper[i].insert (end (deeper[i]), begin (next), end (next));
258278 }
259279 deeper.pop_back ();
@@ -262,8 +282,59 @@ namespace cp_algo::math {
262282 }
263283
264284 template <typename base>
265- big_vector<base> subset_compose (std::span<base> f, std::span<base> inpa) {
266- return subset_compose (big_vector{f}, inpa)[0 ];
285+ big_vector<base> subset_compose (std::span<base> f, std::span<base> g) {
286+ return subset_compose (f, g, 1 )[0 ];
287+ }
288+
289+ // Transpose of f -> f * g = h
290+ template <typename base>
291+ big_vector<base> subset_conv_transpose (std::span<base> h, std::span<base> g) {
292+ std::ranges::reverse (h);
293+ auto res = subset_convolution<base>(h, g);
294+ std::ranges::reverse (h);
295+ std::ranges::reverse (res);
296+ return res;
297+ }
298+
299+ template <typename base>
300+ big_vector<base> subset_power_projection (big_vector<big_vector<base>> &&fg, std::span<base> g, size_t M) {
301+ if (size (g) == 1 ) {
302+ size_t n = size (fg);
303+ big_vector<base> res (M);
304+ big_vector<base> pw (M+1 );
305+ pw[0 ] = 1 ;
306+ for (size_t j = 1 ; j < M; j++) {
307+ pw[j] = pw[j - 1 ] * g[0 ];
308+ }
309+ for (size_t i = 0 ; i < size (fg); i++) {
310+ for (size_t j = 0 ; j < M; j++) {
311+ res[j] += pw[j] * fg[i][0 ];
312+ }
313+ for (size_t j = M; j > i; j--) {
314+ pw[j] = pw[j - 1 ] * base (j);
315+ }
316+ pw[i] = 0 ;
317+ }
318+ cp_algo::checkpoint (" base case" );
319+ return res;
320+ }
321+ size_t N = std::size (g);
322+ fg.emplace_back (N / 2 );
323+ for (auto && [i, h]: fg | std::views::enumerate | std::views::reverse | std::views::drop (1 )) {
324+ auto prev = subset_conv_transpose<base>(std::span (h).last (N / 2 ), std::span (g).last (N / 2 ));
325+ for (size_t j = 0 ; j < N / 2 ; j++) {
326+ fg[i + 1 ][j] += prev[j];
327+ }
328+ fg[i + 1 ].resize (N / 2 );
329+ }
330+ fg[0 ].resize (N / 2 );
331+ cp_algo::checkpoint (" decombine" );
332+ return subset_power_projection (std::move (fg), std::span (g).first (N / 2 ), M);
333+ }
334+
335+ template <typename base>
336+ big_vector<base> subset_power_projection (std::span<base> g, std::span<base> w, size_t M) {
337+ return subset_power_projection ({{begin (w), end (w)}}, g, M);
267338 }
268339}
269340#pragma GCC pop_options
0 commit comments