|
13 | 13 | #endif |
14 | 14 |
|
15 | 15 | namespace cp_algo { |
16 | | - template <typename T> |
17 | | - class big_alloc: public std::allocator<T> { |
18 | | - public: |
| 16 | + template <typename T, std::size_t Align = 32> |
| 17 | + class big_alloc { |
| 18 | + static_assert( Align >= alignof(void*), "Align must be at least pointer-size"); |
| 19 | + static_assert(std::popcount(Align) == 1, "Align must be a power of two"); |
| 20 | + public: |
19 | 21 | using value_type = T; |
20 | | - using base = std::allocator<T>; |
| 22 | + template <class U> struct rebind { using other = big_alloc<U, Align>; }; |
21 | 23 |
|
22 | 24 | big_alloc() noexcept = default; |
| 25 | + template <typename U, std::size_t A> |
| 26 | + big_alloc(const big_alloc<U, A>&) noexcept {} |
23 | 27 |
|
24 | | - template <typename U> |
25 | | - big_alloc(const big_alloc<U>&) noexcept {} |
26 | | - |
27 | | -#if CP_ALGO_USE_MMAP |
28 | 28 | [[nodiscard]] T* allocate(std::size_t n) { |
29 | | - if(n * sizeof(T) < 1024 * 1024) { |
30 | | - return base::allocate(n); |
| 29 | + std::size_t padded = round_up(n * sizeof(T)); |
| 30 | + std::size_t align = std::max<std::size_t>(alignof(T), Align); |
| 31 | +#if CP_ALGO_USE_MMAP |
| 32 | + if (padded >= MEGABYTE) { |
| 33 | + void* raw = mmap(nullptr, padded, |
| 34 | + PROT_READ | PROT_WRITE, |
| 35 | + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); |
| 36 | + madvise(raw, padded, MADV_HUGEPAGE); |
| 37 | + madvise(raw, padded, MADV_POPULATE_WRITE); |
| 38 | + return static_cast<T*>(raw); |
31 | 39 | } |
32 | | - n *= sizeof(T); |
33 | | - void* raw = mmap(nullptr, n, |
34 | | - PROT_READ | PROT_WRITE, |
35 | | - MAP_PRIVATE | MAP_ANONYMOUS, |
36 | | - -1, 0); |
37 | | - madvise(raw, n, MADV_HUGEPAGE); |
38 | | - madvise(raw, n, MADV_POPULATE_WRITE); |
39 | | - return static_cast<T*>(raw); |
40 | | - } |
41 | 40 | #endif |
| 41 | + return static_cast<T*>(::operator new(padded, std::align_val_t(align))); |
| 42 | + } |
42 | 43 |
|
43 | | -#if CP_ALGO_USE_MMAP |
44 | 44 | void deallocate(T* p, std::size_t n) noexcept { |
45 | | - if(n * sizeof(T) < 1024 * 1024) { |
46 | | - return base::deallocate(p, n); |
47 | | - } |
48 | | - if(p) { |
49 | | - munmap(p, n * sizeof(T)); |
50 | | - } |
| 45 | + if (!p) return; |
| 46 | + std::size_t padded = round_up(n * sizeof(T)); |
| 47 | + std::size_t align = std::max<std::size_t>(alignof(T), Align); |
| 48 | + #if CP_ALGO_USE_MMAP |
| 49 | + if (padded >= MEGABYTE) { munmap(p, padded); return; } |
| 50 | + #endif |
| 51 | + ::operator delete(p, padded, std::align_val_t(align)); |
| 52 | + } |
| 53 | + |
| 54 | + private: |
| 55 | + static constexpr std::size_t MEGABYTE = 1 << 20; |
| 56 | + static constexpr std::size_t round_up(std::size_t x) noexcept { |
| 57 | + return (x + Align - 1) / Align * Align; |
51 | 58 | } |
52 | | -#endif |
53 | 59 | }; |
54 | 60 | } |
55 | 61 | #endif // CP_ALGO_UTIL_big_alloc_HPP |
0 commit comments