1111#include < cstring>
1212CP_ALGO_SIMD_PRAGMA_PUSH
1313namespace cp_algo ::math {
14- const size_t logn = 20 ;
14+ const size_t max_logn = 20 ;
1515
1616 enum transform_dir { forw, inv };
1717
1818 template <auto N, transform_dir direction>
1919 inline void or_transform (auto &&a) {
20- [[gnu::assume (N <= 1ull << 30 )]];
20+ [[gnu::assume (N <= 1 << 30 )]];
2121 if constexpr (N <= 32 ) {
2222 for (size_t i = 1 ; i < N; i *= 2 ) {
2323 for (size_t j = 0 ; j < N; j += 2 * i) {
2424 for (size_t k = j; k < j + i; k++) {
25- for (size_t z = 0 ; z < logn ; z++) {
25+ for (size_t z = 0 ; z < max_logn ; z++) {
2626 if constexpr (direction == forw) {
2727 a[k + i][z] += a[k][z];
2828 } else {
@@ -37,8 +37,8 @@ namespace cp_algo::math {
3737 or_transform<half, direction>(&a[0 ]);
3838 or_transform<half, direction>(&a[half]);
3939 for (size_t i = 0 ; i < half; i++) {
40- #pragma GCC unroll logn
41- for (size_t z = 0 ; z < logn ; z++) {
40+ #pragma GCC unroll max_logn
41+ for (size_t z = 0 ; z < max_logn ; z++) {
4242 if constexpr (direction == forw) {
4343 a[i + half][z] += a[i][z];
4444 } else {
@@ -85,7 +85,7 @@ namespace cp_algo::math {
8585 // Create array buffers for each input
8686 auto create_buffers = [bottoms]<typename ... Args>(const Args&...) {
8787 return std::make_tuple (
88- big_vector<std::array<typename std::decay_t <Args>::value_type, logn >>(bottoms)...
88+ big_vector<std::array<typename std::decay_t <Args>::value_type, max_logn >>(bottoms)...
8989 );
9090 };
9191 auto buffers = std::apply (create_buffers, input_tuple);
@@ -130,8 +130,8 @@ namespace cp_algo::math {
130130 for (size_t i = 0 ; i < bottoms; i += K) {
131131 std::apply ([&](auto &... bufs) {
132132 auto extract_one = [&](auto & buf) {
133- std::array<u64x4, logn > aa;
134- for (size_t j = 0 ; j < logn ; j++) {
133+ std::array<u64x4, max_logn > aa;
134+ for (size_t j = 0 ; j < max_logn ; j++) {
135135 for (size_t z = 0 ; z < K; z++) {
136136 aa[j][z] = buf[i + z][j].getr ();
137137 }
@@ -145,7 +145,7 @@ namespace cp_algo::math {
145145 // Write results back: only first array needs to be written
146146 auto & first_buf = std::get<0 >(std::forward_as_tuple (bufs...));
147147 const auto & first_aa = std::get<0 >(aa_tuple);
148- for (size_t j = 0 ; j < logn ; j++) {
148+ for (size_t j = 0 ; j < max_logn ; j++) {
149149 for (size_t z = 0 ; z < K; z++) {
150150 first_buf[i + z][j].setr ((uint32_t )first_aa[j][z]);
151151 }
@@ -179,33 +179,40 @@ namespace cp_algo::math {
179179
180180 template <typename base>
181181 big_vector<base> subset_convolution (std::span<base> inpa, std::span<base> inpb) {
182- auto outpa = on_rank_vectors ([](auto &a, auto const & b) {
183- std::decay_t <decltype (a)> res = {};
184- const auto mod = base::mod ();
185- const auto imod = math::inv2 (-mod);
186- const auto r4 = u64x4 () + uint64_t (-1 ) % mod + 1 ;
187- for (size_t i = 0 ; i < logn; i++) {
188- for (size_t j = 0 ; i + j + 1 < logn; j++) {
189- res[i + j + 1 ] += (u64x4)_mm256_mul_epu32 (__m256i (a[i]), __m256i (b[j]));
190- }
191- if (i == logn / 2 ) {
192- for (size_t k = logn - 2 ; k < logn; k++) {
193- res[k] = res[k] >= base::modmod8 () ? res[k] - base::modmod8 () : res[k];
182+ big_vector<base> outpa;
183+ with_bit_floor (std::size (inpa), [&]<auto N>() {
184+ constexpr size_t lgn = std::bit_width (N) - 1 ;
185+ [[gnu::assume (lgn <= max_logn)]];
186+ outpa = on_rank_vectors ([](auto &a, auto const & b) {
187+ std::decay_t <decltype (a)> res = {};
188+ const auto mod = base::mod ();
189+ const auto imod = math::inv2 (-mod);
190+ const auto modmod8 = base::modmod8 ();
191+ const auto r4 = u64x4 () + uint64_t (-1 ) % mod + 1 ;
192+ auto add = [&](size_t i) {
193+ if constexpr (lgn) for (size_t j = 0 ; i + j + 1 < lgn; j++) {
194+ res[i + j + 1 ] += (u64x4)_mm256_mul_epu32 (__m256i (a[i]), __m256i (b[j]));
194195 }
196+ };
197+ if constexpr (lgn) for (size_t i = 0 ; i < lgn / 2 ; i++) { add (i); }
198+ if constexpr (lgn >= 20 ) {
199+ res[lgn - 1 ] = res[lgn - 1 ] >= modmod8 ? res[lgn - 1 ] - modmod8 : res[lgn - 1 ];
200+ res[lgn - 2 ] = res[lgn - 2 ] >= modmod8 ? res[lgn - 2 ] - modmod8 : res[lgn - 2 ];
195201 }
202+ if constexpr (lgn) for (size_t i = lgn / 2 ; i < lgn; i++) { add (i); }
203+ if constexpr (lgn) if constexpr (lgn) for (size_t k = 0 ; k < lgn; k++) {
204+ res[k] = montgomery_reduce (res[k], mod, imod);
205+ res[k] = montgomery_mul (res[k], r4, mod, imod);
206+ a[k] = res[k] >= mod ? res[k] - mod : res[k];
207+ }
208+ }, inpa, inpb);
209+
210+ outpa[0 ] = inpa[0 ] * inpb[0 ];
211+ for (size_t i = 1 ; i < std::size (inpa); i++) {
212+ outpa[i] += inpa[i] * inpb[0 ] + inpa[0 ] * inpb[i];
196213 }
197- for (size_t k = 0 ; k < logn; k++) {
198- res[k] = montgomery_reduce (res[k], mod, imod);
199- res[k] = montgomery_mul (res[k], r4, mod, imod);
200- a[k] = res[k] >= mod ? res[k] - mod : res[k];
201- }
202- }, inpa, inpb);
203-
204- outpa[0 ] = inpa[0 ] * inpb[0 ];
205- for (size_t i = 1 ; i < std::size (inpa); i++) {
206- outpa[i] += inpa[i] * inpb[0 ] + inpa[0 ] * inpb[i];
207- }
208- checkpoint (" fix 0" );
214+ checkpoint (" fix 0" );
215+ });
209216 return outpa;
210217 }
211218
@@ -215,11 +222,50 @@ namespace cp_algo::math {
215222 return big_vector<base>{1 };
216223 }
217224 size_t N = std::size (inpa);
218- auto out0 = subset_exp<base> (std::span (inpa).first (N / 2 ));
225+ auto out0 = subset_exp (std::span (inpa).first (N / 2 ));
219226 auto out1 = subset_convolution<base>(out0, std::span (inpa).last (N / 2 ));
220227 out0.insert (end (out0), begin (out1), end (out1));
228+ cp_algo::checkpoint (" extend out" );
221229 return out0;
222230 }
231+
232+ template <typename base>
233+ big_vector<big_vector<base>> subset_compose (big_vector<std::span<base>> fd, std::span<base> inpa) {
234+ if (size (inpa) == 1 ) {
235+ big_vector<big_vector<base>> res (size (fd), {base (0 )});
236+ big_vector<base> pw (size (fd[0 ]), 1 );
237+ for (size_t i = 1 ; i < size (fd[0 ]); i++) {
238+ pw[i] = pw[i - 1 ] * inpa[0 ];
239+ }
240+ for (size_t i = 0 ; i < size (fd); i++) {
241+ for (size_t j = 0 ; j < size (fd[i]); j++) {
242+ res[i][0 ] += pw[j] * fd[i][j];
243+ }
244+ }
245+ cp_algo::checkpoint (" base case" );
246+ return res;
247+ }
248+ size_t N = std::size (inpa);
249+ big_vector<base> fdk (size (fd[0 ]));
250+ for (size_t i = 0 ; i + 1 < size (fdk); i++) {
251+ fdk[i] = fd.back ()[i + 1 ] * base (i + 1 );
252+ }
253+ fd.push_back (fdk);
254+ cp_algo::checkpoint (" fdk" );
255+ auto deeper = subset_compose (fd, std::span (inpa).first (N / 2 ));
256+ for (size_t i = 0 ; i + 1 < size (fd); i++) {
257+ auto next = subset_convolution<base>(deeper[i + 1 ], std::span (inpa).last (N / 2 ));
258+ deeper[i].insert (end (deeper[i]), begin (next), end (next));
259+ }
260+ deeper.pop_back ();
261+ cp_algo::checkpoint (" combine" );
262+ return deeper;
263+ }
264+
265+ template <typename base>
266+ big_vector<base> subset_compose (std::span<base> f, std::span<base> inpa) {
267+ return subset_compose (big_vector{f}, inpa)[0 ];
268+ }
223269}
224270#pragma GCC pop_options
225271#endif // CP_ALGO_MATH_SUBSET_CONVOLUTION_HPP
0 commit comments