Skip to content

Commit 9fcaff0

Browse files
committed
feat: polynomial
1 parent b0c3f77 commit 9fcaff0

11 files changed

Lines changed: 427 additions & 0 deletions

modular-arithmetic/binpow.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
/**
4+
* Author: Teetat T.
5+
* Date: 2024-01-15
6+
* Description: n-th power using divide and conquer
7+
* Time: $O(\log b)$
8+
*/
9+
10+
template<class T>
11+
constexpr T binpow(T a,ll b){
12+
T res=1;
13+
for(;b>0;b>>=1,a*=a)if(b&1)res*=a;
14+
return res;
15+
}
16+

polynomial/fft.hpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#pragma once
2+
3+
/**
4+
* Author: Teetat T.
5+
* Date: 2024-03-17
6+
* Description: Fast Fourier Transform
7+
* Time: $O(N \log N)$
8+
*/
9+
10+
template<class T=ll,int mod=0>
11+
struct FFT{
12+
using vt = vector<T>;
13+
using cd = complex<db>;
14+
using vc = vector<cd>;
15+
16+
static const bool INT=true;
17+
18+
static void fft(vc &a){
19+
int n=a.size(),L=31-__builtin_clz(n);
20+
vc rt(n);
21+
rt[1]=1;
22+
for(int k=2;k<n;k*=2){
23+
cd z=polar(db(1),PI/k);
24+
for(int i=k;i<2*k;i++)rt[i]=i&1?rt[i/2]*z:rt[i/2];
25+
}
26+
vector<int> rev(n);
27+
for(int i=1;i<n;i++)rev[i]=(rev[i/2]|(i&1)<<L)/2;
28+
for(int i=1;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
29+
for(int k=1;k<n;k*=2)for(int i=0;i<n;i+=2*k)for(int j=0;j<k;j++){
30+
cd z=rt[j+k]*a[i+j+k];
31+
a[i+j+k]=a[i+j]-z;
32+
a[i+j]+=z;
33+
}
34+
}
35+
template<class U>
36+
static db norm(const U &x){
37+
return INT?round(x):x;
38+
}
39+
static vt conv(const vt &a,const vt &b){
40+
if(a.empty()||b.empty())return {};
41+
vt res(a.size()+b.size()-1);
42+
int L=32-__builtin_clz(res.size()),n=1<<L;
43+
vc in(n),out(n);
44+
copy(a.begin(),a.end(),in.begin());
45+
for(int i=0;i<b.size();i++)in[i].imag(b[i]);
46+
fft(in);
47+
for(auto &x:in)x*=x;
48+
for(int i=0;i<n;i++)out[i]=in[-i&(n-1)]-conj(in[i]);
49+
fft(out);
50+
for(int i=0;i<res.size();i++)res[i]=norm(imag(out[i])/(4*n));
51+
return res;
52+
}
53+
static vt convMod(const vt &a,const vt &b){
54+
assert(mod>0);
55+
if(a.empty()||b.empty())return {};
56+
vt res(a.size()+b.size()-1);
57+
int L=32-__builtin_clz(res.size()),n=1<<L;
58+
ll cut=int(sqrt(mod));
59+
vc in1(n),in2(n),out1(n),out2(n);
60+
for(int i=0;i<a.size();i++)in1[i]=cd(ll(a[i])/cut,ll(a[i])%cut); // a1 + i * a2
61+
for(int i=0;i<b.size();i++)in2[i]=cd(ll(b[i])/cut,ll(b[i])%cut); // b1 + i * b2
62+
fft(in1),fft(in2);
63+
for(int i=0;i<n;i++){
64+
int j=-i&(n-1);
65+
out1[j]=(in1[i]+conj(in1[j]))*in2[i]/(2.l*n); // f1 * (g1 + i * g2) = f1 * g1 + i f1 * g2
66+
out2[j]=(in1[i]-conj(in1[j]))*in2[i]/cd(0.l,2.l*n); // f2 * (g1 + i * g2) = f2 * g1 + i f2 * g2
67+
}
68+
fft(out1),fft(out2);
69+
for(int i=0;i<res.size();i++){
70+
ll x=round(real(out1[i])),y=round(imag(out1[i]))+round(real(out2[i])),z=round(imag(out2[i]));
71+
res[i]=((x%mod*cut+y)%mod*cut+z)%mod; // a1 * b1 * cut^2 + (a1 * b2 + a2 * b1) * cut + a2 * b2
72+
}
73+
return res;
74+
}
75+
vt operator()(const vt &a,const vt &b){
76+
return mod>0?convMod(a,b):conv(a,b);
77+
}
78+
};
79+
template<>
80+
struct FFT<db>{
81+
static const bool INT=false;
82+
};
83+

polynomial/formal-power-series.hpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#pragma once
2+
#include "polynomial/ntt.hpp"
3+
4+
/**
5+
* Author: Teetat T.
6+
* Date: 2024-03-17
7+
* Description: basic operations of formal power series
8+
*/
9+
10+
template<class mint>
11+
struct FormalPowerSeries:vector<mint>{
12+
using vector<mint>::vector;
13+
using FPS = FormalPowerSeries;
14+
15+
FPS &operator+=(const FPS &rhs){
16+
if(rhs.size()>this->size())this->resize(rhs.size());
17+
for(int i=0;i<rhs.size();i++)(*this)[i]+=rhs[i];
18+
return *this;
19+
}
20+
FPS &operator+=(const mint &rhs){
21+
if(this->empty())this->resize(1);
22+
(*this)[0]+=rhs;
23+
return *this;
24+
}
25+
FPS &operator-=(const FPS &rhs){
26+
if(rhs.size()>this->size())this->resize(rhs.size());
27+
for(int i=0;i<rhs.size();i++)(*this)[i]-=rhs[i];
28+
return *this;
29+
}
30+
FPS &operator-=(const mint &rhs){
31+
if(this->empty())this->resize(1);
32+
(*this)[0]-=rhs;
33+
return *this;
34+
}
35+
FPS &operator*=(const FPS &rhs){
36+
auto res=NTT<mint>()(*this,rhs);
37+
return *this=FPS(res.begin(),res.end());
38+
}
39+
FPS &operator*=(const mint &rhs){
40+
for(auto &a:*this)a*=rhs;
41+
return *this;
42+
}
43+
friend FPS operator+(FPS lhs,const FPS &rhs){return lhs+=rhs;}
44+
friend FPS operator+(FPS lhs,const mint &rhs){return lhs+=rhs;}
45+
friend FPS operator+(const mint &lhs,FPS &rhs){return rhs+=lhs;}
46+
friend FPS operator-(FPS lhs,const FPS &rhs){return lhs-=rhs;}
47+
friend FPS operator-(FPS lhs,const mint &rhs){return lhs-=rhs;}
48+
friend FPS operator-(const mint &lhs,FPS rhs){return -(rhs-lhs);}
49+
friend FPS operator*(FPS lhs,const FPS &rhs){return lhs*=rhs;}
50+
friend FPS operator*(FPS lhs,const mint &rhs){return lhs*=rhs;}
51+
friend FPS operator*(const mint &lhs,FPS rhs){return rhs*=lhs;}
52+
53+
FPS operator-(){return (*this)*-1;}
54+
55+
FPS rev(){
56+
FPS res(*this);
57+
reverse(res.beign(),res.end());
58+
return res;
59+
}
60+
FPS pre(int sz){
61+
FPS res(this->begin(),this->begin()+min((int)this->size(),sz));
62+
if(res.size()<sz)res.resize(sz);
63+
return res;
64+
}
65+
FPS shrink(){
66+
FPS res(*this);
67+
while(!res.empty()&&res.back()==mint{})res.pop_back();
68+
return res;
69+
}
70+
FPS operator>>(int sz){
71+
if(this->size()<=sz)return {};
72+
FPS res(*this);
73+
res.erase(res.begin(),res.begin()+sz);
74+
return res;
75+
}
76+
FPS operator<<(int sz){
77+
FPS res(*this);
78+
res.insert(res.begin(),sz,mint{});
79+
return res;
80+
}
81+
FPS diff(){
82+
const int n=this->size();
83+
FPS res(max(0,n-1));
84+
for(int i=1;i<n;i++)res[i-1]=(*this)[i]*mint(i);
85+
return res;
86+
}
87+
FPS integral(){
88+
const int n=this->size();
89+
FPS res(n+1);
90+
res[0]=0;
91+
if(n>0)res[1]=1;
92+
ll mod=mint::get_mod();
93+
for(int i=2;i<=n;i++)res[i]=(-res[mod%i])*(mod/i);
94+
for(int i=0;i<n;i++)res[i+1]*=(*this)[i];
95+
return res;
96+
}
97+
mint eval(const mint &x){
98+
mint res=0,w=1;
99+
for(auto &a:*this)res+=a*w,w*=x;
100+
return res;
101+
}
102+
103+
FPS inv(int deg=-1){
104+
assert(!this->empty()&&(*this)[0]!=mint(0));
105+
if(deg==-1)deg=this->size();
106+
FPS res{mint(1)/(*this)[0]};
107+
for(int i=2;i>>1<deg;i<<=1){
108+
res=(res*(mint(2)-res*pre(i))).pre(i);
109+
}
110+
return res.pre(deg);
111+
}
112+
FPS log(int deg=-1){
113+
assert(!this->empty()&&(*this)[0]==mint(1));
114+
if(deg==-1)deg=this->size();
115+
return (pre(deg).diff()*inv(deg)).pre(deg-1).integral();
116+
}
117+
FPS exp(int deg=-1){
118+
assert(this->empty()||(*this)[0]==mint(0));
119+
if(deg==-1)deg=this->size();
120+
FPS res{mint(1)};
121+
for(int i=2;i>>1<deg;i<<=1){
122+
res=(res*(pre(i)-res.log(i)+mint(1))).pre(i);
123+
}
124+
return res.pre(deg);
125+
}
126+
FPS pow(ll k,int deg=-1){
127+
const int n=this->size();
128+
if(deg==-1)deg=n;
129+
if(k==0){
130+
FPS res(deg);
131+
if(deg)res[0]=mint(1);
132+
return res;
133+
}
134+
for(int i=0;i<n;i++){
135+
if(__int128_t(i)*k>=deg)return FPS(deg,mint(0));
136+
if((*this)[i]==mint(0))continue;
137+
mint rev=mint(1)/(*this)[i];
138+
FPS res=(((*this*rev)>>i).log(deg)*k).exp(deg);
139+
res=((res*binpow((*this)[i],k))<<(i*k)).pre(deg);
140+
return res;
141+
}
142+
return FPS(deg,mint(0));
143+
}
144+
};
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
/**
4+
* Author: Teetat T.
5+
* Description: Lagrange interpolation. Given f(0)...f(n) return a polynomial with degree n.
6+
* Time: $O(N)$
7+
*/
8+
9+
template<class mint>
10+
mint lagrange_interpolate(vector<mint> &f,mint c){
11+
int n=f.size();
12+
if(c.val()<n)return f[c.val()];
13+
vector<mint> l(n+1),r(n+1);
14+
l[0]=r[n]=1;
15+
for(int i=0;i<n;i++)l[i+1]=l[i]*(c-i);
16+
for(int i=n-1;i>=0;i--)r[i]=r[i+1]*(c-i);
17+
mint ans=0;
18+
for(int i=0;i<n;i++){
19+
mint cur=f[i]*comb.ifac(i)*comb.ifac(n-i-1);
20+
if((n-i-1)&1)cur*=-1;
21+
ans+=cur*l[i]*r[i+1];
22+
}
23+
return ans;
24+
}

polynomial/ntt.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
#include "modular-arithmetic/binpow.hpp"
3+
#include "modular-arithmetic/montgomery-modint.hpp"
4+
5+
/**
6+
* Author: Teetat T.
7+
* Description: Number Theoretic Transform
8+
* Time: $O(N \log N)$
9+
*/
10+
11+
template<class mint>
12+
struct NTT{
13+
using vm = vector<mint>;
14+
15+
static constexpr mint root=mint::get_root();
16+
static_assert(root!=0);
17+
18+
static void ntt(vm &a){
19+
int n=a.size(),L=31-__builtin_clz(n);
20+
vm rt(n);
21+
rt[1]=1;
22+
for(int k=2,s=2;k<n;k*=2,s++){
23+
mint z[]={1,binpow(root,MOD>>s)};
24+
for(int i=k;i<2*k;i++)rt[i]=rt[i/2]*z[i&1];
25+
}
26+
vector<int> rev(n);
27+
for(int i=1;i<n;i++)rev[i]=(rev[i/2]|(i&1)<<L)/2;
28+
for(int i=1;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
29+
for(int k=1;k<n;k*=2)for(int i=0;i<n;i+=2*k)for(int j=0;j<k;j++){
30+
mint z=rt[j+k]*a[i+j+k];
31+
a[i+j+k]=a[i+j]-z;
32+
a[i+j]+=z;
33+
}
34+
}
35+
static vm conv(const vm &a,const vm &b){
36+
if(a.empty()||b.empty())return {};
37+
int s=a.size()+b.size()-1,n=1<<(32-__builtin_clz(s));
38+
mint inv=mint(n).inv();
39+
vm in1(a),in2(b),out(n);
40+
in1.resize(n),in2.resize(n);
41+
ntt(in1),ntt(in2);
42+
for(int i=0;i<n;i++)out[-i&(n-1)]=in1[i]*in2[i]*inv;
43+
ntt(out);
44+
return vm(out.begin(),out.begin()+s);
45+
}
46+
vm operator()(const vm &a,const vm &b){
47+
return conv(a,b);
48+
}
49+
};
50+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"
2+
#include "template.hpp"
3+
#include "modular-arithmetic/montgomery-modint.hpp"
4+
#include "polynomial/ntt.hpp"
5+
6+
using mint = mint998;
7+
8+
int main(){
9+
cin.tie(nullptr)->sync_with_stdio(false);
10+
int n,m;
11+
cin >> n >> m;
12+
vector<mint> a(n),b(m);
13+
for(auto &x:a)cin >> x;
14+
for(auto &x:b)cin >> x;
15+
auto c=NTT<mint>()(a,b);
16+
for(auto x:c)cout << x << " ";
17+
cout << "\n";
18+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_1000000007"
2+
#include "template.hpp"
3+
#include "polynomial/fft.hpp"
4+
5+
int main(){
6+
cin.tie(nullptr)->sync_with_stdio(false);
7+
int n,m;
8+
cin >> n >> m;
9+
vector<ll> a(n),b(m);
10+
for(auto &x:a)cin >> x;
11+
for(auto &x:b)cin >> x;
12+
auto c=FFT<ll,MOD2>()(a,b);
13+
for(auto x:c)cout << x << " ";
14+
cout << "\n";
15+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/exp_of_formal_power_series"
2+
#include "template.hpp"
3+
#include "modular-arithmetic/montgomery-modint.hpp"
4+
#include "polynomial/formal-power-series.hpp"
5+
6+
7+
using mint = mint998;
8+
using FPS = FormalPowerSeries<mint>;
9+
10+
int main(){
11+
cin.tie(nullptr)->sync_with_stdio(false);
12+
int n;
13+
cin >> n;
14+
FPS a(n);
15+
for(auto &x:a)cin >> x;
16+
a=a.exp();
17+
for(auto x:a)cout << x << " ";
18+
cout << "\n";
19+
}

0 commit comments

Comments
 (0)