Skip to content

Commit 9efb2ed

Browse files
committed
add lazy_convolution.hpp, minor improvements
1 parent c66deb1 commit 9efb2ed

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

cp-algo/math/convolution.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ namespace cp_algo::math {
1414
// Writes the result into `a`; performs in-place when possible (modint path).
1515
template<class VecA, class VecB>
1616
void convolution_prefix(VecA& a, VecB const& b, size_t need) {
17-
using T = typename std::decay_t<VecA>::value_type;
17+
using T = std::decay_t<decltype(a[0])>;
18+
if constexpr (modint_type<T>) {
19+
// Use NTT-based truncated multiplication. Works in-place on `a`.
20+
fft::mul_truncate(a, b, need);
21+
return;
22+
}
1823
size_t na = std::min(need, std::size(a));
1924
size_t nb = std::min(need, std::size(b));
2025
a.resize(na);
@@ -24,11 +29,7 @@ void convolution_prefix(VecA& a, VecB const& b, size_t need) {
2429
a.clear();
2530
return;
2631
}
27-
28-
if constexpr (modint_type<T>) {
29-
// Use NTT-based truncated multiplication. Works in-place on `a`.
30-
fft::mul_truncate(a, bv, need);
31-
} else if constexpr (std::is_same_v<T, fft::point>) {
32+
if constexpr (std::is_same_v<T, fft::point>) {
3233
size_t conv_len = na + nb - 1;
3334
size_t n = std::bit_ceil(conv_len);
3435
n = std::max(n, (size_t)fft::flen);

cp-algo/math/lazy_convolution.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef CP_ALGO_MATH_LAZY_MULTIPLY_HPP
2+
#define CP_ALGO_MATH_LAZY_MULTIPLY_HPP
3+
4+
#include "convolution.hpp"
5+
#include <algorithm>
6+
7+
namespace cp_algo::math {
8+
template<typename base>
9+
auto lazy_multiply(base a0, base b0, auto &&get_ab, size_t n) {
10+
big_vector<base> A = {a0}, B = {b0};
11+
big_vector<base> C(n);
12+
C[0] = a0 * b0;
13+
auto cdq = [&](this auto &&cdq, size_t l, size_t r) -> void {
14+
if (r - l == 1) {
15+
auto [al, bl] = get_ab(A, B, C, l);
16+
A.push_back(al);
17+
B.push_back(bl);
18+
C[l] += A[l] * B[0] + A[0] * B[l];
19+
return;
20+
}
21+
auto m = (l + r) / 2;
22+
cdq(l, m);
23+
auto A_pref = std::span(A).subspan(0, std::min(m, r - l));
24+
auto B_pref = std::span(B).subspan(0, std::min(l, r - l));
25+
big_vector<base> A_suf(std::from_range, std::span(A).subspan(l, m - l));
26+
big_vector<base> B_suf(std::from_range, std::span(B).subspan(l, m - l));
27+
convolution_prefix(A_suf, B_pref, r - l);
28+
convolution_prefix(B_suf, A_pref, r - l);
29+
A_suf.resize(r - l);
30+
B_suf.resize(r - l);
31+
for(size_t i = m; i < r; i++) {
32+
C[i] += A_suf[i - l] + B_suf[i - l];
33+
}
34+
cdq(m, r);
35+
};
36+
cdq(1, n);
37+
return C;
38+
}
39+
}
40+
41+
#endif // CP_ALGO_MATH_LAZY_MULTIPLY_HPP

cp-algo/number_theory/modint.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace cp_algo::math {
2525
constexpr modint_base(Int2 rr) {
2626
to_modint().setr(UInt((rr + modmod()) % mod()));
2727
}
28-
modint inv() const {
28+
constexpr modint inv() const {
2929
return bpow(to_modint(), mod() - 2);
3030
}
3131
modint operator - () const {

cp-algo/util/simd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace cp_algo {
3131
using u16x4 = simd<uint16_t, 4>;
3232
using i16x4 = simd<int16_t, 4>;
3333
using u8x32 = simd<uint8_t, 32>;
34+
using u8x16 = simd<uint8_t, 16>;
3435
using u8x8 = simd<uint8_t, 8>;
3536
using u8x4 = simd<uint8_t, 4>;
3637
using dx4 = simd<double, 4>;

0 commit comments

Comments
 (0)