@@ -19,7 +19,7 @@ namespace cp_algo::math::fft {
1919 }
2020 static u64x4 mod, imod;
2121
22- void init () {
22+ static void init () {
2323 if (!_init) {
2424 factor = 1 + random::rng () % (base::mod () - 1 );
2525 ifactor = base (1 ) / factor;
@@ -40,16 +40,16 @@ namespace cp_algo::math::fft {
4040 };
4141 u64x4 step4 = u64x4{} + (bpow (factor, 4 ) * b2x32).getr ();
4242 u64x4 stepn = u64x4{} + (bpow (factor, n) * b2x32).getr ();
43- for (size_t i = 0 ; i < std::min (n, size (a)); i += flen) {
43+ for (size_t i = 0 ; i < std::min (n, std:: size (a)); i += flen) {
4444 auto splt = [&](size_t i, auto mul) {
45- if (i >= size (a)) {
45+ if (i >= std:: size (a)) {
4646 return std::pair{vftype (), vftype ()};
4747 }
4848 u64x4 au = {
49- i < size (a) ? a[i].getr () : 0 ,
50- i + 1 < size (a) ? a[i + 1 ].getr () : 0 ,
51- i + 2 < size (a) ? a[i + 2 ].getr () : 0 ,
52- i + 3 < size (a) ? a[i + 3 ].getr () : 0
49+ i < std:: size (a) ? a[i].getr () : 0 ,
50+ i + 1 < std:: size (a) ? a[i + 1 ].getr () : 0 ,
51+ i + 2 < std:: size (a) ? a[i + 2 ].getr () : 0 ,
52+ i + 3 < std:: size (a) ? a[i + 3 ].getr () : 0
5353 };
5454 au = montgomery_mul (au, mul, mod, imod);
5555 au = au >= base::mod () ? au - base::mod () : au;
@@ -101,7 +101,8 @@ namespace cp_algo::math::fft {
101101 }
102102
103103 void recover_mod (auto &&C, auto &res, size_t k) {
104- res.assign ((k / flen + 1 ) * flen, base (0 ));
104+ size_t check = (k + flen - 1 ) / flen * flen;
105+ assert (res.size () >= check);
105106 size_t n = A.size ();
106107 auto const splitsplit = base (split () * split ()).getr ();
107108 base b2x32 = bpow (base (2 ), 32 );
@@ -134,7 +135,6 @@ namespace cp_algo::math::fft {
134135 }
135136 cur = montgomery_mul (cur, step4, mod, imod);
136137 }
137- res.resize (k);
138138 checkpoint (" recover mod" );
139139 }
140140
@@ -158,12 +158,12 @@ namespace cp_algo::math::fft {
158158 mul (cvector (B.A ), B.B , res, k);
159159 }
160160 std::vector<base, big_alloc<base>> operator *= (dft &B) {
161- std::vector<base, big_alloc<base>> res;
161+ std::vector<base, big_alloc<base>> res ( 2 * A. size ()) ;
162162 mul_inplace (B, res, 2 * A.size ());
163163 return res;
164164 }
165165 std::vector<base, big_alloc<base>> operator *= (dft const & B) {
166- std::vector<base, big_alloc<base>> res;
166+ std::vector<base, big_alloc<base>> res ( 2 * A. size ()) ;
167167 mul (B, res, 2 * A.size ());
168168 return res;
169169 }
@@ -180,11 +180,11 @@ namespace cp_algo::math::fft {
180180 template <modint_type base> u64x4 dft<base>::imod = {};
181181
182182 void mul_slow (auto &a, auto const & b, size_t k) {
183- if (empty (a) || empty (b)) {
183+ if (std:: empty (a) || std:: empty (b)) {
184184 a.clear ();
185185 } else {
186- size_t n = std::min (k, size (a));
187- size_t m = std::min (k, size (b));
186+ size_t n = std::min (k, std:: size (a));
187+ size_t m = std::min (k, std:: size (b));
188188 a.resize (k);
189189 for (int j = int (k - 1 ); j >= 0 ; j--) {
190190 a[j] *= b[0 ];
@@ -202,55 +202,103 @@ namespace cp_algo::math::fft {
202202 }
203203 void mul_truncate (auto &a, auto const & b, size_t k) {
204204 using base = std::decay_t <decltype (a[0 ])>;
205- if (std::min ({k, size (a), size (b)}) < magic) {
205+ if (std::min ({k, std:: size (a), std:: size (b)}) < magic) {
206206 mul_slow (a, b, k);
207207 return ;
208208 }
209209 auto n = std::max (flen, std::bit_ceil (
210- std::min (k, size (a)) + std::min (k, size (b)) - 1
210+ std::min (k, std:: size (a)) + std::min (k, std:: size (b)) - 1
211211 ) / 2 );
212212 auto A = dft<base>(a | std::views::take (k), n);
213- if (&a == &b) {
214- A.mul (A, a, k);
215- } else {
216- A.mul_inplace (dft<base>(b | std::views::take (k), n), a, k);
213+ auto B = dft<base>(b | std::views::take (k), n);
214+ a.resize ((k + flen - 1 ) / flen * flen);
215+ A.mul_inplace (B, a, k);
216+ a.resize (k);
217+ }
218+
219+ // store mod x^n-k in first half, x^n+k in second half
220+ void mod_split (auto &&x, size_t n, auto k) {
221+ using base = std::decay_t <decltype (k)>;
222+ dft<base>::init ();
223+ assert (std::size (x) == 2 * n);
224+ u64x4 cur = u64x4{} + (k * bpow (base (2 ), 32 )).getr ();
225+ for (size_t i = 0 ; i < n; i += flen) {
226+ u64x4 xl = {
227+ x[i].getr (),
228+ x[i + 1 ].getr (),
229+ x[i + 2 ].getr (),
230+ x[i + 3 ].getr ()
231+ };
232+ u64x4 xr = {
233+ x[n + i].getr (),
234+ x[n + i + 1 ].getr (),
235+ x[n + i + 2 ].getr (),
236+ x[n + i + 3 ].getr ()
237+ };
238+ xr = montgomery_mul (xr, cur, dft<base>::mod, dft<base>::imod);
239+ xr = xr >= base::mod () ? xr - base::mod () : xr;
240+ auto t = xr;
241+ xr = xl - t;
242+ xl += t;
243+ xl = xl >= base::mod () ? xl - base::mod () : xl;
244+ xr = xr >= base::mod () ? xr + base::mod () : xr;
245+ for (size_t k = 0 ; k < flen; k++) {
246+ x[i + k].setr (typename base::UInt (xl[k]));
247+ x[n + i + k].setr (typename base::UInt (xr[k]));
248+ }
217249 }
250+ cp_algo::checkpoint (" mod split" );
218251 }
219- void mul (auto &a, auto const & b) {
252+ void cyclic_mul (auto &a, auto &&b, size_t k) {
253+ assert (std::popcount (k) == 1 );
254+ assert (std::size (a) == std::size (b) && std::size (a) == k);
255+ using base = std::decay_t <decltype (a[0 ])>;
256+ dft<base>::init ();
257+ if (k <= (1 << 16 )) {
258+ auto ap = std::ranges::to<std::vector<base, big_alloc<base>>>(a);
259+ mul_truncate (ap, b, 2 * k);
260+ mod_split (ap, k, bpow (dft<base>::factor, k));
261+ std::ranges::copy (ap | std::views::take (k), begin (a));
262+ return ;
263+ }
264+ k /= 2 ;
265+ auto factor = bpow (dft<base>::factor, k);
266+ mod_split (a, k, factor);
267+ mod_split (b, k, factor);
268+ auto la = std::span (a).first (k);
269+ auto lb = std::span (b).first (k);
270+ auto ra = std::span (a).last (k);
271+ auto rb = std::span (b).last (k);
272+ cyclic_mul (la, lb, k);
273+ auto A = dft<base>(ra, k / 2 );
274+ auto B = dft<base>(rb, k / 2 );
275+ A.mul_inplace (B, ra, k);
276+ base i2 = base (2 ).inv ();
277+ factor = factor.inv () * i2;
278+ for (size_t i = 0 ; i < k; i++) {
279+ auto t = (a[i] + a[i + k]) * i2;
280+ a[i + k] = (a[i] - a[i + k]) * factor;
281+ a[i] = t;
282+ }
283+ cp_algo::checkpoint (" mod join" );
284+ }
285+ void cyclic_mul (auto &a, auto const & b, size_t k) {
286+ return cyclic_mul (a, make_copy (b), k);
287+ }
288+ void mul (auto &a, auto &&b) {
220289 size_t N = size (a) + size (b) - 1 ;
221- if (std::max (size (a), size (b)) > (1 << 23 )) {
222- using T = std::decay_t <decltype (a[0 ])>;
223- // do karatsuba to save memory
224- auto n = (std::max (size (a), size (b)) + 1 ) / 2 ;
225- auto a0 = to<std::vector<T, big_alloc<T>>>(a | std::views::take (n));
226- auto a1 = to<std::vector<T, big_alloc<T>>>(a | std::views::drop (n));
227- auto b0 = to<std::vector<T, big_alloc<T>>>(b | std::views::take (n));
228- auto b1 = to<std::vector<T, big_alloc<T>>>(b | std::views::drop (n));
229- a0.resize (n); a1.resize (n);
230- b0.resize (n); b1.resize (n);
231- auto a01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform (std::plus{}, a0, a1));
232- auto b01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform (std::plus{}, b0, b1));
233- checkpoint (" karatsuba split" );
234- mul (a0, b0);
235- mul (a1, b1);
236- mul (a01, b01);
237- a.assign (4 * n, 0 );
238- for (auto [i, ai]: a0 | std::views::enumerate) {
239- a[i] += ai;
240- a[i + n] -= ai;
241- }
242- for (auto [i, ai]: a1 | std::views::enumerate) {
243- a[i + n] -= ai;
244- a[i + 2 * n] += ai;
245- }
246- for (auto [i, ai]: a01 | std::views::enumerate) {
247- a[i + n] += ai;
248- }
290+ if (N > (1 << 19 )) {
291+ size_t NN = std::bit_ceil (N);
292+ a.resize (NN);
293+ b.resize (NN);
294+ cyclic_mul (a, b, NN);
249295 a.resize (N);
250- checkpoint (" karatsuba join" );
251- } else if (size (a)) {
296+ } else {
252297 mul_truncate (a, b, N);
253298 }
254299 }
300+ void mul (auto &a, auto const & b) {
301+ mul (a, make_copy (b));
302+ }
255303}
256304#endif // CP_ALGO_MATH_FFT_HPP
0 commit comments