@@ -88,9 +88,10 @@ namespace cp_algo::math {
8888 };
8989 auto buffers = std::apply (create_buffers, input_tuple);
9090
91- big_vector<uint32_t > counts (N);
92- for (size_t i = 1 ; i < N; i++) {
93- counts[i] = (uint32_t )std::popcount (i) - 1 ;
91+ checkpoint (" alloc buffers" );
92+ big_vector<uint32_t > counts (2 * bottoms);
93+ for (size_t i = 1 ; i < 2 * bottoms; i++) {
94+ counts[i] = (uint32_t )std::popcount (i);
9495 }
9596 checkpoint (" prepare" );
9697
@@ -105,12 +106,16 @@ namespace cp_algo::math {
105106 std::apply ([&](auto const &... inps) {
106107 std::apply ([&](auto &... bufs) {
107108 auto init_one = [&](auto const & inp, auto & buf) {
108- for (size_t i = 1 ; i < M; i++) {
109- size_t bottom = (i >> 1 ) & (bottoms - 1 );
110- if (__builtin_parity (uint32_t ((i >> 1 ) & top))) {
111- buf[bottom][counts[i]] -= inp[i];
112- } else {
113- buf[bottom][counts[i]] += inp[i];
109+ for (size_t i = 0 ; i < M; i += 2 * bottoms) {
110+ bool parity = __builtin_parity (uint32_t ((i >> 1 ) & top));
111+ size_t limit = std::min (M, i + 2 * bottoms) - i;
112+ uint32_t count = (uint32_t )std::popcount (i) - 1 ;
113+ for (size_t bottom = (i == 0 ); bottom < limit; bottom++) {
114+ if (parity) {
115+ buf[bottom >> 1 ][count + counts[bottom]] -= inp[i + bottom];
116+ } else {
117+ buf[bottom >> 1 ][count + counts[bottom]] += inp[i + bottom];
118+ }
114119 }
115120 }
116121 };
@@ -158,13 +163,16 @@ namespace cp_algo::math {
158163
159164 // Gather results from first buffer
160165
161-
162- for (size_t i = 1 ; i < M; i++) {
163- size_t bottom = (i >> 1 ) & (bottoms - 1 );
164- if (__builtin_parity (uint32_t ((i >> 1 ) & top))) {
165- out[i] -= first_buf[bottom][counts[i]];
166- } else {
167- out[i] += first_buf[bottom][counts[i]];
166+ for (size_t i = 0 ; i < M; i += 2 * bottoms) {
167+ bool parity = __builtin_parity (uint32_t ((i >> 1 ) & top));
168+ size_t limit = std::min (M, i + 2 * bottoms) - i;
169+ uint32_t count = (uint32_t )std::popcount (i) - 1 ;
170+ for (size_t bottom = (i == 0 ); bottom < limit; bottom++) {
171+ if (parity) {
172+ out[i + bottom] -= first_buf[bottom >> 1 ][count + counts[bottom]];
173+ } else {
174+ out[i + bottom] += first_buf[bottom >> 1 ][count + counts[bottom]];
175+ }
168176 }
169177 }
170178 checkpoint (" gather" );
0 commit comments