Skip to content

Commit a017702

Browse files
committed
fix: Ensure ConstantTypedExpr and variant types are equivalent
1 parent 7df0c83 commit a017702

File tree

4 files changed

+47
-8
lines changed

4 files changed

+47
-8
lines changed

velox/core/Expressions.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ class ConstantTypedExpr : public ITypedExpr {
6161
// Variant::null() value is supported.
6262
ConstantTypedExpr(TypePtr type, Variant value)
6363
: ITypedExpr{ExprKind::kConstant, std::move(type)},
64-
value_{std::move(value)} {}
64+
value_{std::move(value)} {
65+
VELOX_CHECK(
66+
value_.isTypeCompatible(ITypedExpr::type()),
67+
"Expression type {} does not match variant type {}",
68+
ITypedExpr::type()->toString(),
69+
value_.inferType()->toString());
70+
}
6571

6672
// Creates constant expression of scalar or complex type. The value comes from
6773
// index zero.

velox/core/tests/ConstantTypedExprTest.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#include <gtest/gtest.h>
1717

18+
#include "velox/common/base/tests/GTestUtils.h"
1819
#include "velox/common/memory/Memory.h"
1920
#include "velox/core/Expressions.h"
2021
#include "velox/functions/prestosql/types/HyperLogLogType.h"
@@ -74,11 +75,10 @@ class ConstantTypedExprTest : public ::testing::Test,
7475
}
7576

7677
// Helper functions
77-
template <typename T>
7878
std::shared_ptr<ConstantTypedExpr> createVariantExpr(
7979
const TypePtr& type,
80-
const T& value) {
81-
return std::make_shared<ConstantTypedExpr>(type, variant(value));
80+
const Variant& value) {
81+
return std::make_shared<ConstantTypedExpr>(type, value);
8282
}
8383

8484
std::shared_ptr<ConstantTypedExpr> createNullVariantExpr(
@@ -579,4 +579,38 @@ TEST_F(ConstantTypedExprTest, toStringComplexTypes) {
579579
<< "toString mismatch for OPAQUE variant vs vector";
580580
}
581581

582+
TEST_F(ConstantTypedExprTest, variantTypeCheck) {
583+
auto testVariantExpr = [&](const Variant& value,
584+
const TypePtr& type,
585+
const TypePtr& expectedType) {
586+
VELOX_ASSERT_THROW(
587+
createVariantExpr(type, value),
588+
fmt::format(
589+
"Expression type {} does not match variant type {}",
590+
type->toString(),
591+
expectedType->toString()));
592+
if (type->isPrimitiveType()) {
593+
VELOX_ASSERT_THROW(
594+
createVariantExpr(type, Variant::null(expectedType->kind())),
595+
fmt::format(
596+
"Expression type {} does not match variant type {}",
597+
type->toString(),
598+
expectedType->toString()));
599+
} else {
600+
ASSERT_NO_THROW(
601+
createVariantExpr(type, Variant::null(expectedType->kind())));
602+
}
603+
};
604+
605+
testVariantExpr("abc", INTEGER(), VARCHAR());
606+
testVariantExpr(variant(123LL), INTEGER(), BIGINT());
607+
testVariantExpr(2.0, BIGINT(), DOUBLE());
608+
testVariantExpr(
609+
variant::array({1, 2, 3}), ARRAY(VARCHAR()), ARRAY(INTEGER()));
610+
testVariantExpr(
611+
variant::map({{2.0, "xyz"}}),
612+
MAP(INTEGER(), VARCHAR()),
613+
MAP(DOUBLE(), VARCHAR()));
614+
}
615+
582616
} // namespace facebook::velox::core::test

velox/expression/tests/ExprRewriteRegistryTest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ TEST_F(ExprRewriteRegistryTest, basic) {
3131
};
3232
registry.registerRewrite(testRewrite);
3333

34-
auto input =
35-
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(123));
34+
auto input = std::make_shared<core::ConstantTypedExpr>(BIGINT(), 123LL);
3635
const auto rewritten = registry.rewrite(input);
3736
ASSERT_TRUE(rewritten->isCallKind());
3837
ASSERT_TRUE(rewritten->type()->isBigint());

velox/expression/tests/ExprTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,8 +2595,8 @@ TEST_P(ParameterizedExprTest, constantToString) {
25952595

25962596
TEST_F(ExprTest, constantEqualsNullConsistency) {
25972597
// Constant expr created using variant
2598-
auto nullVariantToExpr =
2599-
std::make_shared<core::ConstantTypedExpr>(VARCHAR(), Variant{});
2598+
auto nullVariantToExpr = std::make_shared<core::ConstantTypedExpr>(
2599+
VARCHAR(), variant::null(TypeKind::VARCHAR));
26002600
auto nonNullVariantToExpr =
26012601
std::make_shared<core::ConstantTypedExpr>(VARCHAR(), Variant{"test"});
26022602

0 commit comments

Comments
 (0)