Skip to content

Commit 8af2b95

Browse files
committed
Optimize dot, use vector<vpoint> instead of vector<vftype>
1 parent 6e0f333 commit 8af2b95

File tree

3 files changed

+97
-99
lines changed

3 files changed

+97
-99
lines changed

cp-algo/math/cvector.hpp

Lines changed: 75 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,67 +13,40 @@ namespace cp_algo::math::fft {
1313
using point = complex<ftype>;
1414
using vftype [[gnu::vector_size(bytes)]] = ftype;
1515
using vpoint = complex<vftype>;
16-
static constexpr vftype fz = {};
16+
static constexpr vftype vz = {};
17+
static constexpr vpoint vi = {vz, vz + 1};
1718

1819
struct cvector {
19-
std::vector<vftype> x, y;
20+
std::vector<vpoint> r;
2021
cvector(size_t n) {
2122
n = std::max(flen, std::bit_ceil(n));
22-
x.resize(n / flen);
23-
y.resize(n / flen);
23+
r.resize(n / flen);
2424
checkpoint("cvector create");
2525
}
26+
27+
vpoint& at(size_t k) {return r[k / flen];}
28+
vpoint at(size_t k) const {return r[k / flen];}
2629
template<class pt = point>
2730
void set(size_t k, pt t) {
2831
if constexpr(std::is_same_v<pt, point>) {
29-
x[k / flen][k % flen] = real(t);
30-
y[k / flen][k % flen] = imag(t);
32+
real(r[k / flen])[k % flen] = real(t);
33+
imag(r[k / flen])[k % flen] = imag(t);
3134
} else {
32-
x[k / flen] = real(t);
33-
y[k / flen] = imag(t);
35+
at(k) = t;
3436
}
3537
}
3638
template<class pt = point>
3739
pt get(size_t k) const {
3840
if constexpr(std::is_same_v<pt, point>) {
39-
return {x[k / flen][k % flen], y[k / flen][k % flen]};
41+
return {real(r[k / flen])[k % flen], imag(r[k / flen])[k % flen]};
4042
} else {
41-
return {x[k / flen], y[k / flen]};
43+
return at(k);
4244
}
4345
}
44-
vpoint vget(size_t k) const {
45-
return get<vpoint>(k);
46-
}
4746

4847
size_t size() const {
49-
return flen * std::size(x);
48+
return flen * std::size(r);
5049
}
51-
52-
static constexpr size_t pre_roots = 1 << 16;
53-
static constexpr std::array<point, pre_roots> roots = []() {
54-
std::array<point, pre_roots> res = {};
55-
for(size_t n = 1; n < res.size(); n *= 2) {
56-
for(size_t k = 0; k < n; k++) {
57-
res[n + k] = polar(1., std::numbers::pi / ftype(n) * ftype(k));
58-
}
59-
}
60-
return res;
61-
}();
62-
static constexpr std::array<size_t, pre_roots> eval_args = []() {
63-
std::array<size_t, pre_roots> res = {};
64-
for(size_t i = 1; i < pre_roots; i++) {
65-
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
66-
}
67-
return res;
68-
}();
69-
static constexpr std::array<point, pre_roots> evalp = []() {
70-
std::array<point, pre_roots> res = {};
71-
res[0] = 1;
72-
for(size_t n = 1; n < pre_roots; n++) {
73-
res[n] = polar(1., std::numbers::pi * ftype(eval_args[n]) / ftype(2 * std::bit_floor(n)));
74-
}
75-
return res;
76-
}();
7750
static size_t eval_arg(size_t n) {
7851
if(n < pre_roots) {
7952
return eval_args[n];
@@ -118,15 +91,17 @@ namespace cp_algo::math::fft {
11891
if(k / flen % 2) {
11992
rt = -rt;
12093
}
121-
auto [Bvx, Bvy] = B.vget(k);
122-
auto [Brvx, Brvy] = vpoint(Bvx, Bvy) * vpoint(fz + real(rt), fz + imag(rt));
123-
auto [Ax, Ay] = A.vget(k);
124-
vftype Bx[2] = {Brvx, Bvx}, By[2] = {Brvy, Bvy};
125-
vpoint res = {fz, fz};
94+
auto [Ax, Ay] = A.at(k);
95+
auto Bv = B.at(k);
96+
vpoint res = {vz, vz};
12697
for (size_t i = 0; i < flen; i++) {
127-
auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128-
auto Bsy = (vftype*)((ftype*)By + flen - i);
129-
res += vpoint(fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
98+
res += vpoint(vz + Ax[i], vz + Ay[i]) * Bv;
99+
real(Bv) = __builtin_shufflevector(real(Bv), real(Bv), 3, 0, 1, 2);
100+
imag(Bv) = __builtin_shufflevector(imag(Bv), imag(Bv), 3, 0, 1, 2);
101+
auto x = real(Bv)[0];
102+
auto y = imag(Bv)[0];
103+
real(Bv)[0] = x * real(rt) - y * imag(rt);
104+
imag(Bv)[0] = x * imag(rt) + y * real(rt);
130105
}
131106
return res;
132107
}
@@ -145,37 +120,36 @@ namespace cp_algo::math::fft {
145120
if (4 * i <= n) { // radix-4
146121
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
147122
k *= 4 * i;
148-
vpoint v1 = {fz + real(rt), fz - imag(rt)};
123+
vpoint v1 = {vz + real(rt), vz - imag(rt)};
149124
vpoint v2 = v1 * v1;
150125
vpoint v3 = v1 * v2;
151126
for(size_t j = k; j < k + i; j += flen) {
152-
auto A = get<vpoint>(j);
153-
auto B = get<vpoint>(j + i);
154-
auto C = get<vpoint>(j + 2 * i);
155-
auto D = get<vpoint>(j + 3 * i);
156-
set(j , (A + B + C + D));
157-
set(j + 2 * i, (A + B - C - D) * v2);
158-
set(j + i, (A - B - vpoint(fz, fz + 1) * (C - D)) * v1);
159-
set(j + 3 * i, (A - B + vpoint(fz, fz + 1) * (C - D)) * v3);
127+
auto A = at(j);
128+
auto B = at(j + i);
129+
auto C = at(j + 2 * i);
130+
auto D = at(j + 3 * i);
131+
at(j) = (A + B + C + D);
132+
at(j + 2 * i) = (A + B - C - D) * v2;
133+
at(j + i) = (A - B - vi * (C - D)) * v1;
134+
at(j + 3 * i) = (A - B + vi * (C - D)) * v3;
160135
}
161136
});
162137
i *= 2;
163138
} else { // radix-2 fallback
164139
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
165140
k *= 2 * i;
166-
vpoint cvrt = {fz + real(rt), fz - imag(rt)};
141+
vpoint cvrt = {vz + real(rt), vz - imag(rt)};
167142
for(size_t j = k; j < k + i; j += flen) {
168-
auto A = get<vpoint>(j) + get<vpoint>(j + i);
169-
auto B = get<vpoint>(j) - get<vpoint>(j + i);
170-
set(j, A);
171-
set(j + i, B * cvrt);
143+
auto B = at(j) - at(j + i);
144+
at(j) += at(j + i);
145+
at(j + i) = B * cvrt;
172146
}
173147
});
174148
}
175149
}
176150
checkpoint("ifft");
177151
for(size_t k = 0; k < n; k += flen) {
178-
set(k, get<vpoint>(k) /= fz + (ftype)(n / flen));
152+
set(k, get<vpoint>(k) /= vz + (ftype)(n / flen));
179153
}
180154
}
181155
void fft() {
@@ -185,34 +159,59 @@ namespace cp_algo::math::fft {
185159
i /= 2;
186160
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
187161
k *= 4 * i;
188-
vpoint v1 = {fz + real(rt), fz + imag(rt)};
162+
vpoint v1 = {vz + real(rt), vz + imag(rt)};
189163
vpoint v2 = v1 * v1;
190164
vpoint v3 = v1 * v2;
191165
for(size_t j = k; j < k + i; j += flen) {
192-
auto A = get<vpoint>(j);
193-
auto B = get<vpoint>(j + i) * v1;
194-
auto C = get<vpoint>(j + 2 * i) * v2;
195-
auto D = get<vpoint>(j + 3 * i) * v3;
196-
set(j , (A + C) + (B + D));
197-
set(j + i, (A + C) - (B + D));
198-
set(j + 2 * i, (A - C) + vpoint(fz, fz + 1) * (B - D));
199-
set(j + 3 * i, (A - C) - vpoint(fz, fz + 1) * (B - D));
166+
auto A = at(j);
167+
auto B = at(j + i) * v1;
168+
auto C = at(j + 2 * i) * v2;
169+
auto D = at(j + 3 * i) * v3;
170+
at(j) = (A + C) + (B + D);
171+
at(j + i) = (A + C) - (B + D);
172+
at(j + 2 * i) = (A - C) + vi * (B - D);
173+
at(j + 3 * i) = (A - C) - vi * (B - D);
200174
}
201175
});
202176
} else { // radix-2 fallback
203177
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
204178
k *= 2 * i;
205-
vpoint vrt = {fz + real(rt), fz + imag(rt)};
179+
vpoint vrt = {vz + real(rt), vz + imag(rt)};
206180
for(size_t j = k; j < k + i; j += flen) {
207-
auto t = get<vpoint>(j + i) * vrt;
208-
set(j + i, get<vpoint>(j) - t);
209-
set(j, get<vpoint>(j) + t);
181+
auto t = at(j + i) * vrt;
182+
at(j + i) = at(j) - t;
183+
at(j) += t;
210184
}
211185
});
212186
}
213187
}
214188
checkpoint("fft");
215189
}
190+
static constexpr size_t pre_roots = 1 << 16;
191+
static constexpr std::array<point, pre_roots> roots = []() {
192+
std::array<point, pre_roots> res = {};
193+
for(size_t n = 1; n < res.size(); n *= 2) {
194+
for(size_t k = 0; k < n; k++) {
195+
res[n + k] = polar(1., std::numbers::pi / ftype(n) * ftype(k));
196+
}
197+
}
198+
return res;
199+
}();
200+
static constexpr std::array<size_t, pre_roots> eval_args = []() {
201+
std::array<size_t, pre_roots> res = {};
202+
for(size_t i = 1; i < pre_roots; i++) {
203+
res[i] = res[i >> 1] | (i & 1) << (std::bit_width(i) - 1);
204+
}
205+
return res;
206+
}();
207+
static constexpr std::array<point, pre_roots> evalp = []() {
208+
std::array<point, pre_roots> res = {};
209+
res[0] = 1;
210+
for(size_t n = 1; n < pre_roots; n++) {
211+
res[n] = polar(1., std::numbers::pi * ftype(eval_args[n]) / ftype(2 * std::bit_floor(n)));
212+
}
213+
return res;
214+
}();
216215
};
217216
}
218217
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/math/fft.hpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,28 @@ namespace cp_algo::math::fft {
4444
if(k / flen % 2) {
4545
rt = -rt;
4646
}
47-
auto [Ax, Ay] = A.vget(k);
48-
auto [Bx, By] = B.vget(k);
49-
auto [Cvx, Cvy] = C.vget(k);
50-
auto [Dvx, Dvy] = D.vget(k);
51-
auto [Crvx, Crvy] = vpoint(Cvx, Cvy) * vpoint(fz + real(rt), fz + imag(rt));
52-
auto [Drvx, Drvy] = vpoint(Dvx, Dvy) * vpoint(fz + real(rt), fz + imag(rt));
53-
vftype Cx[2] = {Crvx, Cvx}, Cy[2] = {Crvy, Cvy};
54-
vftype Dx[2] = {Drvx, Dvx}, Dy[2] = {Drvy, Dvy};
47+
auto [Ax, Ay] = A.at(k);
48+
auto [Bx, By] = B.at(k);
5549
vpoint AC, AD, BC, BD;
56-
AC = AD = BC = BD = {fz, fz};
57-
for(size_t i = 0; i < flen; i++) {
58-
auto Csx = (vftype*)((ftype*)Cx + flen - i);
59-
auto Csy = (vftype*)((ftype*)Cy + flen - i);
60-
auto Dsx = (vftype*)((ftype*)Dx + flen - i);
61-
auto Dsy = (vftype*)((ftype*)Dy + flen - i);
62-
vpoint As = {fz + Ax[i], fz + Ay[i]}, Bs = {fz + Bx[i], fz + By[i]};
63-
vpoint Cs = {*Csx, *Csy}, Ds = {*Dsx, *Dsy};
64-
AC += As * Cs; AD += As * Ds;
65-
BC += Bs * Cs; BD += Bs * Ds;
50+
auto Cv = C.at(k), Dv = D.at(k);
51+
for (size_t i = 0; i < flen; i++) {
52+
vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
53+
AC += Av * Cv; AD += Av * Dv;
54+
BC += Bv * Cv; BD += Bv * Dv;
55+
real(Cv) = __builtin_shufflevector(real(Cv), real(Cv), 3, 0, 1, 2);
56+
imag(Cv) = __builtin_shufflevector(imag(Cv), imag(Cv), 3, 0, 1, 2);
57+
real(Dv) = __builtin_shufflevector(real(Dv), real(Dv), 3, 0, 1, 2);
58+
imag(Dv) = __builtin_shufflevector(imag(Dv), imag(Dv), 3, 0, 1, 2);
59+
auto cx = real(Cv)[0], cy = imag(Cv)[0];
60+
auto dx = real(Dv)[0], dy = imag(Dv)[0];
61+
real(Cv)[0] = cx * real(rt) - cy * imag(rt);
62+
imag(Cv)[0] = cx * imag(rt) + cy * real(rt);
63+
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
64+
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
6665
}
67-
A.set(k, AC);
68-
C.set(k, AD + BC);
69-
B.set(k, BD);
66+
A.at(k) = AC;
67+
C.at(k) = AD + BC;
68+
B.at(k) = BD;
7069
}
7170
checkpoint("dot");
7271
A.ifft();

cp-algo/util/complex.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ namespace cp_algo {
88
struct complex {
99
using value_type = T;
1010
T x, y;
11-
constexpr complex() {}
12-
constexpr complex(T x): x(x), y(0) {}
11+
constexpr complex(): x(), y() {}
12+
constexpr complex(T x): x(x), y() {}
1313
constexpr complex(T x, T y): x(x), y(y) {}
1414
complex& operator *= (T t) {x *= t; y *= t; return *this;}
1515
complex& operator /= (T t) {x /= t; y /= t; return *this;}

0 commit comments

Comments
 (0)