Skip to content

Commit 75b80ae

Browse files
committed
Tighter memory usage, less constexpr stuff for faster compilation in subset convolution
1 parent 2232586 commit 75b80ae

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
#include <cstring>
1212
CP_ALGO_SIMD_PRAGMA_PUSH
1313
namespace cp_algo::math {
14-
const size_t max_logn = 20;
14+
#ifndef CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN
15+
#define CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN 20
16+
#endif
17+
const size_t max_logn = CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN;
1518

1619
enum transform_dir { forw, inv };
1720

1821
template<auto N, transform_dir direction>
1922
inline void xor_transform(auto &&a) {
20-
[[gnu::assume(N <= 1 << 30)]];
21-
if constexpr (N <= 32) {
23+
if constexpr (N >> max_logn) {
24+
throw std::runtime_error("N too large for xor_transform");
25+
} else if constexpr (N <= 32) {
2226
for (size_t i = 1; i < N; i *= 2) {
2327
for (size_t j = 0; j < N; j += 2 * i) {
2428
for (size_t k = j; k < j + i; k++) {
@@ -96,7 +100,7 @@ namespace cp_algo::math {
96100
constexpr size_t K = 4;
97101
N = std::max(N, 2 * K);
98102
const size_t n = std::bit_width(N) - 1;
99-
const size_t T = std::min<size_t>(n - 3, 2);
103+
const size_t T = std::min<size_t>(n - 3, 3);
100104
const size_t bottoms = 1 << (n - T - 1);
101105
const auto M = std::size(first_input);
102106

@@ -205,33 +209,35 @@ namespace cp_algo::math {
205209
template<typename base>
206210
big_vector<base> subset_convolution(std::span<base> f, std::span<base> g) {
207211
big_vector<base> outpa;
208-
with_bit_floor(std::size(f), [&]<auto N>() {
209-
constexpr size_t lgn = std::bit_width(N) - 1;
210-
[[gnu::assume(lgn <= max_logn)]];
211-
outpa = on_rank_vectors([](auto &a, auto const& b) {
212-
std::decay_t<decltype(a)> res = {};
213-
const auto mod = base::mod();
214-
const auto imod = math::inv2(-mod);
215-
const auto r4 = u64x4() + uint64_t(-1) % mod + 1;
216-
auto add = [&](size_t i) {
217-
if constexpr (lgn) for(size_t j = 0; i + j + 1 < lgn; j++) {
218-
res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j]));
212+
constexpr size_t lgn = max_logn;
213+
outpa = on_rank_vectors([](auto &a, auto const& b) {
214+
std::decay_t<decltype(a)> res = {};
215+
const auto mod = base::mod();
216+
const auto imod = math::inv2(-mod);
217+
const auto r4 = u64x4() + uint64_t(-1) % mod + 1;
218+
auto add = [&](size_t i) {
219+
if constexpr (lgn) for(size_t j = 0; i + j + 1 < lgn; j++) {
220+
res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j]));
221+
}
222+
if constexpr (lgn >= 20) if (i == 15) {
223+
for(size_t k = 0; k < lgn; k++) {
224+
res[k] = res[k] >= base::modmod8() ? res[k] - base::modmod8() : res[k];
219225
}
220-
};
221-
if constexpr (lgn) for(size_t i = 0; i < lgn; i++) { add(i); }
222-
if constexpr (lgn) if constexpr (lgn) for(size_t k = 0; k < lgn; k++) {
223-
res[k] = montgomery_reduce(res[k], mod, imod);
224-
res[k] = montgomery_mul(res[k], r4, mod, imod);
225-
a[k] = res[k] >= mod ? res[k] - mod : res[k];
226226
}
227-
}, f, g);
228-
229-
outpa[0] = f[0] * g[0];
230-
for(size_t i = 1; i < std::size(f); i++) {
231-
outpa[i] += f[i] * g[0] + f[0] * g[i];
227+
};
228+
if constexpr (lgn) for(size_t i = 0; i < lgn; i++) { add(i); }
229+
if constexpr (lgn) if constexpr (lgn) for(size_t k = 0; k < lgn; k++) {
230+
res[k] = montgomery_reduce(res[k], mod, imod);
231+
res[k] = montgomery_mul(res[k], r4, mod, imod);
232+
a[k] = res[k] >= mod ? res[k] - mod : res[k];
232233
}
233-
checkpoint("fix 0");
234-
});
234+
}, f, g);
235+
236+
outpa[0] = f[0] * g[0];
237+
for(size_t i = 1; i < std::size(f); i++) {
238+
outpa[i] += f[i] * g[0] + f[0] * g[i];
239+
}
240+
checkpoint("fix 0");
235241
return outpa;
236242
}
237243

0 commit comments

Comments
 (0)