diff --git a/base/base/sort.h b/base/base/sort.h index 592a899a291..78533e1d455 100644 --- a/base/base/sort.h +++ b/base/base/sort.h @@ -2,27 +2,104 @@ #include +#ifndef NDEBUG +#include +#include +/** Same as libcxx std::__debug_less. Just without dependency on private part of standard library. + * Check that Comparator induce strict weak ordering. + */ +template +class DebugLessComparator +{ +public: + constexpr DebugLessComparator(Comparator & cmp_) + : cmp(cmp_) + {} + template + constexpr bool operator()(const LhsType & lhs, const RhsType & rhs) + { + bool lhs_less_than_rhs = cmp(lhs, rhs); + if (lhs_less_than_rhs) + assert(!cmp(rhs, lhs)); + return lhs_less_than_rhs; + } + template + constexpr bool operator()(LhsType & lhs, RhsType & rhs) + { + bool lhs_less_than_rhs = cmp(lhs, rhs); + if (lhs_less_than_rhs) + assert(!cmp(rhs, lhs)); + return lhs_less_than_rhs; + } +private: + Comparator & cmp; +}; +template +using ComparatorWrapper = DebugLessComparator; +template +void shuffle(RandomIt first, RandomIt last) +{ + static thread_local pcg64 rng(getThreadId()); + std::shuffle(first, last, rng); +} +#else +template +using ComparatorWrapper = Comparator; +#endif #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wold-style-cast" #include -template -void nth_element(RandomIt first, RandomIt nth, RandomIt last) +template +void nth_element(RandomIt first, RandomIt nth, RandomIt last, Compare compare) { - ::miniselect::floyd_rivest_select(first, nth, last); +#ifndef NDEBUG + ::shuffle(first, last); +#endif + + ComparatorWrapper compare_wrapper = compare; + ::miniselect::floyd_rivest_select(first, nth, last, compare_wrapper); + +#ifndef NDEBUG + ::shuffle(first, nth); + + if (nth != last) + ::shuffle(nth + 1, last); +#endif } template -void partial_sort(RandomIt first, RandomIt middle, RandomIt last) +void nth_element(RandomIt first, RandomIt nth, RandomIt last) { - ::miniselect::floyd_rivest_partial_sort(first, middle, last); + using value_type = typename std::iterator_traits::value_type; + using comparator = std::less; + + ::nth_element(first, nth, last, comparator()); } template void partial_sort(RandomIt first, RandomIt middle, RandomIt last, Compare compare) { - ::miniselect::floyd_rivest_partial_sort(first, middle, last, compare); +#ifndef NDEBUG + ::shuffle(first, last); +#endif + + ComparatorWrapper compare_wrapper = compare; + ::miniselect::floyd_rivest_partial_sort(first, middle, last, compare_wrapper); + +#ifndef NDEBUG + ::shuffle(middle, last); +#endif +} + +template +void partial_sort(RandomIt first, RandomIt middle, RandomIt last) +{ + using value_type = typename std::iterator_traits::value_type; + using comparator = std::less; + + ::partial_sort(first, middle, last, comparator()); } #pragma GCC diagnostic pop diff --git a/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md b/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md new file mode 100644 index 00000000000..9bee0c29e7a --- /dev/null +++ b/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md @@ -0,0 +1,45 @@ + --- + toc_priority: 112 + --- + + # groupArraySorted {#groupArraySorted} + + Returns an array with the first N items in ascending order. + + ``` sql + groupArraySorted(N)(column) + ``` + + **Arguments** + + - `N` – The number of elements to return. + + - `column` – The value (Integer, String, Float and other Generic types). + + **Example** + + Gets the first 10 numbers: + + ``` sql + SELECT groupArraySorted(10)(number) FROM numbers(100) + ``` + + ``` text + ┌─groupArraySorted(10)(number)─┐ + │ [0,1,2,3,4,5,6,7,8,9] │ + └──────────────────────────────┘ + ``` + + + Gets all the String implementations of all numbers in column: + + ``` sql +SELECT groupArraySorted(5)(str) FROM (SELECT toString(number) as str FROM numbers(5)); + + ``` + + ``` text +┌─groupArraySorted(5)(str)─┐ +│ ['0','1','2','3','4'] │ +└──────────────────────────┘ + ``` \ No newline at end of file diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index 4b2c059709a..c5b6bf62c93 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -48,7 +48,7 @@ inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataType // return std::make_shared>(argument_type, std::forward(args)...); } - +template AggregateFunctionPtr createAggregateFunctionGroupArray( const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) { @@ -79,9 +79,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArray( ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!limit_size) - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters); else - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems); } AggregateFunctionPtr createAggregateFunctionGroupArraySample( @@ -114,7 +114,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample( else seed = thread_local_rng(); - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems, seed); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems, seed); } } @@ -124,8 +124,9 @@ void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory) { AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true }; - factory.registerFunction("group_array", { createAggregateFunctionGroupArray, properties }); + factory.registerFunction("group_array", { createAggregateFunctionGroupArray, properties }); factory.registerFunction("group_array_sample", { createAggregateFunctionGroupArraySample, properties }); + factory.registerFunction("group_array_last", { createAggregateFunctionGroupArray, properties }); } } diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.h b/src/AggregateFunctions/AggregateFunctionGroupArray.h index c844a6d6cdb..987b6e6356b 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -40,16 +40,19 @@ enum class Sampler DETERMINATOR // TODO }; -template +template struct GroupArrayTrait { static constexpr bool has_limit = Thas_limit; + static constexpr bool last = Tlast; static constexpr Sampler sampler = Tsampler; }; template static constexpr const char * getNameByTrait() { + if (Trait::last) + return "group_array_last"; if (Trait::sampler == Sampler::NONE) return "group_array"; else if (Trait::sampler == Sampler::RNG) @@ -100,6 +103,8 @@ struct GroupArrayNumericData using Allocator = MixedAlignedArenaAllocator; using Array = PODArray; + // For group_array_last() + size_t total_values = 0; Array value; }; @@ -131,11 +136,13 @@ class GroupArrayNumericImpl final DataTypePtr getReturnType() const override { return std::make_shared(this->argument_types[0]); } - void insert(Data & a, const T & v, Arena * arena) const + void insertWithSampler(Data & a, const T & v, Arena * arena) const { ++a.total_values; if (a.value.size() < max_elems) + { a.value.push_back(v, arena); + } else { UInt64 rnd = a.genRandom(a.total_values); @@ -153,25 +160,33 @@ class GroupArrayNumericImpl final void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { + const auto & row_value = assert_cast &>(*columns[0]).getData()[row_num]; + auto & cur_elems = this->data(place); + + ++cur_elems.total_values; + if constexpr (Trait::sampler == Sampler::NONE) { - if (limit_num_elems && this->data(place).value.size() >= max_elems) + if (limit_num_elems && cur_elems.value.size() >= max_elems) + { + if constexpr (Trait::last) + cur_elems.value[(cur_elems.total_values - 1) % max_elems] = row_value; + return; + } - this->data(place).value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + cur_elems.value.push_back(row_value, arena); } if constexpr (Trait::sampler == Sampler::RNG) { - auto & a = this->data(place); - ++a.total_values; - if (a.value.size() < max_elems) - a.value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + if (cur_elems.value.size() < max_elems) + cur_elems.value.push_back(row_value, arena); else { - UInt64 rnd = a.genRandom(a.total_values); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); if (rnd < max_elems) - a.value[rnd] = assert_cast &>(*columns[0]).getData()[row_num]; + cur_elems.value[rnd] = row_value; } } // TODO @@ -180,59 +195,75 @@ class GroupArrayNumericImpl final void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if constexpr (Trait::sampler == Sampler::NONE) - { - auto & cur_elems = this->data(place); - auto & rhs_elems = this->data(rhs); + auto & cur_elems = this->data(place); + auto & rhs_elems = this->data(rhs); - if (!limit_num_elems) - { - if (rhs_elems.value.size()) - cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena); - } - else - { - UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); - if (elems_to_insert) - cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena); - } - } + if (rhs_elems.value.empty()) + return; - if constexpr (Trait::sampler == Sampler::RNG) + if constexpr (Trait::last) + mergeNoSamplerLast(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::NONE) + mergeNoSampler(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::RNG) + mergeWithRNGSampler(cur_elems, rhs_elems, arena); + } + + void mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + UInt64 new_elements = std::min(static_cast(max_elems), cur_elems.value.size() + rhs_elems.value.size()); + cur_elems.value.resize(new_elements, arena); + for (auto & value : rhs_elems.value) { - if (this->data(rhs).value.empty()) /// rhs state is empty - return; + cur_elems.value[cur_elems.total_values % max_elems] = value; + ++cur_elems.total_values; + } + assert(rhs_elems.total_values >= rhs_elems.value.size()); + cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size(); + } - auto & a = this->data(place); - auto & b = this->data(rhs); + void mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + if (!limit_num_elems) + { + if (rhs_elems.value.size()) + cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena); + } + else + { + UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); + if (elems_to_insert) + cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena); + } + } - if (b.total_values <= max_elems) - { - for (size_t i = 0; i < b.value.size(); ++i) - insert(a, b.value[i], arena); - } - else if (a.total_values <= max_elems) - { - decltype(a.value) from; - from.swap(a.value, arena); - a.value.assign(b.value.begin(), b.value.end(), arena); - a.total_values = b.total_values; - for (size_t i = 0; i < from.size(); ++i) - insert(a, from[i], arena); - } - else + void mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + if (rhs_elems.total_values <= max_elems) + { + for (size_t i = 0; i < rhs_elems.value.size(); ++i) + insertWithSampler(cur_elems, rhs_elems.value[i], arena); + } + else if (cur_elems.total_values <= max_elems) + { + decltype(cur_elems.value) from; + from.swap(cur_elems.value, arena); + cur_elems.value.assign(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + cur_elems.total_values = rhs_elems.total_values; + for (size_t i = 0; i < from.size(); ++i) + insertWithSampler(cur_elems, from[i], arena); + } + else + { + cur_elems.randomShuffle(); + cur_elems.total_values += rhs_elems.total_values; + for (size_t i = 0; i < max_elems; ++i) { - a.randomShuffle(); - a.total_values += b.total_values; - for (size_t i = 0; i < max_elems; ++i) - { - UInt64 rnd = a.genRandom(a.total_values); - if (rnd < b.total_values) - a.value[i] = b.value[i]; - } + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); + if (rnd < rhs_elems.total_values) + cur_elems.value[i] = rhs_elems.value[i]; } } - // TODO // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } @@ -244,6 +275,9 @@ class GroupArrayNumericImpl final writeVarUInt(size, buf); buf.write(reinterpret_cast(value.data()), size * sizeof(value[0])); + if constexpr (Trait::last) + DB::writeIntBinary(this->data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::writeIntBinary(this->data(place).total_values, buf); @@ -272,6 +306,9 @@ class GroupArrayNumericImpl final value.resize(size, arena); buf.readStrict(reinterpret_cast(value.data()), size * sizeof(value[0])); + if constexpr (Trait::last) + DB::readIntBinary(this->data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::readIntBinary(this->data(place).total_values, buf); @@ -398,6 +435,8 @@ struct GroupArrayGeneralData using Allocator = MixedAlignedArenaAllocator; using Array = PODArray; + // For group_array_last() + size_t total_values = 0; Array value; }; @@ -434,7 +473,7 @@ class GroupArrayGeneralImpl final DataTypePtr getReturnType() const override { return std::make_shared(data_type); } - void insert(Data & a, const Node * v, Arena * arena) const + void insertWithSampler(Data & a, const Node * v, Arena * arena) const { ++a.total_values; if (a.value.size() < max_elems) @@ -454,28 +493,49 @@ class GroupArrayGeneralImpl final a->rng.seed(seed); } + void destroy(AggregateDataPtr __restrict place) const noexcept override + { + data(place).~Data(); + } + + bool hasTrivialDestructor() const override + { + return std::is_trivially_destructible_v; + } + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { + auto & cur_elems = data(place); + + ++cur_elems.total_values; + if constexpr (Trait::sampler == Sampler::NONE) { - if (limit_num_elems && data(place).value.size() >= max_elems) + if (limit_num_elems && cur_elems.value.size() >= max_elems) + { + if (Trait::last) + { + Node * node = Node::allocate(*columns[0], row_num, arena); + cur_elems.value[(cur_elems.total_values - 1) % max_elems] = node; + } return; + } Node * node = Node::allocate(*columns[0], row_num, arena); - data(place).value.push_back(node, arena); + cur_elems.value.push_back(node, arena); } if constexpr (Trait::sampler == Sampler::RNG) { - auto & a = data(place); - ++a.total_values; - if (a.value.size() < max_elems) - a.value.push_back(Node::allocate(*columns[0], row_num, arena), arena); + if (cur_elems.value.size() < max_elems) + { + cur_elems.value.push_back(Node::allocate(*columns[0], row_num, arena), arena); + } else { - UInt64 rnd = a.genRandom(a.total_values); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); if (rnd < max_elems) - a.value[rnd] = Node::allocate(*columns[0], row_num, arena); + cur_elems.value[rnd] = Node::allocate(*columns[0], row_num, arena); } } // TODO @@ -484,68 +544,83 @@ class GroupArrayGeneralImpl final void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if constexpr (Trait::sampler == Sampler::NONE) - mergeNoSampler(place, rhs, arena); + auto & cur_elems = data(place); + auto & rhs_elems = data(rhs); + + if (rhs_elems.value.empty()) + return; + + if constexpr (Trait::last) + mergeNoSamplerLast(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::NONE) + mergeNoSampler(cur_elems, rhs_elems, arena); else if constexpr (Trait::sampler == Sampler::RNG) - mergeWithRNGSampler(place, rhs, arena); + mergeWithRNGSampler(cur_elems, rhs_elems, arena); // TODO // else if constexpr (Trait::sampler == Sampler::DETERMINATOR) } - void ALWAYS_INLINE mergeNoSampler(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const + void ALWAYS_INLINE mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const { - if (data(rhs).value.empty()) /// rhs state is empty - return; + UInt64 new_elements = std::min(static_cast(max_elems), cur_elems.value.size() + rhs_elems.value.size()); + cur_elems.value.resize(new_elements, arena); + + for (auto & value : rhs_elems.value) + { + cur_elems.value[cur_elems.total_values % max_elems] = value->clone(arena); + ++cur_elems.total_values; + } + + + assert(rhs_elems.total_values >= rhs_elems.value.size()); + cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size(); + } + void ALWAYS_INLINE mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { UInt64 new_elems; if (limit_num_elems) { - if (data(place).value.size() >= max_elems) + if (cur_elems.value.size() >= max_elems) return; - - new_elems = std::min(data(rhs).value.size(), static_cast(max_elems) - data(place).value.size()); + new_elems = std::min(rhs_elems.value.size(), static_cast(max_elems) - cur_elems.value.size()); } else - new_elems = data(rhs).value.size(); + new_elems = rhs_elems.value.size(); - auto & a = data(place).value; - auto & b = data(rhs).value; for (UInt64 i = 0; i < new_elems; ++i) - a.push_back(b[i]->clone(arena), arena); + cur_elems.value.push_back(rhs_elems.value[i]->clone(arena), arena); + } - void ALWAYS_INLINE mergeWithRNGSampler(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const + void ALWAYS_INLINE mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const { - if (data(rhs).value.empty()) /// rhs state is empty - return; - - auto & a = data(place); - auto & b = data(rhs); - - if (b.total_values <= max_elems) + if (rhs_elems.total_values <= max_elems) { - for (size_t i = 0; i < b.value.size(); ++i) - insert(a, b.value[i], arena); + for (size_t i = 0; i < rhs_elems.value.size(); ++i) + insertWithSampler(cur_elems, rhs_elems.value[i], arena); } - else if (a.total_values <= max_elems) + else if (cur_elems.total_values <= max_elems) { - decltype(a.value) from; - from.swap(a.value, arena); - for (auto & node : b.value) - a.value.push_back(node->clone(arena), arena); - a.total_values = b.total_values; + decltype(cur_elems.value) from; + from.swap(cur_elems.value, arena); + for (auto & node : rhs_elems.value) + cur_elems.value.push_back(node->clone(arena), arena); + + cur_elems.total_values = rhs_elems.total_values; for (size_t i = 0; i < from.size(); ++i) - insert(a, from[i], arena); + insertWithSampler(cur_elems, from[i], arena); } else { - a.randomShuffle(); - a.total_values += b.total_values; + cur_elems.randomShuffle(); + cur_elems.total_values += rhs_elems.total_values; for (size_t i = 0; i < max_elems; ++i) { - UInt64 rnd = a.genRandom(a.total_values); - if (rnd < b.total_values) - a.value[i] = b.value[i]->clone(arena); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); + if (rnd < rhs_elems.total_values) + cur_elems.value[i] = rhs_elems.value[i]->clone(arena); + } } } @@ -558,6 +633,9 @@ class GroupArrayGeneralImpl final for (auto & node : value) node->write(buf); + if constexpr (Trait::last) + DB::writeIntBinary(data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::writeIntBinary(data(place).total_values, buf); @@ -590,6 +668,9 @@ class GroupArrayGeneralImpl final for (UInt64 i = 0; i < elems; ++i) value[i] = Node::read(buf, arena); + if constexpr (Trait::last) + DB::readIntBinary(data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::readIntBinary(data(place).total_values, buf); diff --git a/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp b/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp new file mode 100644 index 00000000000..df99d026cb2 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp @@ -0,0 +1,431 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ + +struct Settings; + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; + extern const int TOO_LARGE_ARRAY_SIZE; +} + +namespace +{ + +enum class GroupArraySortedStrategy : uint8_t +{ + heap, + sort +}; + + +constexpr size_t group_array_sorted_sort_strategy_max_elements_threshold = 1000000; + +template +struct GroupArraySortedData +{ + static constexpr bool is_value_generic_field = std::is_same_v; + + using Allocator = MixedAlignedArenaAllocator; + using Array = typename std::conditional_t, PODArray>; + + static constexpr size_t partial_sort_max_elements_factor = 2; + + Array values; + + static bool compare(const T & lhs, const T & rhs) + { + if constexpr (is_value_generic_field) + { + return lhs < rhs; + } + else + { + return CompareHelper::less(lhs, rhs, -1); + } + } + + struct Comparator + { + bool operator()(const T & lhs, const T & rhs) + { + return compare(lhs, rhs); + } + }; + + ALWAYS_INLINE void heapReplaceTop() + { + size_t size = values.size(); + if (size < 2) + return; + + size_t child_index = 1; + + if (values.size() > 2 && compare(values[1], values[2])) + ++child_index; + + /// Check if we are in order + if (compare(values[child_index], values[0])) + return; + + size_t current_index = 0; + auto current = values[current_index]; + + do + { + /// We are not in heap-order, swap the parent with it's largest child. + values[current_index] = values[child_index]; + current_index = child_index; + + // Recompute the child based off of the updated parent + child_index = 2 * child_index + 1; + + if (child_index >= size) + break; + + if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1])) + { + /// Right child exists and is greater than left child. + ++child_index; + } + + /// Check if we are in order. + } while (!compare(values[child_index], current)); + + values[current_index] = current; + } + + ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena) + { + if constexpr (is_value_generic_field) + { + ::sort(values.begin(), values.end(), Comparator()); + } + else + { + bool try_sort = trySort(values.begin(), values.end(), Comparator()); + if (!try_sort) + RadixSort>::executeLSD(values.data(), values.size()); + } + + if (values.size() > max_elements) + resize(max_elements, arena); + } + + ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena) + { + if (values.size() < max_elements * partial_sort_max_elements_factor) + return; + + ::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator()); + resize(max_elements, arena); + } + + ALWAYS_INLINE void resize(size_t n, Arena * arena) + { + if constexpr (is_value_generic_field) + values.resize(n); + else + values.resize(n, arena); + } + + ALWAYS_INLINE void push_back(T && element, Arena * arena) + { + if constexpr (is_value_generic_field) + values.push_back(element); + else + values.push_back(element, arena); + } + + ALWAYS_INLINE void addElement(T && element, size_t max_elements, Arena * arena) + { + if constexpr (strategy == GroupArraySortedStrategy::heap) + { + if (values.size() >= max_elements) + { + /// Element is greater or equal than current max element, it cannot be in k min elements + if (!compare(element, values[0])) + return; + + values[0] = std::move(element); + heapReplaceTop(); + return; + } + + push_back(std::move(element), arena); + std::push_heap(values.begin(), values.end(), Comparator()); + } + else + { + push_back(std::move(element), arena); + partialSortAndLimitIfNeeded(max_elements, arena); + } + } + + ALWAYS_INLINE void insertResultInto(IColumn & to, size_t max_elements, Arena * arena) + { + auto & result_array = assert_cast(to); + auto & result_array_offsets = result_array.getOffsets(); + + sortAndLimit(max_elements, arena); + + result_array_offsets.push_back(result_array_offsets.back() + values.size()); + + if (values.empty()) + return; + + if constexpr (is_value_generic_field) + { + auto & result_array_data = result_array.getData(); + for (auto & value : values) + result_array_data.insert(value); + } + else + { + auto & result_array_data = assert_cast &>(result_array.getData()).getData(); + + size_t result_array_data_insert_begin = result_array_data.size(); + result_array_data.resize(result_array_data_insert_begin + values.size()); + + for (size_t i = 0; i < values.size(); ++i) + result_array_data[result_array_data_insert_begin + i] = values[i]; + } + } +}; + +template +using GroupArraySortedDataHeap = GroupArraySortedData; + +template +using GroupArraySortedDataSort = GroupArraySortedData; + +constexpr UInt64 aggregate_function_group_array_sorted_max_element_size = 0xFFFFFF; + +template +class GroupArraySorted final + : public IAggregateFunctionDataHelper> +{ +public: + explicit GroupArraySorted( + const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elements_) + : IAggregateFunctionDataHelper>( + {data_type_}, parameters_, std::make_shared(data_type_)) + , max_elements(max_elements_) + , serialization(data_type_->getDefaultSerialization()) + { + if (max_elements > aggregate_function_group_array_sorted_max_element_size) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Too large limit parameter for groupArraySorted aggregate function, it should not exceed {}", + aggregate_function_group_array_sorted_max_element_size); + } + + String getName() const override { return "group_array_sorted"; } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + if constexpr (std::is_same_v) + { + auto row_value = (*columns[0])[row_num]; + this->data(place).addElement(std::move(row_value), max_elements, arena); + } + else + { + auto row_value = assert_cast &>(*columns[0]).getData()[row_num]; + this->data(place).addElement(std::move(row_value), max_elements, arena); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & rhs_values = this->data(rhs).values; + for (auto rhs_element : rhs_values) + this->data(place).addElement(std::move(rhs_element), max_elements, arena); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + { + auto & values = this->data(place).values; + size_t size = values.size(); + writeVarUInt(size, buf); + + if constexpr (std::is_same_v) + { + for (const Field & element : values) + { + if (element.isNull()) + { + writeBinary(false, buf); + } + else + { + writeBinary(true, buf); + serialization->serializeBinary(element, buf, {}); + } + } + } + else + { + if constexpr (std::endian::native == std::endian::little) + { + buf.write(reinterpret_cast(values.data()), size * sizeof(values[0])); + } + else + { + for (const auto & element : values) + writeBinary(element, buf); + } + } + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena * arena) const override + { + size_t size = 0; + readVarUInt(size, buf); + + if (unlikely(size > max_elements)) + throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elements); + + auto & values = this->data(place).values; + + if constexpr (Data::is_value_generic_field) + { + values.resize(size); + for (Field & element : values) + { + bool has_value = false; + readBinary(has_value, buf); + if (has_value) + serialization->deserializeBinary(element, buf, {}); + } + } + else + { + values.resize(size, arena); + if constexpr (std::endian::native == std::endian::little) + { + buf.readStrict(reinterpret_cast(values.data()), size * sizeof(values[0])); + } + else + { + for (auto & element : values) + readBinary(element, buf); + } + } + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + this->data(place).insertResultInto(to, max_elements, arena); + } + + bool allocatesMemoryInArena() const override { return true; } + +private: + UInt64 max_elements; + SerializationPtr serialization; +}; + +template +using GroupArraySortedHeap = GroupArraySorted, T>; + +template +using GroupArraySortedSort = GroupArraySorted, T>; + +template