Skip to content

Commit 2a49c11

Browse files
committed
on_rank_vectors variadic caller
1 parent 1120fb7 commit 2a49c11

File tree

1 file changed

+112
-51
lines changed

1 file changed

+112
-51
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 112 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ namespace cp_algo::math {
1515

1616
enum transform_dir { forw, inv };
1717

18-
template<auto N, transform_dir direction, size_t L>
18+
template<auto N, transform_dir direction>
1919
inline void or_transform(auto &&a) {
2020
[[gnu::assume(N <= 1ull << 30)]];
2121
if constexpr (N <= 32) {
2222
for(size_t i = 1; i < N; i *= 2) {
2323
for(size_t j = 0; j < N; j += 2 * i) {
2424
for(size_t k = j; k < j + i; k++) {
25-
for(size_t z = 0; z < L; z++) {
25+
for(size_t z = 0; z < logn; z++) {
2626
if constexpr (direction == forw) {
2727
a[k + i][z] += a[k][z];
2828
} else {
@@ -34,11 +34,11 @@ namespace cp_algo::math {
3434
}
3535
} else {
3636
constexpr auto half = N / 2;
37-
or_transform<half, direction, L>(&a[0]);
38-
or_transform<half, direction, L>(&a[half]);
37+
or_transform<half, direction>(&a[0]);
38+
or_transform<half, direction>(&a[half]);
3939
for (size_t i = 0; i < half; i++) {
40-
#pragma GCC unroll L
41-
for(size_t z = 0; z < L; z++) {
40+
#pragma GCC unroll logn
41+
for(size_t z = 0; z < logn; z++) {
4242
if constexpr (direction == forw) {
4343
a[i + half][z] += a[i][z];
4444
} else {
@@ -49,17 +49,17 @@ namespace cp_algo::math {
4949
}
5050
}
5151

52-
template<transform_dir direction, size_t L>
52+
template<transform_dir direction>
5353
inline void or_transform(auto &&a, auto n) {
5454
cp_algo::with_bit_floor(n, [&]<auto NN>() {
5555
assert(NN == n);
56-
or_transform<NN, direction, L>(a);
56+
or_transform<NN, direction>(a);
5757
});
5858
}
5959

60-
template<transform_dir direction = forw, size_t L = logn>
60+
template<transform_dir direction = forw>
6161
inline void or_transform(auto &&a) {
62-
or_transform<direction, L>(a, std::size(a));
62+
or_transform<direction>(a, std::size(a));
6363
}
6464

6565
template<typename base>
@@ -86,75 +86,136 @@ namespace cp_algo::math {
8686
a[k] = res[k] >= base::mod() ? res[k] - base::mod() : res[k];
8787
}
8888
}
89-
90-
template<typename base>
91-
auto subset_convolution(auto const& inpa, auto const& inpb) {
92-
auto outpa = inpa;
93-
std::ranges::fill(outpa, base(0));
94-
auto N = std::size(inpa);
95-
constexpr size_t K = 4;
96-
N = std::max(N, K);
89+
90+
// Generic rank vectors processor with variadic inputs
91+
// Assumes output[0] = 0, caller is responsible for handling rank 0
92+
// Returns the output array
93+
template<size_t K = 4>
94+
auto on_rank_vectors(auto &&cb, auto const& ...inputs) {
95+
static_assert(sizeof...(inputs) >= 1, "on_rank_vectors requires at least one input");
96+
97+
// Create tuple of input references once
98+
auto input_tuple = std::forward_as_tuple(inputs...);
99+
auto const& first_input = std::get<0>(input_tuple);
100+
101+
auto out = first_input;
102+
using base = std::decay_t<decltype(first_input[0])>;
103+
std::ranges::fill(out, base(0));
104+
105+
auto N = std::size(first_input);
106+
constexpr size_t LOCAL_K = K;
107+
N = std::max(N, LOCAL_K);
97108
const size_t n = std::bit_width(N) - 1;
98109
const size_t T = std::min<size_t>(n - 2, 4);
99110
const size_t bottoms = 1 << (n - T);
100-
const auto M = size(outpa);
101-
cp_algo::big_vector<std::array<base, logn>> a(bottoms), b(bottoms);
111+
const auto M = std::size(first_input);
112+
113+
// Create array buffers for each input
114+
auto create_buffers = [bottoms]<typename... Args>(const Args&...) {
115+
return std::make_tuple(
116+
cp_algo::big_vector<std::array<typename std::decay_t<Args>::value_type, logn>>(bottoms)...
117+
);
118+
};
119+
auto buffers = std::apply(create_buffers, input_tuple);
120+
102121
cp_algo::big_vector<uint32_t> counts(N);
103122
for(size_t i = 1; i < N; i++) {
104123
counts[i] = (uint32_t)std::popcount(i);
105124
}
106125
cp_algo::checkpoint("prepare");
126+
107127
for(size_t top = 0; top < N; top += bottoms) {
108-
memset(a.data(), 0, sizeof(a[0]) * bottoms);
109-
memset(b.data(), 0, sizeof(b[0]) * bottoms);
110-
for(size_t mask = top; ; mask = (mask - bottoms) & top) {
111-
size_t limit = std::min(M, mask + bottoms) - mask;
112-
uint32_t count = counts[mask / bottoms] - 1;
113-
for(size_t bottom = (mask == 0); bottom < limit; bottom++) {
114-
size_t i = bottom | mask;
115-
a[bottom][count + counts[bottom]] += inpa[i];
116-
b[bottom][count + counts[bottom]] += inpb[i];
117-
}
118-
if (!mask) break;
119-
}
128+
// Clear all buffers
129+
std::apply([bottoms](auto&... bufs) {
130+
(..., memset(bufs.data(), 0, sizeof(bufs[0]) * bottoms));
131+
}, buffers);
132+
133+
// Initialize buffers from inputs
134+
std::apply([&](auto const&... inps) {
135+
std::apply([&](auto&... bufs) {
136+
auto init_one = [&](auto const& inp, auto& buf) {
137+
for(size_t mask = top; ; mask = (mask - bottoms) & top) {
138+
size_t limit = std::min(M, mask + bottoms) - mask;
139+
uint32_t count = counts[mask / bottoms] - 1;
140+
for(size_t bottom = (mask == 0); bottom < limit; bottom++) {
141+
size_t i = bottom | mask;
142+
buf[bottom][count + counts[bottom]] += inp[i];
143+
}
144+
if (!mask) break;
145+
}
146+
};
147+
(init_one(inps, bufs), ...);
148+
}, buffers);
149+
}, input_tuple);
150+
120151
cp_algo::checkpoint("init");
121-
or_transform(a);
122-
or_transform(b);
152+
std::apply([](auto&... bufs) {
153+
(..., or_transform(bufs));
154+
}, buffers);
123155
cp_algo::checkpoint("transform");
124-
assert(bottoms % K == 0);
125-
for(size_t i = 0; i < bottoms; i += K) {
126-
std::array<cp_algo::u64x4, logn> aa, bb;
127-
for(size_t j = 0; j < logn; j++) {
128-
for(size_t z = 0; z < K; z++) {
129-
aa[j][z] = a[i + z][j].getr();
130-
bb[j][z] = b[i + z][j].getr();
131-
}
132-
}
133-
convolve_logn<base>(aa, bb);
134-
for(size_t j = 0; j < logn; j++) {
135-
for(size_t z = 0; z < K; z++) {
136-
a[i + z][j].setr((uint32_t)aa[j][z]);
156+
157+
assert(bottoms % LOCAL_K == 0);
158+
for(size_t i = 0; i < bottoms; i += LOCAL_K) {
159+
std::apply([&](auto&... bufs) {
160+
auto extract_one = [&](auto& buf) {
161+
std::array<cp_algo::u64x4, logn> aa;
162+
for(size_t j = 0; j < logn; j++) {
163+
for(size_t z = 0; z < LOCAL_K; z++) {
164+
aa[j][z] = buf[i + z][j].getr();
165+
}
166+
}
167+
return aa;
168+
};
169+
170+
auto aa_tuple = std::make_tuple(extract_one(bufs)...);
171+
std::apply(cb, aa_tuple);
172+
173+
// Write results back: only first array needs to be written
174+
auto& first_buf = std::get<0>(std::forward_as_tuple(bufs...));
175+
const auto& first_aa = std::get<0>(aa_tuple);
176+
for(size_t j = 0; j < logn; j++) {
177+
for(size_t z = 0; z < LOCAL_K; z++) {
178+
first_buf[i + z][j].setr((uint32_t)first_aa[j][z]);
179+
}
137180
}
138-
}
181+
}, buffers);
139182
}
183+
140184
cp_algo::checkpoint("dot");
141-
or_transform<inv>(a);
185+
auto& first_buf = std::get<0>(buffers);
186+
or_transform<inv>(first_buf);
142187
cp_algo::checkpoint("transform");
188+
189+
// Gather results from first buffer
143190
for(size_t mask = top; mask < N; mask = (mask + bottoms) | top) {
144191
bool parity = __builtin_parity(uint32_t(mask ^ top));
145192
size_t limit = std::min(M, mask + bottoms) - mask;
146193
uint32_t count = counts[mask / bottoms] - 1;
147194
for(size_t bottom = (mask == 0); bottom < limit; bottom++) {
148195
size_t i = bottom | mask;
149196
if (parity) {
150-
outpa[i] -= a[bottom][count + counts[bottom]];
197+
out[i] -= first_buf[bottom][count + counts[bottom]];
151198
} else {
152-
outpa[i] += a[bottom][count + counts[bottom]];
199+
out[i] += first_buf[bottom][count + counts[bottom]];
153200
}
154201
}
155202
}
156203
cp_algo::checkpoint("gather");
157204
}
205+
return out;
206+
}
207+
208+
template<typename base>
209+
auto subset_convolution(auto const& inpa, auto const& inpb) {
210+
auto M = std::size(inpa);
211+
212+
auto callback = [&](auto &aa, auto const& bb) {
213+
convolve_logn<base>(aa, bb);
214+
};
215+
216+
constexpr size_t K = 4;
217+
auto outpa = on_rank_vectors<K>(callback, inpa, inpb);
218+
158219
outpa[0] = inpa[0] * inpb[0];
159220
for(size_t i = 1; i < M; i++) {
160221
outpa[i] += inpa[i] * inpb[0] + inpa[0] * inpb[i];

0 commit comments

Comments
 (0)