@@ -105,9 +105,10 @@ namespace cp_algo::math::fft {
105
105
callback (i, cur);
106
106
}
107
107
}
108
+ template <int step = 1 >
108
109
static void exec_on_evals (size_t n, auto &&callback) {
109
110
for (size_t i = 0 ; i < n; i++) {
110
- callback (i, eval_point (i));
111
+ callback (i, eval_point (step * i));
111
112
}
112
113
}
113
114
static auto dot_block (size_t k, cvector const & A, cvector const & B) {
@@ -145,16 +146,36 @@ namespace cp_algo::math::fft {
145
146
void ifft () {
146
147
size_t n = size ();
147
148
for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
148
- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
149
- k *= 2 * i;
150
- vpoint cvrt = {real (rt), -imag (rt)};
151
- for (size_t j = k; j < k + i; j += flen) {
152
- auto A = get<vpoint>(j) + get<vpoint>(j + i);
153
- auto B = get<vpoint>(j) - get<vpoint>(j + i);
154
- set (j, A);
155
- set (j + i, B * cvrt);
156
- }
157
- });
149
+ if (4 * i <= n) { // radix-4
150
+ exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
151
+ k *= 4 * i;
152
+ vpoint v1 = {real (rt), -imag (rt)};
153
+ vpoint v2 = v1 * v1;
154
+ vpoint v3 = v1 * v2;
155
+ for (size_t j = k; j < k + i; j += flen) {
156
+ auto A = get<vpoint>(j);
157
+ auto B = get<vpoint>(j + i);
158
+ auto C = get<vpoint>(j + 2 * i);
159
+ auto D = get<vpoint>(j + 3 * i);
160
+ set (j , (A + B + C + D));
161
+ set (j + 2 * i, (A + B - C - D) * v2);
162
+ set (j + i, (A - B - vpoint (0 , 1 ) * (C - D)) * v1);
163
+ set (j + 3 * i, (A - B + vpoint (0 , 1 ) * (C - D)) * v3);
164
+ }
165
+ });
166
+ i *= 2 ;
167
+ } else { // radix-2 fallback
168
+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
169
+ k *= 2 * i;
170
+ vpoint cvrt = {real (rt), -imag (rt)};
171
+ for (size_t j = k; j < k + i; j += flen) {
172
+ auto A = get<vpoint>(j) + get<vpoint>(j + i);
173
+ auto B = get<vpoint>(j) - get<vpoint>(j + i);
174
+ set (j, A);
175
+ set (j + i, B * cvrt);
176
+ }
177
+ });
178
+ }
158
179
}
159
180
checkpoint (" ifft" );
160
181
for (size_t k = 0 ; k < n; k += flen) {
@@ -164,15 +185,35 @@ namespace cp_algo::math::fft {
164
185
void fft () {
165
186
size_t n = size ();
166
187
for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
167
- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
168
- k *= 2 * i;
169
- vpoint vrt = {real (rt), imag (rt)};
170
- for (size_t j = k; j < k + i; j += flen) {
171
- auto t = get<vpoint>(j + i) * vrt;
172
- set (j + i, get<vpoint>(j) - t);
173
- set (j, get<vpoint>(j) + t);
174
- }
175
- });
188
+ if (i / 2 >= flen) { // radix-4
189
+ i /= 2 ;
190
+ exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
191
+ k *= 4 * i;
192
+ vpoint v1 = {real (rt), imag (rt)};
193
+ vpoint v2 = v1 * v1;
194
+ vpoint v3 = v1 * v2;
195
+ for (size_t j = k; j < k + i; j += flen) {
196
+ auto A = get<vpoint>(j);
197
+ auto B = get<vpoint>(j + i) * v1;
198
+ auto C = get<vpoint>(j + 2 * i) * v2;
199
+ auto D = get<vpoint>(j + 3 * i) * v3;
200
+ set (j , (A + C) + (B + D));
201
+ set (j + i, (A + C) - (B + D));
202
+ set (j + 2 * i, (A - C) + vpoint (0 , 1 ) * (B - D));
203
+ set (j + 3 * i, (A - C) - vpoint (0 , 1 ) * (B - D));
204
+ }
205
+ });
206
+ } else { // radix-2 fallback
207
+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
208
+ k *= 2 * i;
209
+ vpoint vrt = {real (rt), imag (rt)};
210
+ for (size_t j = k; j < k + i; j += flen) {
211
+ auto t = get<vpoint>(j + i) * vrt;
212
+ set (j + i, get<vpoint>(j) - t);
213
+ set (j, get<vpoint>(j) + t);
214
+ }
215
+ });
216
+ }
176
217
}
177
218
checkpoint (" fft" );
178
219
}
0 commit comments