|
| 1 | +#ifndef CP_ALGO_MATH_FFT64_HPP |
| 2 | +#define CP_ALGO_MATH_FFT64_HPP |
| 3 | +#include "../random/rng.hpp" |
| 4 | +#include "../math/common.hpp" |
| 5 | +#include "../math/cvector.hpp" |
| 6 | + |
| 7 | +namespace cp_algo::math::fft { |
| 8 | + struct dft64 { |
| 9 | + std::vector<cp_algo::math::fft::cvector> cv; |
| 10 | + |
| 11 | + static uint64_t factor, ifactor; |
| 12 | + static bool _init; |
| 13 | + |
| 14 | + static void init() { |
| 15 | + if(_init) return; |
| 16 | + _init = true; |
| 17 | + factor = random::rng(); |
| 18 | + if(factor % 2 == 0) {factor++;} |
| 19 | + ifactor = inv2(factor); |
| 20 | + } |
| 21 | + |
| 22 | + dft64(auto const& a, size_t n): cv(4, n) { |
| 23 | + init(); |
| 24 | + uint64_t cur = 1, step = bpow(factor, n); |
| 25 | + for(size_t i = 0; i < std::min(std::size(a), n); i++) { |
| 26 | + auto split = [&](size_t i, uint64_t mul) -> std::array<int16_t, 4> { |
| 27 | + uint64_t x = i < std::size(a) ? a[i] * mul : 0; |
| 28 | + std::array<int16_t, 4> res; |
| 29 | + for(int z = 0; z < 4; z++) { |
| 30 | + res[z] = int16_t(x); |
| 31 | + x = (x >> 16) + (res[z] < 0); |
| 32 | + } |
| 33 | + return res; |
| 34 | + }; |
| 35 | + auto re = split(i, cur); |
| 36 | + auto im = split(n + i, cur * step); |
| 37 | + for(int z = 0; z < 4; z++) { |
| 38 | + real(cv[z].at(i))[i % 4] = re[z]; |
| 39 | + imag(cv[z].at(i))[i % 4] = im[z]; |
| 40 | + } |
| 41 | + cur *= factor; |
| 42 | + } |
| 43 | + checkpoint("dft64 init"); |
| 44 | + for(auto &x: cv) { |
| 45 | + x.fft(); |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + void dot(dft64 const& t) { |
| 50 | + size_t N = cv[0].size(); |
| 51 | + cvector::exec_on_evals<1>(N / flen, [&](size_t k, point rt) { |
| 52 | + k *= flen; |
| 53 | + auto [A0x, A0y] = cv[0].at(k); |
| 54 | + auto [A1x, A1y] = cv[1].at(k); |
| 55 | + auto [A2x, A2y] = cv[2].at(k); |
| 56 | + auto [A3x, A3y] = cv[3].at(k); |
| 57 | + std::array B = { |
| 58 | + t.cv[0].at(k), |
| 59 | + t.cv[1].at(k), |
| 60 | + t.cv[2].at(k), |
| 61 | + t.cv[3].at(k) |
| 62 | + }; |
| 63 | + |
| 64 | + std::array<vpoint, 4> C = {vz, vz, vz, vz}; |
| 65 | + for (size_t i = 0; i < flen; i++) { |
| 66 | + std::array A = { |
| 67 | + vpoint{vz + A0x[i], vz + A0y[i]}, |
| 68 | + vpoint{vz + A1x[i], vz + A1y[i]}, |
| 69 | + vpoint{vz + A2x[i], vz + A2y[i]}, |
| 70 | + vpoint{vz + A3x[i], vz + A3y[i]} |
| 71 | + }; |
| 72 | + for(size_t k = 0; k < 4; k++) { |
| 73 | + for(size_t i = 0; i <= k; i++) { |
| 74 | + C[k] += A[i] * B[k - i]; |
| 75 | + } |
| 76 | + } |
| 77 | + for(size_t k = 0; k < 4; k++) { |
| 78 | + real(B[k]) = rotate_right(real(B[k])); |
| 79 | + imag(B[k]) = rotate_right(imag(B[k])); |
| 80 | + auto bx = real(B[k])[0], by = imag(B[k])[0]; |
| 81 | + real(B[k])[0] = bx * real(rt) - by * imag(rt); |
| 82 | + imag(B[k])[0] = bx * imag(rt) + by * real(rt); |
| 83 | + } |
| 84 | + } |
| 85 | + cv[0].at(k) = C[0]; |
| 86 | + cv[1].at(k) = C[1]; |
| 87 | + cv[2].at(k) = C[2]; |
| 88 | + cv[3].at(k) = C[3]; |
| 89 | + }); |
| 90 | + checkpoint("dot"); |
| 91 | + for(auto &x: cv) { |
| 92 | + x.ifft(); |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + void recover_mod(auto &res, size_t k) { |
| 97 | + size_t n = cv[0].size(); |
| 98 | + uint64_t cur = 1, step = bpow(ifactor, n); |
| 99 | + for(size_t i = 0; i < std::min(k, n); i++) { |
| 100 | + std::array re = {real(cv[0].get(i)), real(cv[1].get(i)), real(cv[2].get(i)), real(cv[3].get(i))}; |
| 101 | + std::array im = {imag(cv[0].get(i)), imag(cv[1].get(i)), imag(cv[2].get(i)), imag(cv[3].get(i))}; |
| 102 | + auto set_i = [&](size_t i, auto &x, auto mul) { |
| 103 | + if (i >= k) return; |
| 104 | + res[i] = llround(x[0]) + (llround(x[1]) << 16) + (llround(x[2]) << 32) + (llround(x[3]) << 48); |
| 105 | + res[i] *= mul; |
| 106 | + }; |
| 107 | + set_i(i, re, cur); |
| 108 | + set_i(n + i, im, cur * step); |
| 109 | + cur *= ifactor; |
| 110 | + } |
| 111 | + cp_algo::checkpoint("recover mod"); |
| 112 | + } |
| 113 | + }; |
| 114 | + uint64_t dft64::factor = 1, dft64::ifactor = 1; |
| 115 | + bool dft64::_init = false; |
| 116 | + |
| 117 | + void conv64(auto& a, auto const& b) { |
| 118 | + size_t n = a.size(), m = b.size(); |
| 119 | + size_t N = std::max(flen, std::bit_ceil(n + m - 1) / 2); |
| 120 | + dft64 A(a, N), B(b, N); |
| 121 | + A.dot(B); |
| 122 | + a.resize(n + m - 1); |
| 123 | + A.recover_mod(a, n + m - 1); |
| 124 | + } |
| 125 | +} |
| 126 | +#endif // CP_ALGO_MATH_FFT64_HPP |
0 commit comments