|
1 | 1 | #ifndef CP_ALGO_MATH_CVECTOR_HPP |
2 | 2 | #define CP_ALGO_MATH_CVECTOR_HPP |
| 3 | +#include "../util/simd.hpp" |
3 | 4 | #include "../util/complex.hpp" |
4 | 5 | #include "../util/checkpoint.hpp" |
5 | 6 | #include "../util/big_alloc.hpp" |
6 | | -#include <experimental/simd> |
7 | 7 | #include <ranges> |
8 | 8 |
|
9 | 9 | namespace stdx = std::experimental; |
10 | 10 | namespace cp_algo::math::fft { |
11 | | - using ftype = double; |
12 | 11 | static constexpr size_t flen = 4; |
13 | | - static constexpr size_t bytes = flen * sizeof(ftype); |
| 12 | + using ftype = double; |
| 13 | + using vftype = simd<ftype, flen>; |
14 | 14 | using point = complex<ftype>; |
15 | | - using vftype [[gnu::vector_size(bytes)]] = ftype; |
16 | 15 | using vpoint = complex<vftype>; |
17 | 16 | static constexpr vftype vz = {}; |
18 | 17 | vpoint vi(vpoint const& r) { |
19 | 18 | return {-imag(r), real(r)}; |
20 | 19 | } |
21 | | - vftype abs(vftype a) { |
22 | | - return a < 0 ? -a : a; |
23 | | - } |
24 | | - using i64x4 [[gnu::vector_size(bytes)]] = int64_t; |
25 | | - using u64x4 [[gnu::vector_size(bytes)]] = uint64_t; |
26 | | - auto lround(vftype a) { |
27 | | - return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4); |
28 | | - } |
29 | | - auto round(vftype a) { |
30 | | - return __builtin_convertvector(lround(a), vftype); |
31 | | - } |
32 | | - u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) { |
33 | | - auto x_ninv = _mm256_mul_epu32(__m256i(x), __m256i(imod)); |
34 | | - auto x_res = _mm256_add_epi64(__m256i(x), _mm256_mul_epu32(x_ninv, __m256i(mod))); |
35 | | - return u64x4(_mm256_bsrli_epi128(x_res, 4)); |
36 | | - } |
37 | | - u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) { |
38 | | - return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod); |
39 | | - } |
40 | 20 |
|
41 | 21 | struct cvector { |
42 | 22 | std::vector<vpoint, big_alloc<vpoint>> r; |
@@ -99,8 +79,7 @@ namespace cp_algo::math::fft { |
99 | 79 | } |
100 | 80 | template<int step> |
101 | 81 | static void exec_on_eval(size_t n, size_t k, auto &&callback) { |
102 | | - point factor = root(4 * step * n); |
103 | | - callback(factor * eval_point(step * k)); |
| 82 | + callback(root(4 * step * n) * eval_point(step * k)); |
104 | 83 | } |
105 | 84 |
|
106 | 85 | void dot(cvector const& t) { |
|
0 commit comments