Skip to content

Commit 02d6019

Browse files
committed
Improve init / gather
1 parent 76bd173 commit 02d6019

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

cp-algo/math/subset_convolution.hpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ namespace cp_algo::math {
8888
};
8989
auto buffers = std::apply(create_buffers, input_tuple);
9090

91-
big_vector<uint32_t> counts(N);
92-
for(size_t i = 1; i < N; i++) {
93-
counts[i] = (uint32_t)std::popcount(i) - 1;
91+
checkpoint("alloc buffers");
92+
big_vector<uint32_t> counts(2 * bottoms);
93+
for(size_t i = 1; i < 2 * bottoms; i++) {
94+
counts[i] = (uint32_t)std::popcount(i);
9495
}
9596
checkpoint("prepare");
9697

@@ -105,12 +106,16 @@ namespace cp_algo::math {
105106
std::apply([&](auto const&... inps) {
106107
std::apply([&](auto&... bufs) {
107108
auto init_one = [&](auto const& inp, auto& buf) {
108-
for(size_t i = 1; i < M; i++) {
109-
size_t bottom = (i >> 1) & (bottoms - 1);
110-
if (__builtin_parity(uint32_t((i >> 1) & top))) {
111-
buf[bottom][counts[i]] -= inp[i];
112-
} else {
113-
buf[bottom][counts[i]] += inp[i];
109+
for(size_t i = 0; i < M; i += 2 * bottoms) {
110+
bool parity = __builtin_parity(uint32_t((i >> 1) & top));
111+
size_t limit = std::min(M, i + 2 * bottoms) - i;
112+
uint32_t count = (uint32_t)std::popcount(i) - 1;
113+
for(size_t bottom = (i == 0); bottom < limit; bottom++) {
114+
if (parity) {
115+
buf[bottom >> 1][count + counts[bottom]] -= inp[i + bottom];
116+
} else {
117+
buf[bottom >> 1][count + counts[bottom]] += inp[i + bottom];
118+
}
114119
}
115120
}
116121
};
@@ -158,13 +163,16 @@ namespace cp_algo::math {
158163

159164
// Gather results from first buffer
160165

161-
162-
for(size_t i = 1; i < M; i++) {
163-
size_t bottom = (i >> 1) & (bottoms - 1);
164-
if (__builtin_parity(uint32_t((i >> 1) & top))) {
165-
out[i] -= first_buf[bottom][counts[i]];
166-
} else {
167-
out[i] += first_buf[bottom][counts[i]];
166+
for(size_t i = 0; i < M; i += 2 * bottoms) {
167+
bool parity = __builtin_parity(uint32_t((i >> 1) & top));
168+
size_t limit = std::min(M, i + 2 * bottoms) - i;
169+
uint32_t count = (uint32_t)std::popcount(i) - 1;
170+
for(size_t bottom = (i == 0); bottom < limit; bottom++) {
171+
if (parity) {
172+
out[i + bottom] -= first_buf[bottom >> 1][count + counts[bottom]];
173+
} else {
174+
out[i + bottom] += first_buf[bottom >> 1][count + counts[bottom]];
175+
}
168176
}
169177
}
170178
checkpoint("gather");

0 commit comments

Comments
 (0)