Skip to content

Commit 6e0f333

Browse files
committed
use gnu::vector_size instead of std::experimental::simd
1 parent 3cbf663 commit 6e0f333

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

cp-algo/math/cvector.hpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
namespace stdx = std::experimental;
99
namespace cp_algo::math::fft {
1010
using ftype = double;
11+
static constexpr size_t bytes = 32;
12+
static constexpr size_t flen = bytes / sizeof(ftype);
1113
using point = complex<ftype>;
12-
using vftype = stdx::native_simd<ftype>;
14+
using vftype [[gnu::vector_size(bytes)]] = ftype;
1315
using vpoint = complex<vftype>;
14-
static constexpr size_t flen = vftype::size();
16+
static constexpr vftype fz = {};
1517

1618
struct cvector {
1719
std::vector<vftype> x, y;
@@ -117,15 +119,14 @@ namespace cp_algo::math::fft {
117119
rt = -rt;
118120
}
119121
auto [Bvx, Bvy] = B.vget(k);
120-
auto [Brvx, Brvy] = vpoint(Bvx, Bvy) * vpoint(real(rt), imag(rt));
122+
auto [Brvx, Brvy] = vpoint(Bvx, Bvy) * vpoint(fz + real(rt), fz + imag(rt));
121123
auto [Ax, Ay] = A.vget(k);
122124
vftype Bx[2] = {Brvx, Bvx}, By[2] = {Brvy, Bvy};
123-
vpoint res = {0, 0};
125+
vpoint res = {fz, fz};
124126
for (size_t i = 0; i < flen; i++) {
125-
vftype Bsx, Bsy;
126-
Bsx.copy_from((ftype*)Bx + flen - i, stdx::element_aligned);
127-
Bsy.copy_from((ftype*)By + flen - i, stdx::element_aligned);
128-
res += vpoint(Ax[i], Ay[i]) * vpoint{Bsx, Bsy};
127+
auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128+
auto Bsy = (vftype*)((ftype*)By + flen - i);
129+
res += vpoint(fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
129130
}
130131
return res;
131132
}
@@ -144,7 +145,7 @@ namespace cp_algo::math::fft {
144145
if (4 * i <= n) { // radix-4
145146
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
146147
k *= 4 * i;
147-
vpoint v1 = {real(rt), -imag(rt)};
148+
vpoint v1 = {fz + real(rt), fz - imag(rt)};
148149
vpoint v2 = v1 * v1;
149150
vpoint v3 = v1 * v2;
150151
for(size_t j = k; j < k + i; j += flen) {
@@ -154,15 +155,15 @@ namespace cp_algo::math::fft {
154155
auto D = get<vpoint>(j + 3 * i);
155156
set(j , (A + B + C + D));
156157
set(j + 2 * i, (A + B - C - D) * v2);
157-
set(j + i, (A - B - vpoint(0, 1) * (C - D)) * v1);
158-
set(j + 3 * i, (A - B + vpoint(0, 1) * (C - D)) * v3);
158+
set(j + i, (A - B - vpoint(fz, fz + 1) * (C - D)) * v1);
159+
set(j + 3 * i, (A - B + vpoint(fz, fz + 1) * (C - D)) * v3);
159160
}
160161
});
161162
i *= 2;
162163
} else { // radix-2 fallback
163164
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
164165
k *= 2 * i;
165-
vpoint cvrt = {real(rt), -imag(rt)};
166+
vpoint cvrt = {fz + real(rt), fz - imag(rt)};
166167
for(size_t j = k; j < k + i; j += flen) {
167168
auto A = get<vpoint>(j) + get<vpoint>(j + i);
168169
auto B = get<vpoint>(j) - get<vpoint>(j + i);
@@ -174,7 +175,7 @@ namespace cp_algo::math::fft {
174175
}
175176
checkpoint("ifft");
176177
for(size_t k = 0; k < n; k += flen) {
177-
set(k, get<vpoint>(k) /= (ftype)(n / flen));
178+
set(k, get<vpoint>(k) /= fz + (ftype)(n / flen));
178179
}
179180
}
180181
void fft() {
@@ -184,7 +185,7 @@ namespace cp_algo::math::fft {
184185
i /= 2;
185186
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
186187
k *= 4 * i;
187-
vpoint v1 = {real(rt), imag(rt)};
188+
vpoint v1 = {fz + real(rt), fz + imag(rt)};
188189
vpoint v2 = v1 * v1;
189190
vpoint v3 = v1 * v2;
190191
for(size_t j = k; j < k + i; j += flen) {
@@ -194,14 +195,14 @@ namespace cp_algo::math::fft {
194195
auto D = get<vpoint>(j + 3 * i) * v3;
195196
set(j , (A + C) + (B + D));
196197
set(j + i, (A + C) - (B + D));
197-
set(j + 2 * i, (A - C) + vpoint(0, 1) * (B - D));
198-
set(j + 3 * i, (A - C) - vpoint(0, 1) * (B - D));
198+
set(j + 2 * i, (A - C) + vpoint(fz, fz + 1) * (B - D));
199+
set(j + 3 * i, (A - C) - vpoint(fz, fz + 1) * (B - D));
199200
}
200201
});
201202
} else { // radix-2 fallback
202203
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
203204
k *= 2 * i;
204-
vpoint vrt = {real(rt), imag(rt)};
205+
vpoint vrt = {fz + real(rt), fz + imag(rt)};
205206
for(size_t j = k; j < k + i; j += flen) {
206207
auto t = get<vpoint>(j + i) * vrt;
207208
set(j + i, get<vpoint>(j) - t);

cp-algo/math/fft.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,19 @@ namespace cp_algo::math::fft {
4848
auto [Bx, By] = B.vget(k);
4949
auto [Cvx, Cvy] = C.vget(k);
5050
auto [Dvx, Dvy] = D.vget(k);
51-
auto [Crvx, Crvy] = vpoint(Cvx, Cvy) * vpoint(real(rt), imag(rt));
52-
auto [Drvx, Drvy] = vpoint(Dvx, Dvy) * vpoint(real(rt), imag(rt));
51+
auto [Crvx, Crvy] = vpoint(Cvx, Cvy) * vpoint(fz + real(rt), fz + imag(rt));
52+
auto [Drvx, Drvy] = vpoint(Dvx, Dvy) * vpoint(fz + real(rt), fz + imag(rt));
5353
vftype Cx[2] = {Crvx, Cvx}, Cy[2] = {Crvy, Cvy};
5454
vftype Dx[2] = {Drvx, Dvx}, Dy[2] = {Drvy, Dvy};
5555
vpoint AC, AD, BC, BD;
56-
AC = AD = BC = BD = {0, 0};
56+
AC = AD = BC = BD = {fz, fz};
5757
for(size_t i = 0; i < flen; i++) {
58-
vftype Csx, Csy, Dsx, Dsy;
59-
Csx.copy_from((ftype*)Cx + flen - i, stdx::element_aligned);
60-
Csy.copy_from((ftype*)Cy + flen - i, stdx::element_aligned);
61-
Dsx.copy_from((ftype*)Dx + flen - i, stdx::element_aligned);
62-
Dsy.copy_from((ftype*)Dy + flen - i, stdx::element_aligned);
63-
vpoint As = {Ax[i], Ay[i]}, Bs = {Bx[i], By[i]};
64-
vpoint Cs = {Csx, Csy}, Ds = {Dsx, Dsy};
58+
auto Csx = (vftype*)((ftype*)Cx + flen - i);
59+
auto Csy = (vftype*)((ftype*)Cy + flen - i);
60+
auto Dsx = (vftype*)((ftype*)Dx + flen - i);
61+
auto Dsy = (vftype*)((ftype*)Dy + flen - i);
62+
vpoint As = {fz + Ax[i], fz + Ay[i]}, Bs = {fz + Bx[i], fz + By[i]};
63+
vpoint Cs = {*Csx, *Csy}, Ds = {*Dsx, *Dsy};
6564
AC += As * Cs; AD += As * Ds;
6665
BC += Bs * Cs; BD += Bs * Ds;
6766
}

0 commit comments

Comments
 (0)