Skip to content

Commit d908948

Browse files
committed
fix: Ensure ConstantTypedExpr and variant types are equivalent
1 parent 95bd54a commit d908948

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed

velox/core/Expressions.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ class ConstantTypedExpr : public ITypedExpr {
6161
// Variant::null() value is supported.
6262
ConstantTypedExpr(TypePtr type, Variant value)
6363
: ITypedExpr{ExprKind::kConstant, std::move(type)},
64-
value_{std::move(value)} {}
64+
value_{std::move(value)} {
65+
VELOX_CHECK(
66+
value_.isTypeCompatible(ITypedExpr::type()),
67+
"Expression type {} does not match variant type {}",
68+
ITypedExpr::type()->toString(),
69+
value_.inferType()->toString());
70+
}
6571

6672
// Creates constant expression of scalar or complex type. The value comes from
6773
// index zero.

velox/core/tests/ConstantTypedExprTest.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#include <gtest/gtest.h>
1717

18+
#include "velox/common/base/tests/GTestUtils.h"
1819
#include "velox/common/memory/Memory.h"
1920
#include "velox/core/Expressions.h"
2021
#include "velox/functions/prestosql/types/HyperLogLogType.h"
@@ -74,11 +75,10 @@ class ConstantTypedExprTest : public ::testing::Test,
7475
}
7576

7677
// Helper functions
77-
template <typename T>
7878
std::shared_ptr<ConstantTypedExpr> createVariantExpr(
7979
const TypePtr& type,
80-
const T& value) {
81-
return std::make_shared<ConstantTypedExpr>(type, variant(value));
80+
const Variant& value) {
81+
return std::make_shared<ConstantTypedExpr>(type, value);
8282
}
8383

8484
std::shared_ptr<ConstantTypedExpr> createNullVariantExpr(
@@ -579,4 +579,34 @@ TEST_F(ConstantTypedExprTest, toStringComplexTypes) {
579579
<< "toString mismatch for OPAQUE variant vs vector";
580580
}
581581

582+
TEST_F(ConstantTypedExprTest, variantTypeCheck) {
583+
auto testVariantExpr =
584+
[&](Variant value, const TypePtr type, const TypePtr expectedType) {
585+
VELOX_ASSERT_THROW(
586+
createVariantExpr(type, value),
587+
fmt::format(
588+
"Expression type {} does not match variant type {}",
589+
type->toString(),
590+
expectedType->toString()));
591+
if (type->isPrimitiveType()) {
592+
VELOX_ASSERT_THROW(
593+
createVariantExpr(type, Variant::null(expectedType->kind())),
594+
fmt::format(
595+
"Expression type {} does not match variant type {}",
596+
type->toString(),
597+
expectedType->toString()));
598+
}
599+
};
600+
601+
testVariantExpr("abc", INTEGER(), VARCHAR());
602+
testVariantExpr(variant(123LL), INTEGER(), BIGINT());
603+
testVariantExpr(2.0, BIGINT(), DOUBLE());
604+
testVariantExpr(
605+
variant::array({1, 2, 3}), ARRAY(VARCHAR()), ARRAY(INTEGER()));
606+
testVariantExpr(
607+
variant::map({{2.0, "xyz"}}),
608+
MAP(INTEGER(), VARCHAR()),
609+
MAP(DOUBLE(), VARCHAR()));
610+
}
611+
582612
} // namespace facebook::velox::core::test

velox/expression/tests/ExprRewriteRegistryTest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ TEST_F(ExprRewriteRegistryTest, basic) {
3131
};
3232
registry.registerRewrite(testRewrite);
3333

34-
auto input =
35-
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(123));
34+
auto input = std::make_shared<core::ConstantTypedExpr>(BIGINT(), 123LL);
3635
const auto rewritten = registry.rewrite(input);
3736
ASSERT_TRUE(rewritten->isCallKind());
3837
ASSERT_TRUE(rewritten->type()->isBigint());

velox/expression/tests/ExprTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,8 +2595,8 @@ TEST_P(ParameterizedExprTest, constantToString) {
25952595

25962596
TEST_F(ExprTest, constantEqualsNullConsistency) {
25972597
// Constant expr created using variant
2598-
auto nullVariantToExpr =
2599-
std::make_shared<core::ConstantTypedExpr>(VARCHAR(), Variant{});
2598+
auto nullVariantToExpr = std::make_shared<core::ConstantTypedExpr>(
2599+
VARCHAR(), variant::null(TypeKind::VARCHAR));
26002600
auto nonNullVariantToExpr =
26012601
std::make_shared<core::ConstantTypedExpr>(VARCHAR(), Variant{"test"});
26022602

0 commit comments

Comments
 (0)