Skip to content

Commit 0ae96a7

Browse files
committed
Better subset exp
1 parent 2bb08b9 commit 0ae96a7

File tree

1 file changed

+11
-34
lines changed

1 file changed

+11
-34
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ namespace cp_algo::math {
7171
// Create tuple of input references once
7272
auto input_tuple = std::forward_as_tuple(inputs...);
7373
auto const& first_input = std::get<0>(input_tuple);
74-
75-
auto out = first_input;
7674
using base = std::decay_t<decltype(first_input[0])>;
77-
std::ranges::fill(out, base(0));
75+
big_vector<base> out(std::size(first_input));
7876

7977
auto N = std::size(first_input);
8078
constexpr size_t K = 4;
@@ -180,7 +178,7 @@ namespace cp_algo::math {
180178
}
181179

182180
template<typename base>
183-
auto subset_convolution(auto const& inpa, auto const& inpb) {
181+
big_vector<base> subset_convolution(auto const& inpa, auto const& inpb) {
184182
auto outpa = on_rank_vectors([](auto &a, auto const& b) {
185183
std::decay_t<decltype(a)> res = {};
186184
const auto mod = base::mod();
@@ -212,36 +210,15 @@ namespace cp_algo::math {
212210
}
213211

214212
template<typename base>
215-
auto subset_exp(auto const& inpa) {
216-
auto outpa = on_rank_vectors([](auto &p) {
217-
std::decay_t<decltype(p)> q = {};
218-
const auto mod = base::mod();
219-
const auto imod = math::inv2(-mod);
220-
const auto r2 = uint32_t(-1) % mod + 1;
221-
const auto r4 = uint64_t(-1) % mod + 1;
222-
static const auto invs = [&]() {
223-
std::array<uint64_t, logn> invs;
224-
for(size_t i = 0; i < logn; i++) {
225-
invs[i] = (base(i + 1).inv() * base(r4)).getr();
226-
}
227-
return invs;
228-
}();
229-
for(size_t k = 0; k < logn; k++) {
230-
q[k] = p[k] = montgomery_mul(p[k], u64x4() + (k + 1) * r2 % mod, mod, imod);
231-
for(size_t i = 0; i < k; i++) {
232-
q[k] += (u64x4)_mm256_mul_epu32((__m256i)p[i], (__m256i)q[k - i - 1]);
233-
if (i == logn / 2) {
234-
q[k] = q[k] >= base::modmod8() ? q[k] - base::modmod8() : q[k];
235-
}
236-
}
237-
q[k] = montgomery_reduce(q[k], mod, imod);
238-
q[k] = montgomery_mul(q[k], u64x4() + invs[k], mod, imod);
239-
q[k] = q[k] >= mod ? q[k] - mod : q[k];
240-
}
241-
p = q;
242-
}, inpa);
243-
outpa[0] = base(1);
244-
return outpa;
213+
big_vector<base> subset_exp(auto const& inpa) {
214+
if (size(inpa) == 1) {
215+
return big_vector<base>{1};
216+
}
217+
size_t N = std::size(inpa);
218+
auto out0 = subset_exp<base>(std::span(inpa).first(N / 2));
219+
auto out1 = subset_convolution<base>(out0, std::span(inpa).last(N / 2));
220+
out0.insert(end(out0), begin(out1), end(out1));
221+
return out0;
245222
}
246223
}
247224
#pragma GCC pop_options

0 commit comments

Comments
 (0)