|
11 | 11 | #include <cstring> |
12 | 12 | CP_ALGO_SIMD_PRAGMA_PUSH |
13 | 13 | namespace cp_algo::math { |
14 | | - const size_t max_logn = 20; |
| 14 | +#ifndef CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN |
| 15 | +#define CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN 20 |
| 16 | +#endif |
| 17 | + const size_t max_logn = CP_ALGO_SUBSET_CONVOLUTION_MAX_LOGN; |
15 | 18 |
|
16 | 19 | enum transform_dir { forw, inv }; |
17 | 20 |
|
18 | 21 | template<auto N, transform_dir direction> |
19 | 22 | inline void xor_transform(auto &&a) { |
20 | | - [[gnu::assume(N <= 1 << 30)]]; |
21 | | - if constexpr (N <= 32) { |
| 23 | + if constexpr (N >> max_logn) { |
| 24 | + throw std::runtime_error("N too large for xor_transform"); |
| 25 | + } else if constexpr (N <= 32) { |
22 | 26 | for (size_t i = 1; i < N; i *= 2) { |
23 | 27 | for (size_t j = 0; j < N; j += 2 * i) { |
24 | 28 | for (size_t k = j; k < j + i; k++) { |
@@ -96,7 +100,7 @@ namespace cp_algo::math { |
96 | 100 | constexpr size_t K = 4; |
97 | 101 | N = std::max(N, 2 * K); |
98 | 102 | const size_t n = std::bit_width(N) - 1; |
99 | | - const size_t T = std::min<size_t>(n - 3, 2); |
| 103 | + const size_t T = std::min<size_t>(n - 3, 3); |
100 | 104 | const size_t bottoms = 1 << (n - T - 1); |
101 | 105 | const auto M = std::size(first_input); |
102 | 106 |
|
@@ -205,33 +209,35 @@ namespace cp_algo::math { |
205 | 209 | template<typename base> |
206 | 210 | big_vector<base> subset_convolution(std::span<base> f, std::span<base> g) { |
207 | 211 | big_vector<base> outpa; |
208 | | - with_bit_floor(std::size(f), [&]<auto N>() { |
209 | | - constexpr size_t lgn = std::bit_width(N) - 1; |
210 | | - [[gnu::assume(lgn <= max_logn)]]; |
211 | | - outpa = on_rank_vectors([](auto &a, auto const& b) { |
212 | | - std::decay_t<decltype(a)> res = {}; |
213 | | - const auto mod = base::mod(); |
214 | | - const auto imod = math::inv2(-mod); |
215 | | - const auto r4 = u64x4() + uint64_t(-1) % mod + 1; |
216 | | - auto add = [&](size_t i) { |
217 | | - if constexpr (lgn) for(size_t j = 0; i + j + 1 < lgn; j++) { |
218 | | - res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j])); |
| 212 | + constexpr size_t lgn = max_logn; |
| 213 | + outpa = on_rank_vectors([](auto &a, auto const& b) { |
| 214 | + std::decay_t<decltype(a)> res = {}; |
| 215 | + const auto mod = base::mod(); |
| 216 | + const auto imod = math::inv2(-mod); |
| 217 | + const auto r4 = u64x4() + uint64_t(-1) % mod + 1; |
| 218 | + auto add = [&](size_t i) { |
| 219 | + if constexpr (lgn) for(size_t j = 0; i + j + 1 < lgn; j++) { |
| 220 | + res[i + j + 1] += (u64x4)_mm256_mul_epu32(__m256i(a[i]), __m256i(b[j])); |
| 221 | + } |
| 222 | + if constexpr (lgn >= 20) if (i == 15) { |
| 223 | + for(size_t k = 0; k < lgn; k++) { |
| 224 | + res[k] = res[k] >= base::modmod8() ? res[k] - base::modmod8() : res[k]; |
219 | 225 | } |
220 | | - }; |
221 | | - if constexpr (lgn) for(size_t i = 0; i < lgn; i++) { add(i); } |
222 | | - if constexpr (lgn) if constexpr (lgn) for(size_t k = 0; k < lgn; k++) { |
223 | | - res[k] = montgomery_reduce(res[k], mod, imod); |
224 | | - res[k] = montgomery_mul(res[k], r4, mod, imod); |
225 | | - a[k] = res[k] >= mod ? res[k] - mod : res[k]; |
226 | 226 | } |
227 | | - }, f, g); |
228 | | - |
229 | | - outpa[0] = f[0] * g[0]; |
230 | | - for(size_t i = 1; i < std::size(f); i++) { |
231 | | - outpa[i] += f[i] * g[0] + f[0] * g[i]; |
| 227 | + }; |
| 228 | + if constexpr (lgn) for(size_t i = 0; i < lgn; i++) { add(i); } |
| 229 | + if constexpr (lgn) if constexpr (lgn) for(size_t k = 0; k < lgn; k++) { |
| 230 | + res[k] = montgomery_reduce(res[k], mod, imod); |
| 231 | + res[k] = montgomery_mul(res[k], r4, mod, imod); |
| 232 | + a[k] = res[k] >= mod ? res[k] - mod : res[k]; |
232 | 233 | } |
233 | | - checkpoint("fix 0"); |
234 | | - }); |
| 234 | + }, f, g); |
| 235 | + |
| 236 | + outpa[0] = f[0] * g[0]; |
| 237 | + for(size_t i = 1; i < std::size(f); i++) { |
| 238 | + outpa[i] += f[i] * g[0] + f[0] * g[i]; |
| 239 | + } |
| 240 | + checkpoint("fix 0"); |
235 | 241 | return outpa; |
236 | 242 | } |
237 | 243 |
|
|
0 commit comments