+for(size_t z=0;z<L;z++){if constexpr(direction==forw){a[i+half][z]+=a[i][z];}else{a[i+half][z]-=a[i][z];}}}}}template<transform_dir direction,size_t L>inline void or_transform(auto&&a,auto n){cp_algo::with_bit_floor(n,[&]<auto NN>(){assert(NN==n);or_transform<NN,direction,L>(a);});}template<transform_dir direction=forw,size_t L=logn>inline void or_transform(auto&&a){or_transform<direction,L>(a,std::size(a));}template<typename base>void convolve_logn(auto&a,auto const&b){std::decay_t<decltype(a)>res={};const auto mod=base::mod();const auto imod=cp_algo::math::inv2(-mod);const auto r2=cp_algo::u64x4()+uint32_t(-1)%mod+1;const auto r4=cp_algo::u64x4()+uint64_t(-1)%mod+1;for(size_t i=0;i<logn;i++){for(size_t j=0;i+j+1<logn;j++){res[i+j+1]+=(cp_algo::u64x4)_mm256_mul_epu32(__m256i(a[i]),__m256i(b[j]));}if(i==logn/2){for(size_t k=logn/2;k<logn;k++){res[k]=cp_algo::montgomery_reduce(res[k],mod,imod);res[k]=(cp_algo::u64x4)_mm256_mul_epu32(__m256i(res[k]),__m256i(r2));}}}for(size_t k=0;k<logn;k++){res[k]=cp_algo::montgomery_reduce(res[k],mod,imod);res[k]=cp_algo::montgomery_mul(res[k],r4,mod,imod);a[k]=res[k]>=base::mod()?res[k]-base::mod():res[k];}}template<typename base>auto subset_convolution(auto const&inpa,auto const&inpb){auto outpa=inpa;std::ranges::fill(outpa,base(0));auto N=std::size(inpa);constexpr size_t K=4;N=std::max(N,K);const size_t n=std::bit_width(N)-1;const size_t T=std::min<size_t>(n-2,4);const size_t bottoms=1<<(n-T);const auto M=size(outpa);cp_algo::big_vector<std::array<base,logn>>a(bottoms),b(bottoms);cp_algo::big_vector<uint32_t>counts(N);for(size_t i=1;i<N;i++){counts[i]=(uint32_t)std::popcount(i);}cp_algo::checkpoint("prepare");for(size_t top=0;top<N;top+=bottoms){memset(a.data(),0,sizeof(a[0])*bottoms);memset(b.data(),0,sizeof(b[0])*bottoms);for(size_t mask=top;;mask=(mask-bottoms)&top){size_t limit=std::min(M,mask+bottoms)-mask;uint32_t count=counts[mask/bottoms]-1;for(size_t bottom=(mask==0);bottom<limit;bottom++){size_t i=bottom|mask;a[bottom][count+counts[bottom]]+=inpa[i];b[bottom][count+counts[bottom]]+=inpb[i];}if(!mask)break;}cp_algo::checkpoint("init");or_transform(a);or_transform(b);cp_algo::checkpoint("transform");assert(bottoms%K==0);for(size_t i=0;i<bottoms;i+=K){std::array<cp_algo::u64x4,logn>aa,bb;for(size_t j=0;j<logn;j++){for(size_t z=0;z<K;z++){aa[j][z]=a[i+z][j].getr();bb[j][z]=b[i+z][j].getr();}}convolve_logn<base>(aa,bb);for(size_t j=0;j<logn;j++){for(size_t z=0;z<K;z++){a[i+z][j].setr((uint32_t)aa[j][z]);}}}cp_algo::checkpoint("dot");or_transform<inv>(a);cp_algo::checkpoint("transform");for(size_t mask=top;mask<N;mask=(mask+bottoms)|top){bool parity=__builtin_parity(uint32_t(mask^top));size_t limit=std::min(M,mask+bottoms)-mask;uint32_t count=counts[mask/bottoms]-1;for(size_t bottom=(mask==0);bottom<limit;bottom++){size_t i=bottom|mask;if(parity){outpa[i]-=a[bottom][count+counts[bottom]];}else{outpa[i]+=a[bottom][count+counts[bottom]];}}}cp_algo::checkpoint("gather");}outpa[0]=inpa[0]*inpb[0];for(size_t i=1;i<M;i++){outpa[i]+=inpa[i]*inpb[0]+inpa[0]*inpb[i];}cp_algo::checkpoint("fix 0");return outpa;}}
0 commit comments