Skip to content

Commit 5bccf77

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
TDigest Aggregate Fuzzer Test (facebookincubator#13301)
Summary: Pull Request resolved: facebookincubator#13301 Reviewed By: kagamiori Differential Revision: D74505772 fbshipit-source-id: 57bf7e566426c1eb170ea4a9322337e739f2ed18
1 parent 4134ca9 commit 5bccf77

File tree

4 files changed

+279
-6
lines changed

4 files changed

+279
-6
lines changed

velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
#include "velox/functions/prestosql/fuzzer/NoisySumResultVerifier.h"
4242
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
4343
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
44+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
45+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
4446
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
4547
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
4648
#include "velox/vector/fuzzer/VectorFuzzer.h"
@@ -86,6 +88,7 @@ getCustomInputGenerators() {
8688
{"approx_distinct", std::make_shared<ApproxDistinctInputGenerator>()},
8789
{"approx_set", std::make_shared<ApproxDistinctInputGenerator>()},
8890
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
91+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
8992
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
9093
{"map_union_sum", std::make_shared<MapUnionSumInputGenerator>()},
9194
{"noisy_count_if_gaussian",
@@ -145,9 +148,6 @@ int main(int argc, char** argv) {
145148
// Skip non-deterministic functions.
146149
// https://github.com/facebookincubator/velox/issues/13547
147150
"merge",
148-
// Will be added in follow up PR:
149-
// https://github.com/facebookincubator/velox/pull/13301
150-
"tdigest_agg",
151151
};
152152

153153
static const std::unordered_set<std::string> functionsRequireSortedInput = {
@@ -165,6 +165,7 @@ int main(int argc, char** argv) {
165165
using facebook::velox::exec::test::NoisySumResultVerifier;
166166
using facebook::velox::exec::test::QDigestAggResultVerifier;
167167
using facebook::velox::exec::test::setupReferenceQueryRunner;
168+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
168169
using facebook::velox::exec::test::TransformResultVerifier;
169170

170171
auto makeArrayVerifier = []() {
@@ -192,6 +193,7 @@ int main(int argc, char** argv) {
192193
{"approx_set", std::make_shared<ApproxDistinctResultVerifier>(true)},
193194
{"approx_percentile",
194195
std::make_shared<ApproxPercentileResultVerifier>()},
196+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
195197
{"qdigest_agg", std::make_shared<QDigestAggResultVerifier>()},
196198
{"arbitrary", std::make_shared<ArbitraryResultVerifier>()},
197199
{"any_value", nullptr},
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <boost/random/uniform_int_distribution.hpp>
19+
#include <boost/random/uniform_real_distribution.hpp>
20+
21+
#include "velox/exec/fuzzer/InputGenerator.h"
22+
#include "velox/vector/fuzzer/VectorFuzzer.h"
23+
24+
namespace facebook::velox::exec::test {
25+
26+
class TDigestAggregateInputGenerator : public InputGenerator {
27+
public:
28+
std::vector<VectorPtr> generate(
29+
const std::vector<TypePtr>& types,
30+
VectorFuzzer& fuzzer,
31+
FuzzerGenerator& rng,
32+
memory::MemoryPool* pool) override {
33+
VELOX_CHECK_GE(types.size(), 1);
34+
VELOX_CHECK_LE(types.size(), 3);
35+
36+
std::vector<VectorPtr> inputs;
37+
inputs.reserve(types.size());
38+
39+
// Values vector
40+
VELOX_CHECK(types[0]->isDouble());
41+
auto valuesVector = fuzzer.fuzz(types[0]);
42+
inputs.push_back(valuesVector);
43+
44+
// Weight is optional
45+
if (types.size() > 1) {
46+
VELOX_CHECK(types[1]->isBigint());
47+
auto weightsVector = fuzzer.fuzz(types[1]);
48+
inputs.push_back(weightsVector);
49+
}
50+
51+
// Compression is optional
52+
if (types.size() > 2) {
53+
VELOX_CHECK(types[2]->isDouble());
54+
const auto size = fuzzer.getOptions().vectorSize;
55+
// Make sure to use the same value of 'compression' for all batches in a
56+
// given Fuzzer iteration.
57+
if (!compression_.has_value()) {
58+
boost::random::uniform_real_distribution<double> dist(10.0, 1000.0);
59+
compression_ = dist(rng);
60+
}
61+
inputs.push_back(BaseVector::createConstant(
62+
DOUBLE(), compression_.value(), size, pool));
63+
}
64+
65+
return inputs;
66+
}
67+
68+
void reset() override {
69+
compression_.reset();
70+
}
71+
72+
private:
73+
std::optional<double> compression_;
74+
};
75+
76+
} // namespace facebook::velox::exec::test
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/core/PlanNode.h"
19+
#include "velox/exec/fuzzer/ResultVerifier.h"
20+
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
21+
#include "velox/exec/tests/utils/PlanBuilder.h"
22+
#include "velox/functions/lib/TDigest.h"
23+
#include "velox/vector/ComplexVector.h"
24+
25+
namespace facebook::velox::exec::test {
26+
27+
class TDigestAggregateResultVerifier : public ResultVerifier {
28+
public:
29+
bool supportsCompare() override {
30+
return true;
31+
}
32+
33+
bool supportsVerify() override {
34+
return false;
35+
}
36+
37+
void initialize(
38+
const std::vector<RowVectorPtr>& /*input*/,
39+
const std::vector<core::ExprPtr>& /*projections*/,
40+
const std::vector<std::string>& groupingKeys,
41+
const core::AggregationNode::Aggregate& aggregate,
42+
const std::string& aggregateName) override {
43+
keys_ = groupingKeys;
44+
resultName_ = aggregateName;
45+
46+
// Check TDigest types
47+
validateTDigestTypes(aggregate.call);
48+
}
49+
50+
void initializeWindow(
51+
const std::vector<RowVectorPtr>& /*input*/,
52+
const std::vector<core::ExprPtr>& /*projections*/,
53+
const std::vector<std::string>& /*partitionByKeys*/,
54+
const std::vector<SortingKeyAndOrder>& /*sortingKeysAndOrders*/,
55+
const core::WindowNode::Function& function,
56+
const std::string& /*frame*/,
57+
const std::string& windowName) override {
58+
keys_ = {"row_number"};
59+
resultName_ = windowName;
60+
61+
// Check TDigest types
62+
validateTDigestTypes(function.functionCall);
63+
}
64+
65+
bool compare(const RowVectorPtr& result, const RowVectorPtr& altResult)
66+
override {
67+
VELOX_CHECK_EQ(result->size(), altResult->size());
68+
69+
auto projection = keys_;
70+
projection.push_back(resultName_);
71+
72+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
73+
auto builder = PlanBuilder(planNodeIdGenerator).values({result});
74+
if (!keys_.empty()) {
75+
builder = builder.orderBy(keys_, false);
76+
}
77+
auto sortByKeys = builder.project(projection).planNode();
78+
auto sortedResult =
79+
AssertQueryBuilder(sortByKeys).copyResults(result->pool());
80+
81+
builder = PlanBuilder(planNodeIdGenerator).values({altResult});
82+
if (!keys_.empty()) {
83+
builder = builder.orderBy(keys_, false);
84+
}
85+
sortByKeys = builder.project(projection).planNode();
86+
auto sortedAltResult =
87+
AssertQueryBuilder(sortByKeys).copyResults(altResult->pool());
88+
89+
VELOX_CHECK_EQ(sortedResult->size(), sortedAltResult->size());
90+
auto size = sortedResult->size();
91+
for (auto i = 0; i < size; i++) {
92+
auto resultIsNull = sortedResult->childAt(resultName_)->isNullAt(i);
93+
auto altResultIsNull = sortedAltResult->childAt(resultName_)->isNullAt(i);
94+
if (resultIsNull || altResultIsNull) {
95+
VELOX_CHECK(resultIsNull && altResultIsNull);
96+
continue;
97+
}
98+
99+
auto resultValue = sortedResult->childAt(resultName_)
100+
->as<SimpleVector<StringView>>()
101+
->valueAt(i);
102+
auto altResultValue = sortedAltResult->childAt(resultName_)
103+
->as<SimpleVector<StringView>>()
104+
->valueAt(i);
105+
if (resultValue == altResultValue) {
106+
continue;
107+
} else {
108+
checkEquivalentTDigest(resultValue, altResultValue);
109+
}
110+
}
111+
return true;
112+
}
113+
114+
bool verify(const RowVectorPtr& /*result*/) override {
115+
VELOX_UNSUPPORTED();
116+
}
117+
118+
void reset() override {
119+
keys_.clear();
120+
resultName_.clear();
121+
}
122+
123+
private:
124+
// Helper method to check TDigest input and return types
125+
void validateTDigestTypes(const core::CallTypedExprPtr& call) const {
126+
// Check input type is double
127+
auto inputType = call->inputs()[0]->type();
128+
if (inputType->kind() != TypeKind::DOUBLE) {
129+
VELOX_FAIL(
130+
"TDigest only supports DOUBLE input type, got {}",
131+
inputType->toString());
132+
}
133+
auto returnType = call->type();
134+
if (returnType->kind() != TypeKind::VARBINARY) {
135+
VELOX_FAIL(
136+
"TDigest return type must be VARBINARY, got {}",
137+
returnType->toString());
138+
}
139+
}
140+
141+
void checkEquivalentTDigest(
142+
const StringView& result,
143+
const StringView& altResult) {
144+
// Create TDigests from serialized data
145+
facebook::velox::functions::TDigest<> resultTdigest;
146+
facebook::velox::functions::TDigest<> altResultTdigest;
147+
std::vector<int16_t> positions;
148+
149+
try {
150+
resultTdigest.mergeDeserialized(positions, result.data());
151+
resultTdigest.compress(positions);
152+
153+
positions.clear();
154+
altResultTdigest.mergeDeserialized(positions, altResult.data());
155+
altResultTdigest.compress(positions);
156+
} catch (const std::exception& e) {
157+
VELOX_FAIL("Failed to deserialize TDigest: {}", e.what());
158+
}
159+
160+
// Compare TDigest values at specific quantiles
161+
for (auto quantile : kQuantiles) {
162+
double resultQuantile = resultTdigest.estimateQuantile(quantile);
163+
double altResultQuantile = altResultTdigest.estimateQuantile(quantile);
164+
165+
variant resultVariant(resultQuantile);
166+
variant altResultVariant(altResultQuantile);
167+
VELOX_CHECK(
168+
resultVariant.equalsWithEpsilon(altResultVariant),
169+
"TDigest quantile values differ at {}: {} vs {}",
170+
quantile,
171+
resultQuantile,
172+
altResultQuantile);
173+
}
174+
}
175+
176+
static constexpr double kQuantiles[] = {
177+
0.01,
178+
0.05,
179+
0.1,
180+
0.25,
181+
0.50,
182+
0.75,
183+
0.9,
184+
0.95,
185+
0.99,
186+
};
187+
188+
std::vector<std::string> keys_;
189+
std::string resultName_;
190+
};
191+
192+
} // namespace facebook::velox::exec::test

velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h"
3030
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3131
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
32+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
33+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
3234
#include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h"
3335
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
3436
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -75,6 +77,7 @@ getCustomInputGenerators() {
7577
{"approx_distinct", std::make_shared<ApproxDistinctInputGenerator>()},
7678
{"approx_set", std::make_shared<ApproxDistinctInputGenerator>()},
7779
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
80+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
7881
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
7982
{"lead", std::make_shared<WindowOffsetInputGenerator>(1)},
8083
{"lag", std::make_shared<WindowOffsetInputGenerator>(1)},
@@ -126,9 +129,6 @@ int main(int argc, char** argv) {
126129
"noisy_sum_gaussian",
127130
// https://github.com/facebookincubator/velox/issues/13547
128131
"merge",
129-
// Will be added in follow up PR:
130-
// https://github.com/facebookincubator/velox/pull/13301
131-
"tdigest_agg",
132132
};
133133

134134
if (!FLAGS_presto_url.empty()) {
@@ -148,6 +148,7 @@ int main(int argc, char** argv) {
148148
using facebook::velox::exec::test::ApproxPercentileResultVerifier;
149149
using facebook::velox::exec::test::AverageResultVerifier;
150150
using facebook::velox::exec::test::QDigestAggResultVerifier;
151+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
151152

152153
static const std::unordered_map<
153154
std::string,
@@ -159,6 +160,7 @@ int main(int argc, char** argv) {
159160
{"approx_percentile",
160161
std::make_shared<ApproxPercentileResultVerifier>()},
161162
{"approx_most_frequent", nullptr},
163+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
162164
{"qdigest_agg", std::make_shared<QDigestAggResultVerifier>()},
163165
{"merge", nullptr},
164166
// Semantically inconsistent functions
@@ -196,6 +198,7 @@ int main(int argc, char** argv) {
196198
"max_by",
197199
"min_by",
198200
"multimap_agg",
201+
"tdigest_agg",
199202
"qdigest_agg",
200203
};
201204

0 commit comments

Comments
 (0)