@@ -112,10 +112,23 @@ namespace cp_algo::math::fft {
112112
113113 void ifft () {
114114 size_t n = size ();
115- for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
116- if (4 * i <= n) { // radix-4
117- exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
118- k *= 4 * i;
115+ bool parity = std::countr_zero (n) % 2 ;
116+ if (parity) {
117+ exec_on_evals<2 >(n / (2 * flen), [&](size_t k, point rt) {
118+ k *= 2 * flen;
119+ vpoint cvrt = {vz + real (rt), vz - imag (rt)};
120+ auto B = at (k) - at (k + flen);
121+ at (k) += at (k + flen);
122+ at (k + flen) = B * cvrt;
123+ });
124+ }
125+
126+ for (size_t leaf = 3 * flen; leaf < n; leaf += 4 * flen) {
127+ size_t level = std::countr_one (leaf + 3 );
128+ for (size_t lvl = 4 + parity; lvl <= level; lvl += 2 ) {
129+ size_t i = (1 << lvl) / 4 ;
130+ exec_on_eval<4 >(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
131+ k <<= lvl;
119132 vpoint v1 = {vz + real (rt), vz - imag (rt)};
120133 vpoint v2 = v1 * v1;
121134 vpoint v3 = v1 * v2;
@@ -124,21 +137,10 @@ namespace cp_algo::math::fft {
124137 auto B = at (j + i);
125138 auto C = at (j + 2 * i);
126139 auto D = at (j + 3 * i);
127- at (j) = (A + B + C + D);
128- at (j + 2 * i) = (A + B - C - D) * v2;
129- at (j + i) = (A - B - vi (C - D)) * v1;
130- at (j + 3 * i) = (A - B + vi (C - D)) * v3;
131- }
132- });
133- i *= 2 ;
134- } else { // radix-2 fallback
135- exec_on_evals<2 >(n / (2 * i), [&](size_t k, point rt) {
136- k *= 2 * i;
137- vpoint cvrt = {vz + real (rt), vz - imag (rt)};
138- for (size_t j = k; j < k + i; j += flen) {
139- auto B = at (j) - at (j + i);
140- at (j) += at (j + i);
141- at (j + i) = B * cvrt;
140+ at (j) = ((A + B) + (C + D));
141+ at (j + 2 * i) = ((A + B) - (C + D)) * v2;
142+ at (j + i) = ((A - B) - vi (C - D)) * v1;
143+ at (j + 3 * i) = ((A - B) + vi (C - D)) * v3;
142144 }
143145 });
144146 }
@@ -150,11 +152,14 @@ namespace cp_algo::math::fft {
150152 }
151153 void fft () {
152154 size_t n = size ();
153- for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
154- if (i / 2 >= flen) { // radix-4
155- i /= 2 ;
156- exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
157- k *= 4 * i;
155+ bool parity = std::countr_zero (n) % 2 ;
156+ for (size_t leaf = 0 ; leaf < n; leaf += 4 * flen) {
157+ size_t level = std::countr_zero (n + leaf);
158+ level -= level % 2 != parity;
159+ for (size_t lvl = level; lvl >= 4 ; lvl -= 2 ) {
160+ size_t i = (1 << lvl) / 4 ;
161+ exec_on_eval<4 >(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
162+ k <<= lvl;
158163 vpoint v1 = {vz + real (rt), vz + imag (rt)};
159164 vpoint v2 = v1 * v1;
160165 vpoint v3 = v1 * v2;
@@ -169,18 +174,17 @@ namespace cp_algo::math::fft {
169174 at (j + 3 * i) = (A - C) - vi (B - D);
170175 }
171176 });
172- } else { // radix-2 fallback
173- exec_on_evals<2 >(n / (2 * i), [&](size_t k, point rt) {
174- k *= 2 * i;
175- vpoint vrt = {vz + real (rt), vz + imag (rt)};
176- for (size_t j = k; j < k + i; j += flen) {
177- auto t = at (j + i) * vrt;
178- at (j + i) = at (j) - t;
179- at (j) += t;
180- }
181- });
182177 }
183178 }
179+ if (parity) {
180+ exec_on_evals<2 >(n / (2 * flen), [&](size_t k, point rt) {
181+ k *= 2 * flen;
182+ vpoint vrt = {vz + real (rt), vz + imag (rt)};
183+ auto t = at (k + flen) * vrt;
184+ at (k + flen) = at (k) - t;
185+ at (k) += t;
186+ });
187+ }
184188 checkpoint (" fft" );
185189 }
186190 static constexpr size_t pre_evals = 1 << 16 ;
0 commit comments