Skip to content

Commit 07f2eb1

Browse files
committed
Make fft/ifft radix-4
1 parent 22e22f3 commit 07f2eb1

File tree

1 file changed

+61
-20
lines changed

1 file changed

+61
-20
lines changed

cp-algo/math/cvector.hpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ namespace cp_algo::math::fft {
105105
callback(i, cur);
106106
}
107107
}
108+
template<int step = 1>
108109
static void exec_on_evals(size_t n, auto &&callback) {
109110
for(size_t i = 0; i < n; i++) {
110-
callback(i, eval_point(i));
111+
callback(i, eval_point(step * i));
111112
}
112113
}
113114
static auto dot_block(size_t k, cvector const& A, cvector const& B) {
@@ -145,16 +146,36 @@ namespace cp_algo::math::fft {
145146
void ifft() {
146147
size_t n = size();
147148
for(size_t i = flen; i <= n / 2; i *= 2) {
148-
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
149-
k *= 2 * i;
150-
vpoint cvrt = {real(rt), -imag(rt)};
151-
for(size_t j = k; j < k + i; j += flen) {
152-
auto A = get<vpoint>(j) + get<vpoint>(j + i);
153-
auto B = get<vpoint>(j) - get<vpoint>(j + i);
154-
set(j, A);
155-
set(j + i, B * cvrt);
156-
}
157-
});
149+
if (4 * i <= n) { // radix-4
150+
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
151+
k *= 4 * i;
152+
vpoint v1 = {real(rt), -imag(rt)};
153+
vpoint v2 = v1 * v1;
154+
vpoint v3 = v1 * v2;
155+
for(size_t j = k; j < k + i; j += flen) {
156+
auto A = get<vpoint>(j);
157+
auto B = get<vpoint>(j + i);
158+
auto C = get<vpoint>(j + 2 * i);
159+
auto D = get<vpoint>(j + 3 * i);
160+
set(j , (A + B + C + D));
161+
set(j + 2 * i, (A + B - C - D) * v2);
162+
set(j + i, (A - B - vpoint(0, 1) * (C - D)) * v1);
163+
set(j + 3 * i, (A - B + vpoint(0, 1) * (C - D)) * v3);
164+
}
165+
});
166+
i *= 2;
167+
} else { // radix-2 fallback
168+
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
169+
k *= 2 * i;
170+
vpoint cvrt = {real(rt), -imag(rt)};
171+
for(size_t j = k; j < k + i; j += flen) {
172+
auto A = get<vpoint>(j) + get<vpoint>(j + i);
173+
auto B = get<vpoint>(j) - get<vpoint>(j + i);
174+
set(j, A);
175+
set(j + i, B * cvrt);
176+
}
177+
});
178+
}
158179
}
159180
checkpoint("ifft");
160181
for(size_t k = 0; k < n; k += flen) {
@@ -164,15 +185,35 @@ namespace cp_algo::math::fft {
164185
void fft() {
165186
size_t n = size();
166187
for(size_t i = n / 2; i >= flen; i /= 2) {
167-
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
168-
k *= 2 * i;
169-
vpoint vrt = {real(rt), imag(rt)};
170-
for(size_t j = k; j < k + i; j += flen) {
171-
auto t = get<vpoint>(j + i) * vrt;
172-
set(j + i, get<vpoint>(j) - t);
173-
set(j, get<vpoint>(j) + t);
174-
}
175-
});
188+
if (i / 2 >= flen) { // radix-4
189+
i /= 2;
190+
exec_on_evals<2>(n / (4 * i), [&](size_t k, point rt) {
191+
k *= 4 * i;
192+
vpoint v1 = {real(rt), imag(rt)};
193+
vpoint v2 = v1 * v1;
194+
vpoint v3 = v1 * v2;
195+
for(size_t j = k; j < k + i; j += flen) {
196+
auto A = get<vpoint>(j);
197+
auto B = get<vpoint>(j + i) * v1;
198+
auto C = get<vpoint>(j + 2 * i) * v2;
199+
auto D = get<vpoint>(j + 3 * i) * v3;
200+
set(j , (A + C) + (B + D));
201+
set(j + i, (A + C) - (B + D));
202+
set(j + 2 * i, (A - C) + vpoint(0, 1) * (B - D));
203+
set(j + 3 * i, (A - C) - vpoint(0, 1) * (B - D));
204+
}
205+
});
206+
} else { // radix-2 fallback
207+
exec_on_evals(n / (2 * i), [&](size_t k, point rt) {
208+
k *= 2 * i;
209+
vpoint vrt = {real(rt), imag(rt)};
210+
for(size_t j = k; j < k + i; j += flen) {
211+
auto t = get<vpoint>(j + i) * vrt;
212+
set(j + i, get<vpoint>(j) - t);
213+
set(j, get<vpoint>(j) + t);
214+
}
215+
});
216+
}
176217
}
177218
checkpoint("fft");
178219
}

0 commit comments

Comments
 (0)