@@ -15,14 +15,14 @@ namespace cp_algo::math {
1515
1616 enum transform_dir { forw, inv };
1717
18- template <auto N, transform_dir direction, size_t L >
18+ template <auto N, transform_dir direction>
1919 inline void or_transform (auto &&a) {
2020 [[gnu::assume (N <= 1ull << 30 )]];
2121 if constexpr (N <= 32 ) {
2222 for (size_t i = 1 ; i < N; i *= 2 ) {
2323 for (size_t j = 0 ; j < N; j += 2 * i) {
2424 for (size_t k = j; k < j + i; k++) {
25- for (size_t z = 0 ; z < L ; z++) {
25+ for (size_t z = 0 ; z < logn ; z++) {
2626 if constexpr (direction == forw) {
2727 a[k + i][z] += a[k][z];
2828 } else {
@@ -34,11 +34,11 @@ namespace cp_algo::math {
3434 }
3535 } else {
3636 constexpr auto half = N / 2 ;
37- or_transform<half, direction, L >(&a[0 ]);
38- or_transform<half, direction, L >(&a[half]);
37+ or_transform<half, direction>(&a[0 ]);
38+ or_transform<half, direction>(&a[half]);
3939 for (size_t i = 0 ; i < half; i++) {
40- #pragma GCC unroll L
41- for (size_t z = 0 ; z < L ; z++) {
40+ #pragma GCC unroll logn
41+ for (size_t z = 0 ; z < logn ; z++) {
4242 if constexpr (direction == forw) {
4343 a[i + half][z] += a[i][z];
4444 } else {
@@ -49,17 +49,17 @@ namespace cp_algo::math {
4949 }
5050 }
5151
52- template <transform_dir direction, size_t L >
52+ template <transform_dir direction>
5353 inline void or_transform (auto &&a, auto n) {
5454 cp_algo::with_bit_floor (n, [&]<auto NN>() {
5555 assert (NN == n);
56- or_transform<NN, direction, L >(a);
56+ or_transform<NN, direction>(a);
5757 });
5858 }
5959
60- template <transform_dir direction = forw, size_t L = logn >
60+ template <transform_dir direction = forw>
6161 inline void or_transform (auto &&a) {
62- or_transform<direction, L >(a, std::size (a));
62+ or_transform<direction>(a, std::size (a));
6363 }
6464
6565 template <typename base>
@@ -86,75 +86,136 @@ namespace cp_algo::math {
8686 a[k] = res[k] >= base::mod () ? res[k] - base::mod () : res[k];
8787 }
8888 }
89-
90- template <typename base>
91- auto subset_convolution (auto const & inpa, auto const & inpb) {
92- auto outpa = inpa;
93- std::ranges::fill (outpa, base (0 ));
94- auto N = std::size (inpa);
95- constexpr size_t K = 4 ;
96- N = std::max (N, K);
89+
90+ // Generic rank vectors processor with variadic inputs
91+ // Assumes output[0] = 0, caller is responsible for handling rank 0
92+ // Returns the output array
93+ template <size_t K = 4 >
94+ auto on_rank_vectors (auto &&cb, auto const & ...inputs) {
95+ static_assert (sizeof ...(inputs) >= 1 , " on_rank_vectors requires at least one input" );
96+
97+ // Create tuple of input references once
98+ auto input_tuple = std::forward_as_tuple (inputs...);
99+ auto const & first_input = std::get<0 >(input_tuple);
100+
101+ auto out = first_input;
102+ using base = std::decay_t <decltype (first_input[0 ])>;
103+ std::ranges::fill (out, base (0 ));
104+
105+ auto N = std::size (first_input);
106+ constexpr size_t LOCAL_K = K;
107+ N = std::max (N, LOCAL_K);
97108 const size_t n = std::bit_width (N) - 1 ;
98109 const size_t T = std::min<size_t >(n - 2 , 4 );
99110 const size_t bottoms = 1 << (n - T);
100- const auto M = size (outpa);
101- cp_algo::big_vector<std::array<base, logn>> a (bottoms), b (bottoms);
111+ const auto M = std::size (first_input);
112+
113+ // Create array buffers for each input
114+ auto create_buffers = [bottoms]<typename ... Args>(const Args&...) {
115+ return std::make_tuple (
116+ cp_algo::big_vector<std::array<typename std::decay_t <Args>::value_type, logn>>(bottoms)...
117+ );
118+ };
119+ auto buffers = std::apply (create_buffers, input_tuple);
120+
102121 cp_algo::big_vector<uint32_t > counts (N);
103122 for (size_t i = 1 ; i < N; i++) {
104123 counts[i] = (uint32_t )std::popcount (i);
105124 }
106125 cp_algo::checkpoint (" prepare" );
126+
107127 for (size_t top = 0 ; top < N; top += bottoms) {
108- memset (a.data (), 0 , sizeof (a[0 ]) * bottoms);
109- memset (b.data (), 0 , sizeof (b[0 ]) * bottoms);
110- for (size_t mask = top; ; mask = (mask - bottoms) & top) {
111- size_t limit = std::min (M, mask + bottoms) - mask;
112- uint32_t count = counts[mask / bottoms] - 1 ;
113- for (size_t bottom = (mask == 0 ); bottom < limit; bottom++) {
114- size_t i = bottom | mask;
115- a[bottom][count + counts[bottom]] += inpa[i];
116- b[bottom][count + counts[bottom]] += inpb[i];
117- }
118- if (!mask) break ;
119- }
128+ // Clear all buffers
129+ std::apply ([bottoms](auto &... bufs) {
130+ (..., memset (bufs.data (), 0 , sizeof (bufs[0 ]) * bottoms));
131+ }, buffers);
132+
133+ // Initialize buffers from inputs
134+ std::apply ([&](auto const &... inps) {
135+ std::apply ([&](auto &... bufs) {
136+ auto init_one = [&](auto const & inp, auto & buf) {
137+ for (size_t mask = top; ; mask = (mask - bottoms) & top) {
138+ size_t limit = std::min (M, mask + bottoms) - mask;
139+ uint32_t count = counts[mask / bottoms] - 1 ;
140+ for (size_t bottom = (mask == 0 ); bottom < limit; bottom++) {
141+ size_t i = bottom | mask;
142+ buf[bottom][count + counts[bottom]] += inp[i];
143+ }
144+ if (!mask) break ;
145+ }
146+ };
147+ (init_one (inps, bufs), ...);
148+ }, buffers);
149+ }, input_tuple);
150+
120151 cp_algo::checkpoint (" init" );
121- or_transform (a);
122- or_transform (b);
152+ std::apply ([](auto &... bufs) {
153+ (..., or_transform (bufs));
154+ }, buffers);
123155 cp_algo::checkpoint (" transform" );
124- assert (bottoms % K == 0 );
125- for (size_t i = 0 ; i < bottoms; i += K) {
126- std::array<cp_algo::u64x4, logn> aa, bb;
127- for (size_t j = 0 ; j < logn; j++) {
128- for (size_t z = 0 ; z < K; z++) {
129- aa[j][z] = a[i + z][j].getr ();
130- bb[j][z] = b[i + z][j].getr ();
131- }
132- }
133- convolve_logn<base>(aa, bb);
134- for (size_t j = 0 ; j < logn; j++) {
135- for (size_t z = 0 ; z < K; z++) {
136- a[i + z][j].setr ((uint32_t )aa[j][z]);
156+
157+ assert (bottoms % LOCAL_K == 0 );
158+ for (size_t i = 0 ; i < bottoms; i += LOCAL_K) {
159+ std::apply ([&](auto &... bufs) {
160+ auto extract_one = [&](auto & buf) {
161+ std::array<cp_algo::u64x4, logn> aa;
162+ for (size_t j = 0 ; j < logn; j++) {
163+ for (size_t z = 0 ; z < LOCAL_K; z++) {
164+ aa[j][z] = buf[i + z][j].getr ();
165+ }
166+ }
167+ return aa;
168+ };
169+
170+ auto aa_tuple = std::make_tuple (extract_one (bufs)...);
171+ std::apply (cb, aa_tuple);
172+
173+ // Write results back: only first array needs to be written
174+ auto & first_buf = std::get<0 >(std::forward_as_tuple (bufs...));
175+ const auto & first_aa = std::get<0 >(aa_tuple);
176+ for (size_t j = 0 ; j < logn; j++) {
177+ for (size_t z = 0 ; z < LOCAL_K; z++) {
178+ first_buf[i + z][j].setr ((uint32_t )first_aa[j][z]);
179+ }
137180 }
138- }
181+ }, buffers);
139182 }
183+
140184 cp_algo::checkpoint (" dot" );
141- or_transform<inv>(a);
185+ auto & first_buf = std::get<0 >(buffers);
186+ or_transform<inv>(first_buf);
142187 cp_algo::checkpoint (" transform" );
188+
189+ // Gather results from first buffer
143190 for (size_t mask = top; mask < N; mask = (mask + bottoms) | top) {
144191 bool parity = __builtin_parity (uint32_t (mask ^ top));
145192 size_t limit = std::min (M, mask + bottoms) - mask;
146193 uint32_t count = counts[mask / bottoms] - 1 ;
147194 for (size_t bottom = (mask == 0 ); bottom < limit; bottom++) {
148195 size_t i = bottom | mask;
149196 if (parity) {
150- outpa [i] -= a [bottom][count + counts[bottom]];
197+ out [i] -= first_buf [bottom][count + counts[bottom]];
151198 } else {
152- outpa [i] += a [bottom][count + counts[bottom]];
199+ out [i] += first_buf [bottom][count + counts[bottom]];
153200 }
154201 }
155202 }
156203 cp_algo::checkpoint (" gather" );
157204 }
205+ return out;
206+ }
207+
208+ template <typename base>
209+ auto subset_convolution (auto const & inpa, auto const & inpb) {
210+ auto M = std::size (inpa);
211+
212+ auto callback = [&](auto &aa, auto const & bb) {
213+ convolve_logn<base>(aa, bb);
214+ };
215+
216+ constexpr size_t K = 4 ;
217+ auto outpa = on_rank_vectors<K>(callback, inpa, inpb);
218+
158219 outpa[0 ] = inpa[0 ] * inpb[0 ];
159220 for (size_t i = 1 ; i < M; i++) {
160221 outpa[i] += inpa[i] * inpb[0 ] + inpa[0 ] * inpb[i];
0 commit comments