@@ -71,10 +71,8 @@ namespace cp_algo::math {
7171 // Create tuple of input references once
7272 auto input_tuple = std::forward_as_tuple (inputs...);
7373 auto const & first_input = std::get<0 >(input_tuple);
74-
75- auto out = first_input;
7674 using base = std::decay_t <decltype (first_input[0 ])>;
77- std::ranges::fill (out, base ( 0 ));
75+ big_vector<base> out ( std::size (first_input ));
7876
7977 auto N = std::size (first_input);
8078 constexpr size_t K = 4 ;
@@ -180,7 +178,7 @@ namespace cp_algo::math {
180178 }
181179
182180 template <typename base>
183- auto subset_convolution (auto const & inpa, auto const & inpb) {
181+ big_vector<base> subset_convolution (auto const & inpa, auto const & inpb) {
184182 auto outpa = on_rank_vectors ([](auto &a, auto const & b) {
185183 std::decay_t <decltype (a)> res = {};
186184 const auto mod = base::mod ();
@@ -212,36 +210,15 @@ namespace cp_algo::math {
212210 }
213211
214212 template <typename base>
215- auto subset_exp (auto const & inpa) {
216- auto outpa = on_rank_vectors ([](auto &p) {
217- std::decay_t <decltype (p)> q = {};
218- const auto mod = base::mod ();
219- const auto imod = math::inv2 (-mod);
220- const auto r2 = uint32_t (-1 ) % mod + 1 ;
221- const auto r4 = uint64_t (-1 ) % mod + 1 ;
222- static const auto invs = [&]() {
223- std::array<uint64_t , logn> invs;
224- for (size_t i = 0 ; i < logn; i++) {
225- invs[i] = (base (i + 1 ).inv () * base (r4)).getr ();
226- }
227- return invs;
228- }();
229- for (size_t k = 0 ; k < logn; k++) {
230- q[k] = p[k] = montgomery_mul (p[k], u64x4 () + (k + 1 ) * r2 % mod, mod, imod);
231- for (size_t i = 0 ; i < k; i++) {
232- q[k] += (u64x4)_mm256_mul_epu32 ((__m256i)p[i], (__m256i)q[k - i - 1 ]);
233- if (i == logn / 2 ) {
234- q[k] = q[k] >= base::modmod8 () ? q[k] - base::modmod8 () : q[k];
235- }
236- }
237- q[k] = montgomery_reduce (q[k], mod, imod);
238- q[k] = montgomery_mul (q[k], u64x4 () + invs[k], mod, imod);
239- q[k] = q[k] >= mod ? q[k] - mod : q[k];
240- }
241- p = q;
242- }, inpa);
243- outpa[0 ] = base (1 );
244- return outpa;
213+ big_vector<base> subset_exp (auto const & inpa) {
214+ if (size (inpa) == 1 ) {
215+ return big_vector<base>{1 };
216+ }
217+ size_t N = std::size (inpa);
218+ auto out0 = subset_exp<base>(std::span (inpa).first (N / 2 ));
219+ auto out1 = subset_convolution<base>(out0, std::span (inpa).last (N / 2 ));
220+ out0.insert (end (out0), begin (out1), end (out1));
221+ return out0;
245222 }
246223}
247224#pragma GCC pop_options
0 commit comments