From 53b3a3b9bf303b1efaab9309c7804e0b66bbc13a Mon Sep 17 00:00:00 2001 From: Jasmine-ge Date: Thu, 20 Feb 2025 11:29:25 +0800 Subject: [PATCH 1/2] fix test and comment support group array last Introduce groupArrayLast() (useful to store last X values) (#44521) * Cleanup DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions() Signed-off-by: Azat Khuzhin * Remove unused GroupArrayGeneralListImpl Signed-off-by: Azat Khuzhin * Introduce groupArrayLast() (useful to store last X values) Also do some refactoring to make code cleaner: - rename insert() to insertWithSampler() (since it is used only for groupArraySample()) - split merge methods into Last/RNG/... Signed-off-by: Azat Khuzhin Signed-off-by: Azat Khuzhin Bugfix/issue 7122 group array missing arg (#7135) (#7198) --- .../AggregateFunctionGroupArray.cpp | 11 +- .../AggregateFunctionGroupArray.h | 279 +++++++++++------- .../DataTypeCustomSimpleAggregateFunction.cpp | 2 +- src/Interpreters/ExpressionAnalyzer.cpp | 76 ++++- .../02520_group_array_last.reference | 45 +++ .../0_stateless/02520_group_array_last.sql | 32 ++ ...rray_additional_arg_for_max_elem.reference | 10 + ...roup_array_additional_arg_for_max_elem.sql | 53 ++++ 8 files changed, 395 insertions(+), 113 deletions(-) create mode 100644 tests/queries_ported/0_stateless/02520_group_array_last.reference create mode 100644 tests/queries_ported/0_stateless/02520_group_array_last.sql create mode 100644 tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.reference create mode 100644 tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.sql diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index 4b2c059709a..c5b6bf62c93 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -48,7 +48,7 @@ inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataType // return std::make_shared>(argument_type, std::forward(args)...); } - +template AggregateFunctionPtr createAggregateFunctionGroupArray( const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) { @@ -79,9 +79,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArray( ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!limit_size) - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters); else - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems); } AggregateFunctionPtr createAggregateFunctionGroupArraySample( @@ -114,7 +114,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample( else seed = thread_local_rng(); - return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems, seed); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], parameters, max_elems, seed); } } @@ -124,8 +124,9 @@ void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory) { AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true }; - factory.registerFunction("group_array", { createAggregateFunctionGroupArray, properties }); + factory.registerFunction("group_array", { createAggregateFunctionGroupArray, properties }); factory.registerFunction("group_array_sample", { createAggregateFunctionGroupArraySample, properties }); + factory.registerFunction("group_array_last", { createAggregateFunctionGroupArray, properties }); } } diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.h b/src/AggregateFunctions/AggregateFunctionGroupArray.h index c844a6d6cdb..987b6e6356b 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -40,16 +40,19 @@ enum class Sampler DETERMINATOR // TODO }; -template +template struct GroupArrayTrait { static constexpr bool has_limit = Thas_limit; + static constexpr bool last = Tlast; static constexpr Sampler sampler = Tsampler; }; template static constexpr const char * getNameByTrait() { + if (Trait::last) + return "group_array_last"; if (Trait::sampler == Sampler::NONE) return "group_array"; else if (Trait::sampler == Sampler::RNG) @@ -100,6 +103,8 @@ struct GroupArrayNumericData using Allocator = MixedAlignedArenaAllocator; using Array = PODArray; + // For group_array_last() + size_t total_values = 0; Array value; }; @@ -131,11 +136,13 @@ class GroupArrayNumericImpl final DataTypePtr getReturnType() const override { return std::make_shared(this->argument_types[0]); } - void insert(Data & a, const T & v, Arena * arena) const + void insertWithSampler(Data & a, const T & v, Arena * arena) const { ++a.total_values; if (a.value.size() < max_elems) + { a.value.push_back(v, arena); + } else { UInt64 rnd = a.genRandom(a.total_values); @@ -153,25 +160,33 @@ class GroupArrayNumericImpl final void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { + const auto & row_value = assert_cast &>(*columns[0]).getData()[row_num]; + auto & cur_elems = this->data(place); + + ++cur_elems.total_values; + if constexpr (Trait::sampler == Sampler::NONE) { - if (limit_num_elems && this->data(place).value.size() >= max_elems) + if (limit_num_elems && cur_elems.value.size() >= max_elems) + { + if constexpr (Trait::last) + cur_elems.value[(cur_elems.total_values - 1) % max_elems] = row_value; + return; + } - this->data(place).value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + cur_elems.value.push_back(row_value, arena); } if constexpr (Trait::sampler == Sampler::RNG) { - auto & a = this->data(place); - ++a.total_values; - if (a.value.size() < max_elems) - a.value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + if (cur_elems.value.size() < max_elems) + cur_elems.value.push_back(row_value, arena); else { - UInt64 rnd = a.genRandom(a.total_values); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); if (rnd < max_elems) - a.value[rnd] = assert_cast &>(*columns[0]).getData()[row_num]; + cur_elems.value[rnd] = row_value; } } // TODO @@ -180,59 +195,75 @@ class GroupArrayNumericImpl final void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if constexpr (Trait::sampler == Sampler::NONE) - { - auto & cur_elems = this->data(place); - auto & rhs_elems = this->data(rhs); + auto & cur_elems = this->data(place); + auto & rhs_elems = this->data(rhs); - if (!limit_num_elems) - { - if (rhs_elems.value.size()) - cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena); - } - else - { - UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); - if (elems_to_insert) - cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena); - } - } + if (rhs_elems.value.empty()) + return; - if constexpr (Trait::sampler == Sampler::RNG) + if constexpr (Trait::last) + mergeNoSamplerLast(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::NONE) + mergeNoSampler(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::RNG) + mergeWithRNGSampler(cur_elems, rhs_elems, arena); + } + + void mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + UInt64 new_elements = std::min(static_cast(max_elems), cur_elems.value.size() + rhs_elems.value.size()); + cur_elems.value.resize(new_elements, arena); + for (auto & value : rhs_elems.value) { - if (this->data(rhs).value.empty()) /// rhs state is empty - return; + cur_elems.value[cur_elems.total_values % max_elems] = value; + ++cur_elems.total_values; + } + assert(rhs_elems.total_values >= rhs_elems.value.size()); + cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size(); + } - auto & a = this->data(place); - auto & b = this->data(rhs); + void mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + if (!limit_num_elems) + { + if (rhs_elems.value.size()) + cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena); + } + else + { + UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); + if (elems_to_insert) + cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena); + } + } - if (b.total_values <= max_elems) - { - for (size_t i = 0; i < b.value.size(); ++i) - insert(a, b.value[i], arena); - } - else if (a.total_values <= max_elems) - { - decltype(a.value) from; - from.swap(a.value, arena); - a.value.assign(b.value.begin(), b.value.end(), arena); - a.total_values = b.total_values; - for (size_t i = 0; i < from.size(); ++i) - insert(a, from[i], arena); - } - else + void mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { + if (rhs_elems.total_values <= max_elems) + { + for (size_t i = 0; i < rhs_elems.value.size(); ++i) + insertWithSampler(cur_elems, rhs_elems.value[i], arena); + } + else if (cur_elems.total_values <= max_elems) + { + decltype(cur_elems.value) from; + from.swap(cur_elems.value, arena); + cur_elems.value.assign(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + cur_elems.total_values = rhs_elems.total_values; + for (size_t i = 0; i < from.size(); ++i) + insertWithSampler(cur_elems, from[i], arena); + } + else + { + cur_elems.randomShuffle(); + cur_elems.total_values += rhs_elems.total_values; + for (size_t i = 0; i < max_elems; ++i) { - a.randomShuffle(); - a.total_values += b.total_values; - for (size_t i = 0; i < max_elems; ++i) - { - UInt64 rnd = a.genRandom(a.total_values); - if (rnd < b.total_values) - a.value[i] = b.value[i]; - } + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); + if (rnd < rhs_elems.total_values) + cur_elems.value[i] = rhs_elems.value[i]; } } - // TODO // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } @@ -244,6 +275,9 @@ class GroupArrayNumericImpl final writeVarUInt(size, buf); buf.write(reinterpret_cast(value.data()), size * sizeof(value[0])); + if constexpr (Trait::last) + DB::writeIntBinary(this->data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::writeIntBinary(this->data(place).total_values, buf); @@ -272,6 +306,9 @@ class GroupArrayNumericImpl final value.resize(size, arena); buf.readStrict(reinterpret_cast(value.data()), size * sizeof(value[0])); + if constexpr (Trait::last) + DB::readIntBinary(this->data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::readIntBinary(this->data(place).total_values, buf); @@ -398,6 +435,8 @@ struct GroupArrayGeneralData using Allocator = MixedAlignedArenaAllocator; using Array = PODArray; + // For group_array_last() + size_t total_values = 0; Array value; }; @@ -434,7 +473,7 @@ class GroupArrayGeneralImpl final DataTypePtr getReturnType() const override { return std::make_shared(data_type); } - void insert(Data & a, const Node * v, Arena * arena) const + void insertWithSampler(Data & a, const Node * v, Arena * arena) const { ++a.total_values; if (a.value.size() < max_elems) @@ -454,28 +493,49 @@ class GroupArrayGeneralImpl final a->rng.seed(seed); } + void destroy(AggregateDataPtr __restrict place) const noexcept override + { + data(place).~Data(); + } + + bool hasTrivialDestructor() const override + { + return std::is_trivially_destructible_v; + } + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { + auto & cur_elems = data(place); + + ++cur_elems.total_values; + if constexpr (Trait::sampler == Sampler::NONE) { - if (limit_num_elems && data(place).value.size() >= max_elems) + if (limit_num_elems && cur_elems.value.size() >= max_elems) + { + if (Trait::last) + { + Node * node = Node::allocate(*columns[0], row_num, arena); + cur_elems.value[(cur_elems.total_values - 1) % max_elems] = node; + } return; + } Node * node = Node::allocate(*columns[0], row_num, arena); - data(place).value.push_back(node, arena); + cur_elems.value.push_back(node, arena); } if constexpr (Trait::sampler == Sampler::RNG) { - auto & a = data(place); - ++a.total_values; - if (a.value.size() < max_elems) - a.value.push_back(Node::allocate(*columns[0], row_num, arena), arena); + if (cur_elems.value.size() < max_elems) + { + cur_elems.value.push_back(Node::allocate(*columns[0], row_num, arena), arena); + } else { - UInt64 rnd = a.genRandom(a.total_values); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); if (rnd < max_elems) - a.value[rnd] = Node::allocate(*columns[0], row_num, arena); + cur_elems.value[rnd] = Node::allocate(*columns[0], row_num, arena); } } // TODO @@ -484,68 +544,83 @@ class GroupArrayGeneralImpl final void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if constexpr (Trait::sampler == Sampler::NONE) - mergeNoSampler(place, rhs, arena); + auto & cur_elems = data(place); + auto & rhs_elems = data(rhs); + + if (rhs_elems.value.empty()) + return; + + if constexpr (Trait::last) + mergeNoSamplerLast(cur_elems, rhs_elems, arena); + else if constexpr (Trait::sampler == Sampler::NONE) + mergeNoSampler(cur_elems, rhs_elems, arena); else if constexpr (Trait::sampler == Sampler::RNG) - mergeWithRNGSampler(place, rhs, arena); + mergeWithRNGSampler(cur_elems, rhs_elems, arena); // TODO // else if constexpr (Trait::sampler == Sampler::DETERMINATOR) } - void ALWAYS_INLINE mergeNoSampler(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const + void ALWAYS_INLINE mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const { - if (data(rhs).value.empty()) /// rhs state is empty - return; + UInt64 new_elements = std::min(static_cast(max_elems), cur_elems.value.size() + rhs_elems.value.size()); + cur_elems.value.resize(new_elements, arena); + + for (auto & value : rhs_elems.value) + { + cur_elems.value[cur_elems.total_values % max_elems] = value->clone(arena); + ++cur_elems.total_values; + } + + + assert(rhs_elems.total_values >= rhs_elems.value.size()); + cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size(); + } + void ALWAYS_INLINE mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const + { UInt64 new_elems; if (limit_num_elems) { - if (data(place).value.size() >= max_elems) + if (cur_elems.value.size() >= max_elems) return; - - new_elems = std::min(data(rhs).value.size(), static_cast(max_elems) - data(place).value.size()); + new_elems = std::min(rhs_elems.value.size(), static_cast(max_elems) - cur_elems.value.size()); } else - new_elems = data(rhs).value.size(); + new_elems = rhs_elems.value.size(); - auto & a = data(place).value; - auto & b = data(rhs).value; for (UInt64 i = 0; i < new_elems; ++i) - a.push_back(b[i]->clone(arena), arena); + cur_elems.value.push_back(rhs_elems.value[i]->clone(arena), arena); + } - void ALWAYS_INLINE mergeWithRNGSampler(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const + void ALWAYS_INLINE mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const { - if (data(rhs).value.empty()) /// rhs state is empty - return; - - auto & a = data(place); - auto & b = data(rhs); - - if (b.total_values <= max_elems) + if (rhs_elems.total_values <= max_elems) { - for (size_t i = 0; i < b.value.size(); ++i) - insert(a, b.value[i], arena); + for (size_t i = 0; i < rhs_elems.value.size(); ++i) + insertWithSampler(cur_elems, rhs_elems.value[i], arena); } - else if (a.total_values <= max_elems) + else if (cur_elems.total_values <= max_elems) { - decltype(a.value) from; - from.swap(a.value, arena); - for (auto & node : b.value) - a.value.push_back(node->clone(arena), arena); - a.total_values = b.total_values; + decltype(cur_elems.value) from; + from.swap(cur_elems.value, arena); + for (auto & node : rhs_elems.value) + cur_elems.value.push_back(node->clone(arena), arena); + + cur_elems.total_values = rhs_elems.total_values; for (size_t i = 0; i < from.size(); ++i) - insert(a, from[i], arena); + insertWithSampler(cur_elems, from[i], arena); } else { - a.randomShuffle(); - a.total_values += b.total_values; + cur_elems.randomShuffle(); + cur_elems.total_values += rhs_elems.total_values; for (size_t i = 0; i < max_elems; ++i) { - UInt64 rnd = a.genRandom(a.total_values); - if (rnd < b.total_values) - a.value[i] = b.value[i]->clone(arena); + UInt64 rnd = cur_elems.genRandom(cur_elems.total_values); + if (rnd < rhs_elems.total_values) + cur_elems.value[i] = rhs_elems.value[i]->clone(arena); + } } } @@ -558,6 +633,9 @@ class GroupArrayGeneralImpl final for (auto & node : value) node->write(buf); + if constexpr (Trait::last) + DB::writeIntBinary(data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::writeIntBinary(data(place).total_values, buf); @@ -590,6 +668,9 @@ class GroupArrayGeneralImpl final for (UInt64 i = 0; i < elems; ++i) value[i] = Node::read(buf, arena); + if constexpr (Trait::last) + DB::readIntBinary(data(place).total_values, buf); + if constexpr (Trait::sampler == Sampler::RNG) { DB::readIntBinary(data(place).total_values, buf); diff --git a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp index 90cb39ca218..9ab0e43da34 100644 --- a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp +++ b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp @@ -32,7 +32,7 @@ void DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(const Aggreg /// TODO Make it sane. static const std::vector supported_functions{"any", "any_last", "min", "max", "sum", "sum_with_overflow", "group_bit_and", "group_bit_or", "group_bit_xor", - "sum_map", "min_map", "max_map", "group_array_array", "group_uniq_array_array", + "sum_map", "min_map", "max_map", "group_array_array", "group_array_last_array", "group_uniq_array_array", "sum_mapped_arrays", "min_mapped_arrays", "max_mapped_arrays"}; // check function diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index be34a8220f8..20b8bc0e862 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -117,13 +117,21 @@ bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & sett Poco::Logger * getLogger() { return &Poco::Logger::get("ExpressionAnalyzer"); } /// proton: starts. +/// Need exact match because _array is a special combinator suffix +/// that would otherwise filter these functions incorrectly +static const std::unordered_set exact_match_functions = { + "group_array", + "group_uniq_array", + "group_array_last_array", +}; + void tryTranslateToParametricAggregateFunction( const ASTFunction * node, DataTypes & types, Array & parameters, Names & argument_names, ContextPtr context) { if (!parameters.empty() || argument_names.empty()) return; - if (AggregateFunctionCombinatorFactory::instance().tryFindSuffix(node->name)) + if (AggregateFunctionCombinatorFactory::instance().tryFindSuffix(node->name) && !exact_match_functions.contains(node->name)) return; assert(node->arguments); @@ -150,8 +158,7 @@ void tryTranslateToParametricAggregateFunction( /// Translate `top_k_exact(key, num[, with_count, limit_memory_size])` to `top_k_exact(num[, with_count, limit_memory_size])(key)` auto size = arguments.size(); if (size < 2 || size > 4) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 4 arguments.", node->name); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 4 arguments.", node->name); ASTPtr expression_list = std::make_shared(); expression_list->children.assign(arguments.begin() + 1, arguments.end()); @@ -166,8 +173,35 @@ void tryTranslateToParametricAggregateFunction( /// Translate `top_k_exact_weighted(key, weight, num, [, with_count, limit_memory_size])` to `top_k_exact_weighted(num[, with_count, limit_memory_size])(key, weighted)` auto size = arguments.size(); if (size < 3 || size > 5) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 5 arguments.", node->name); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 5 arguments.", node->name); + + ASTPtr expression_list = std::make_shared(); + expression_list->children.assign(arguments.begin() + 2, arguments.end()); + parameters = getAggregateFunctionParametersArray(expression_list, "", context); + + argument_names = {argument_names[0], argument_names[1]}; + types = {types[0], types[1]}; + } + else if (lower_name == "approx_top_k" || lower_name == "approx_top_k_count") + { + /// approx_top_k(key, k, reserved) to approx_top_k(k, reserved)(key) + auto size = arguments.size(); + if (size < 2 || size > 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 3 arguments.", node->name); + + ASTPtr expression_list = std::make_shared(); + expression_list->children.assign(arguments.begin() + 1, arguments.end()); + parameters = getAggregateFunctionParametersArray(expression_list, "", context); + + argument_names = {argument_names[0]}; + types = {types[0]}; + } + else if (lower_name == "approx_top_k_sum") + { + /// approx_top_k_sum(key, weight, k, reserved) to approx_top_sum(k, reserved)(key, weight) + auto size = arguments.size(); + if (size < 3 || size > 4) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 4 arguments.", node->name); ASTPtr expression_list = std::make_shared(); expression_list->children.assign(arguments.begin() + 2, arguments.end()); @@ -191,7 +225,6 @@ void tryTranslateToParametricAggregateFunction( for (size_t i = 2; i < arg_size; ++i) expression_list->children.push_back(arguments[i]); parameters = getAggregateFunctionParametersArray(expression_list, "", context); - } argument_names = {argument_names[0], argument_names[1]}; types = {types[0], types[1]}; @@ -267,8 +300,7 @@ void tryTranslateToParametricAggregateFunction( { /// Translate `largest_triangle_three_buckets(x, y, n)` to `largest_triangle_three_buckets(n)(x, y)` if (arguments.size() != 3) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 arguments", node->name); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 arguments", node->name); ASTPtr expression_list = std::make_shared(); expression_list->children.emplace_back(arguments.back()); @@ -277,6 +309,22 @@ void tryTranslateToParametricAggregateFunction( argument_names.pop_back(); types.pop_back(); } + else if (lower_name == "group_array") + { + if (arguments.size() != 1 && arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 1 or 2 arguments", node->name); + + if (arguments.size() == 2) + { + /// Translate `group_array(column, max_elems)` to `group_array(max_elems)(column)` + ASTPtr expression_list = std::make_shared(); + expression_list->children.push_back(arguments[1]); + parameters = getAggregateFunctionParametersArray(expression_list, "", context); + + argument_names.pop_back(); + types.pop_back(); + } + } else if (lower_name == "group_concat") { /// Translate `group_concat(expression, delimiter, limit)` to `group_concat(delimiter, limit)(expression)` @@ -298,6 +346,18 @@ void tryTranslateToParametricAggregateFunction( argument_names = {argument_names[0]}; types = {types[0]}; } + else if (lower_name.starts_with("group_array_last")) + { + if (arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 arguments", node->name); + + /// Translate `group_array_last(column, max_size)` to `group_array_last(max_size)(column)` + ASTPtr expression_list = std::make_shared(); + expression_list->children.push_back(arguments[1]); + parameters = getAggregateFunctionParametersArray(expression_list, "", context); + argument_names.pop_back(); + types.pop_back(); + } }; /// proton: starts. Add 'is_changelog_input' param to allow aggregate function being aware whether the input stream is a changelog diff --git a/tests/queries_ported/0_stateless/02520_group_array_last.reference b/tests/queries_ported/0_stateless/02520_group_array_last.reference new file mode 100644 index 00000000000..3d1c3efd352 --- /dev/null +++ b/tests/queries_ported/0_stateless/02520_group_array_last.reference @@ -0,0 +1,45 @@ +-- { echo } +-- NUMBER_OF_ARGUMENTS_DOESNT_MATCH +select group_array_last(number+1) from numbers(5); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select group_array_last_array([number+1]) from numbers(5); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +-- group_array_last by number +select group_array_last(number+1, 1) from numbers(5); +[5] +select group_array_last(number+1, 3) from numbers(5); +[4,5,3] +select group_array_last(number+1, 3) from numbers(10); +[10,8,9] +-- group_array_last by String +select group_array_last((number+1)::string, 3) from numbers(5); +['4','5','3'] +select group_array_last((number+1)::string, 3) from numbers(10); +['10','8','9'] +-- group_array_last_array +select group_array_last_array([1,2,3,4,5,6], 3); +[4,5,6] +select group_array_last_array(['1','2','3','4','5','6'], 3); +['4','5','6'] +-- group_array_last_merge +-- [10,8,9] + [10,8,9] => [10,10,9] => [10,10,8] => [9,10,8] +-- ^ ^ ^ ^^ +-- (position to insert at) +select group_array_last(number+1, 3) from (select * from numbers(10)); +[10,8,9] +select group_array_last((number+1)::string, 3) from (select * from numbers(10)); +['10','8','9'] +select group_array_last([number+1], 3) from (select * from numbers(10)); +[[10],[8],[9]] +select group_array_last(number+1, 100) from (select * from numbers(10)); +[1,2,3,4,5,6,7,8,9,10] +select group_array_last((number+1)::string, 100) from (select * from numbers(10)); +['1','2','3','4','5','6','7','8','9','10'] +select group_array_last([number+1], 100) from (select * from numbers(10)); +[[1],[2],[3],[4],[5],[6],[7],[8],[9],[10]] +-- SimpleAggregateFunction +create stream simple_agg_group_array_last_array (key int, value simple_aggregate_function(group_array_last_array(5), array(uint64))) engine=MergeTree() order by key; +insert into simple_agg_group_array_last_array (key, value) values (1, [1,2,3]), (1, [4,5,6]), (2, [4,5,6]), (2, [1,2,3]); +select sleep(3); +0 +select key, group_array_last_array(value, 5) from simple_agg_group_array_last_array group by key order by key; +1 [6,2,3,4,5] +2 [3,5,6,1,2] diff --git a/tests/queries_ported/0_stateless/02520_group_array_last.sql b/tests/queries_ported/0_stateless/02520_group_array_last.sql new file mode 100644 index 00000000000..960eca38818 --- /dev/null +++ b/tests/queries_ported/0_stateless/02520_group_array_last.sql @@ -0,0 +1,32 @@ +set query_mode='table'; +drop stream if exists simple_agg_group_array_last_array; + +-- { echo } +-- NUMBER_OF_ARGUMENTS_DOESNT_MATCH +select group_array_last(number+1) from numbers(5); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select group_array_last_array([number+1]) from numbers(5); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +-- group_array_last by number +select group_array_last(number+1, 1) from numbers(5); +select group_array_last(number+1, 3) from numbers(5); +select group_array_last(number+1, 3) from numbers(10); +-- group_array_last by String +select group_array_last((number+1)::string, 3) from numbers(5); +select group_array_last((number+1)::string, 3) from numbers(10); +-- group_array_last_array +select group_array_last_array([1,2,3,4,5,6], 3); +select group_array_last_array(['1','2','3','4','5','6'], 3); +-- group_array_last_merge +-- [10,8,9] + [10,8,9] => [10,10,9] => [10,10,8] => [9,10,8] +-- ^ ^ ^ ^^ +-- (position to insert at) +select group_array_last(number+1, 3) from (select * from numbers(10)); +select group_array_last((number+1)::string, 3) from (select * from numbers(10)); +select group_array_last([number+1], 3) from (select * from numbers(10)); +select group_array_last(number+1, 100) from (select * from numbers(10)); +select group_array_last((number+1)::string, 100) from (select * from numbers(10)); +select group_array_last([number+1], 100) from (select * from numbers(10)); +-- SimpleAggregateFunction +create stream simple_agg_group_array_last_array (key int, value simple_aggregate_function(group_array_last_array(5), array(uint64))) engine=MergeTree() order by key; +insert into simple_agg_group_array_last_array (key, value) values (1, [1,2,3]), (1, [4,5,6]), (2, [4,5,6]), (2, [1,2,3]); +select sleep(3); +select key, group_array_last_array(value, 5) from simple_agg_group_array_last_array group by key order by key; diff --git a/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.reference b/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.reference new file mode 100644 index 00000000000..0dfe1e75577 --- /dev/null +++ b/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.reference @@ -0,0 +1,10 @@ +[0,1,2,3,4] +[0,1,2] +[[1,2]] +[1] +[1,2,3] +[1,2,3,4,5] +[0,4,2,1,3] +[0,2,1] +[[1,2]] +[1] diff --git a/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.sql b/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.sql new file mode 100644 index 00000000000..227f3758b69 --- /dev/null +++ b/tests/queries_ported/0_stateless/99039_group_array_additional_arg_for_max_elem.sql @@ -0,0 +1,53 @@ +-- Using group_array with numbers +SELECT group_array(number) +FROM numbers(5); -- This will give you [0,1,2,3,4] + +-- With max_size limit +SELECT group_array(number, 3) +FROM numbers(5); -- This will give you [0,1,2] + + +SELECT group_array([1,2], 1); + +-- Instead, you could do: +SELECT group_array(number + 1, 1) +FROM numbers(2); -- This will give you [1] + +-- below test is not stable, so i will hide it. +-- SELECT group_array(x, 1) +-- FROM (SELECT 1 AS x UNION ALL SELECT 2); + + +drop stream if exists test99039; +create stream test99039(id int); +select sleep(3) FORMAT Null; +insert into test99039(id) values (1)(2)(3)(4)(5); +select sleep(3) FORMAT Null; +select group_array(id, 3) from table(test99039); +select group_array(id) from table(test99039); +select group_array(id, 0) from table(test99039); -- { serverError BAD_ARGUMENTS } +select group_array(id, -1) from table(test99039); -- { serverError BAD_ARGUMENTS } + + + + +-- Using group_uniq_array with numbers +SELECT group_uniq_array(number) +FROM numbers(5); -- This will give you [0,4,2,1,3] + +-- With max_size limit +SELECT group_uniq_array(number, 3) +FROM numbers(5); -- This will give you [0,2,1] + + +SELECT group_uniq_array([1,2], 1); + +-- Instead, you could do: +SELECT group_uniq_array(number + 1, 1) +FROM numbers(2); -- This will give you [1] + +-- below test is not stable, so i will hide it. +-- SELECT group_uniq_array(x, 1) +-- FROM (SELECT 1 AS x UNION ALL SELECT 2); + + From 97a32c4297738f6bb9051525441bf87c461cd429 Mon Sep 17 00:00:00 2001 From: Jasmine-ge Date: Thu, 20 Feb 2025 18:07:45 +0800 Subject: [PATCH 2/2] support group array sorted function Fix a crash and a leak in AggregateFunctionGroupArraySorted std::vector Fix memory like in groupArraySorted Fix multiple bugs in groupArraySorted Fix problem detected by ubsan Fix groupArraySorted documentation Revert Revert Add new aggregation function groupArraySorted() Updated implementation Fixed style check Fixed code review issues Fixed code review issues Revert Add new aggregation function groupArraySorted() created groupSortedArray by ..Moving --- base/base/sort.h | 89 +++- .../reference/grouparraysorted.md | 45 ++ .../AggregateFunctionGroupArraySorted.cpp | 431 ++++++++++++++++++ .../registerAggregateFunctions.cpp | 2 + src/Interpreters/ExpressionAnalyzer.cpp | 15 + tests/performance/group_array_sorted.xml | 31 ++ .../02841_group_array_sorted.reference | 12 + .../0_stateless/02841_group_array_sorted.sql | 41 ++ .../03008_groupSortedArray_field.reference | 3 + .../03008_groupSortedArray_field.sql | 6 + .../03094_grouparraysorted_memory.reference | 0 .../03094_grouparraysorted_memory.sql | 36 ++ .../02841_group_array_sorted.reference | 18 + .../0_stateless/02841_group_array_sorted.sql | 47 ++ .../03008_groupSortedArray_field.reference | 3 + .../03008_groupSortedArray_field.sql | 5 + .../03094_grouparraysorted_memory.reference | 0 .../03094_grouparraysorted_memory.sql | 34 ++ 18 files changed, 812 insertions(+), 6 deletions(-) create mode 100644 docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md create mode 100644 src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp create mode 100644 tests/performance/group_array_sorted.xml create mode 100644 tests/queries/0_stateless/02841_group_array_sorted.reference create mode 100644 tests/queries/0_stateless/02841_group_array_sorted.sql create mode 100644 tests/queries/0_stateless/03008_groupSortedArray_field.reference create mode 100644 tests/queries/0_stateless/03008_groupSortedArray_field.sql create mode 100644 tests/queries/0_stateless/03094_grouparraysorted_memory.reference create mode 100644 tests/queries/0_stateless/03094_grouparraysorted_memory.sql create mode 100644 tests/queries_ported/0_stateless/02841_group_array_sorted.reference create mode 100644 tests/queries_ported/0_stateless/02841_group_array_sorted.sql create mode 100644 tests/queries_ported/0_stateless/03008_groupSortedArray_field.reference create mode 100644 tests/queries_ported/0_stateless/03008_groupSortedArray_field.sql create mode 100644 tests/queries_ported/0_stateless/03094_grouparraysorted_memory.reference create mode 100644 tests/queries_ported/0_stateless/03094_grouparraysorted_memory.sql diff --git a/base/base/sort.h b/base/base/sort.h index 592a899a291..78533e1d455 100644 --- a/base/base/sort.h +++ b/base/base/sort.h @@ -2,27 +2,104 @@ #include +#ifndef NDEBUG +#include +#include +/** Same as libcxx std::__debug_less. Just without dependency on private part of standard library. + * Check that Comparator induce strict weak ordering. + */ +template +class DebugLessComparator +{ +public: + constexpr DebugLessComparator(Comparator & cmp_) + : cmp(cmp_) + {} + template + constexpr bool operator()(const LhsType & lhs, const RhsType & rhs) + { + bool lhs_less_than_rhs = cmp(lhs, rhs); + if (lhs_less_than_rhs) + assert(!cmp(rhs, lhs)); + return lhs_less_than_rhs; + } + template + constexpr bool operator()(LhsType & lhs, RhsType & rhs) + { + bool lhs_less_than_rhs = cmp(lhs, rhs); + if (lhs_less_than_rhs) + assert(!cmp(rhs, lhs)); + return lhs_less_than_rhs; + } +private: + Comparator & cmp; +}; +template +using ComparatorWrapper = DebugLessComparator; +template +void shuffle(RandomIt first, RandomIt last) +{ + static thread_local pcg64 rng(getThreadId()); + std::shuffle(first, last, rng); +} +#else +template +using ComparatorWrapper = Comparator; +#endif #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wold-style-cast" #include -template -void nth_element(RandomIt first, RandomIt nth, RandomIt last) +template +void nth_element(RandomIt first, RandomIt nth, RandomIt last, Compare compare) { - ::miniselect::floyd_rivest_select(first, nth, last); +#ifndef NDEBUG + ::shuffle(first, last); +#endif + + ComparatorWrapper compare_wrapper = compare; + ::miniselect::floyd_rivest_select(first, nth, last, compare_wrapper); + +#ifndef NDEBUG + ::shuffle(first, nth); + + if (nth != last) + ::shuffle(nth + 1, last); +#endif } template -void partial_sort(RandomIt first, RandomIt middle, RandomIt last) +void nth_element(RandomIt first, RandomIt nth, RandomIt last) { - ::miniselect::floyd_rivest_partial_sort(first, middle, last); + using value_type = typename std::iterator_traits::value_type; + using comparator = std::less; + + ::nth_element(first, nth, last, comparator()); } template void partial_sort(RandomIt first, RandomIt middle, RandomIt last, Compare compare) { - ::miniselect::floyd_rivest_partial_sort(first, middle, last, compare); +#ifndef NDEBUG + ::shuffle(first, last); +#endif + + ComparatorWrapper compare_wrapper = compare; + ::miniselect::floyd_rivest_partial_sort(first, middle, last, compare_wrapper); + +#ifndef NDEBUG + ::shuffle(middle, last); +#endif +} + +template +void partial_sort(RandomIt first, RandomIt middle, RandomIt last) +{ + using value_type = typename std::iterator_traits::value_type; + using comparator = std::less; + + ::partial_sort(first, middle, last, comparator()); } #pragma GCC diagnostic pop diff --git a/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md b/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md new file mode 100644 index 00000000000..9bee0c29e7a --- /dev/null +++ b/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md @@ -0,0 +1,45 @@ + --- + toc_priority: 112 + --- + + # groupArraySorted {#groupArraySorted} + + Returns an array with the first N items in ascending order. + + ``` sql + groupArraySorted(N)(column) + ``` + + **Arguments** + + - `N` – The number of elements to return. + + - `column` – The value (Integer, String, Float and other Generic types). + + **Example** + + Gets the first 10 numbers: + + ``` sql + SELECT groupArraySorted(10)(number) FROM numbers(100) + ``` + + ``` text + ┌─groupArraySorted(10)(number)─┐ + │ [0,1,2,3,4,5,6,7,8,9] │ + └──────────────────────────────┘ + ``` + + + Gets all the String implementations of all numbers in column: + + ``` sql +SELECT groupArraySorted(5)(str) FROM (SELECT toString(number) as str FROM numbers(5)); + + ``` + + ``` text +┌─groupArraySorted(5)(str)─┐ +│ ['0','1','2','3','4'] │ +└──────────────────────────┘ + ``` \ No newline at end of file diff --git a/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp b/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp new file mode 100644 index 00000000000..df99d026cb2 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp @@ -0,0 +1,431 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ + +struct Settings; + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; + extern const int TOO_LARGE_ARRAY_SIZE; +} + +namespace +{ + +enum class GroupArraySortedStrategy : uint8_t +{ + heap, + sort +}; + + +constexpr size_t group_array_sorted_sort_strategy_max_elements_threshold = 1000000; + +template +struct GroupArraySortedData +{ + static constexpr bool is_value_generic_field = std::is_same_v; + + using Allocator = MixedAlignedArenaAllocator; + using Array = typename std::conditional_t, PODArray>; + + static constexpr size_t partial_sort_max_elements_factor = 2; + + Array values; + + static bool compare(const T & lhs, const T & rhs) + { + if constexpr (is_value_generic_field) + { + return lhs < rhs; + } + else + { + return CompareHelper::less(lhs, rhs, -1); + } + } + + struct Comparator + { + bool operator()(const T & lhs, const T & rhs) + { + return compare(lhs, rhs); + } + }; + + ALWAYS_INLINE void heapReplaceTop() + { + size_t size = values.size(); + if (size < 2) + return; + + size_t child_index = 1; + + if (values.size() > 2 && compare(values[1], values[2])) + ++child_index; + + /// Check if we are in order + if (compare(values[child_index], values[0])) + return; + + size_t current_index = 0; + auto current = values[current_index]; + + do + { + /// We are not in heap-order, swap the parent with it's largest child. + values[current_index] = values[child_index]; + current_index = child_index; + + // Recompute the child based off of the updated parent + child_index = 2 * child_index + 1; + + if (child_index >= size) + break; + + if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1])) + { + /// Right child exists and is greater than left child. + ++child_index; + } + + /// Check if we are in order. + } while (!compare(values[child_index], current)); + + values[current_index] = current; + } + + ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena) + { + if constexpr (is_value_generic_field) + { + ::sort(values.begin(), values.end(), Comparator()); + } + else + { + bool try_sort = trySort(values.begin(), values.end(), Comparator()); + if (!try_sort) + RadixSort>::executeLSD(values.data(), values.size()); + } + + if (values.size() > max_elements) + resize(max_elements, arena); + } + + ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena) + { + if (values.size() < max_elements * partial_sort_max_elements_factor) + return; + + ::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator()); + resize(max_elements, arena); + } + + ALWAYS_INLINE void resize(size_t n, Arena * arena) + { + if constexpr (is_value_generic_field) + values.resize(n); + else + values.resize(n, arena); + } + + ALWAYS_INLINE void push_back(T && element, Arena * arena) + { + if constexpr (is_value_generic_field) + values.push_back(element); + else + values.push_back(element, arena); + } + + ALWAYS_INLINE void addElement(T && element, size_t max_elements, Arena * arena) + { + if constexpr (strategy == GroupArraySortedStrategy::heap) + { + if (values.size() >= max_elements) + { + /// Element is greater or equal than current max element, it cannot be in k min elements + if (!compare(element, values[0])) + return; + + values[0] = std::move(element); + heapReplaceTop(); + return; + } + + push_back(std::move(element), arena); + std::push_heap(values.begin(), values.end(), Comparator()); + } + else + { + push_back(std::move(element), arena); + partialSortAndLimitIfNeeded(max_elements, arena); + } + } + + ALWAYS_INLINE void insertResultInto(IColumn & to, size_t max_elements, Arena * arena) + { + auto & result_array = assert_cast(to); + auto & result_array_offsets = result_array.getOffsets(); + + sortAndLimit(max_elements, arena); + + result_array_offsets.push_back(result_array_offsets.back() + values.size()); + + if (values.empty()) + return; + + if constexpr (is_value_generic_field) + { + auto & result_array_data = result_array.getData(); + for (auto & value : values) + result_array_data.insert(value); + } + else + { + auto & result_array_data = assert_cast &>(result_array.getData()).getData(); + + size_t result_array_data_insert_begin = result_array_data.size(); + result_array_data.resize(result_array_data_insert_begin + values.size()); + + for (size_t i = 0; i < values.size(); ++i) + result_array_data[result_array_data_insert_begin + i] = values[i]; + } + } +}; + +template +using GroupArraySortedDataHeap = GroupArraySortedData; + +template +using GroupArraySortedDataSort = GroupArraySortedData; + +constexpr UInt64 aggregate_function_group_array_sorted_max_element_size = 0xFFFFFF; + +template +class GroupArraySorted final + : public IAggregateFunctionDataHelper> +{ +public: + explicit GroupArraySorted( + const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elements_) + : IAggregateFunctionDataHelper>( + {data_type_}, parameters_, std::make_shared(data_type_)) + , max_elements(max_elements_) + , serialization(data_type_->getDefaultSerialization()) + { + if (max_elements > aggregate_function_group_array_sorted_max_element_size) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Too large limit parameter for groupArraySorted aggregate function, it should not exceed {}", + aggregate_function_group_array_sorted_max_element_size); + } + + String getName() const override { return "group_array_sorted"; } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + if constexpr (std::is_same_v) + { + auto row_value = (*columns[0])[row_num]; + this->data(place).addElement(std::move(row_value), max_elements, arena); + } + else + { + auto row_value = assert_cast &>(*columns[0]).getData()[row_num]; + this->data(place).addElement(std::move(row_value), max_elements, arena); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & rhs_values = this->data(rhs).values; + for (auto rhs_element : rhs_values) + this->data(place).addElement(std::move(rhs_element), max_elements, arena); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + { + auto & values = this->data(place).values; + size_t size = values.size(); + writeVarUInt(size, buf); + + if constexpr (std::is_same_v) + { + for (const Field & element : values) + { + if (element.isNull()) + { + writeBinary(false, buf); + } + else + { + writeBinary(true, buf); + serialization->serializeBinary(element, buf, {}); + } + } + } + else + { + if constexpr (std::endian::native == std::endian::little) + { + buf.write(reinterpret_cast(values.data()), size * sizeof(values[0])); + } + else + { + for (const auto & element : values) + writeBinary(element, buf); + } + } + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena * arena) const override + { + size_t size = 0; + readVarUInt(size, buf); + + if (unlikely(size > max_elements)) + throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elements); + + auto & values = this->data(place).values; + + if constexpr (Data::is_value_generic_field) + { + values.resize(size); + for (Field & element : values) + { + bool has_value = false; + readBinary(has_value, buf); + if (has_value) + serialization->deserializeBinary(element, buf, {}); + } + } + else + { + values.resize(size, arena); + if constexpr (std::endian::native == std::endian::little) + { + buf.readStrict(reinterpret_cast(values.data()), size * sizeof(values[0])); + } + else + { + for (auto & element : values) + readBinary(element, buf); + } + } + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + this->data(place).insertResultInto(to, max_elements, arena); + } + + bool allocatesMemoryInArena() const override { return true; } + +private: + UInt64 max_elements; + SerializationPtr serialization; +}; + +template +using GroupArraySortedHeap = GroupArraySorted, T>; + +template +using GroupArraySortedSort = GroupArraySorted, T>; + +template