diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/AbstractVarcharType.java b/presto-common/src/main/java/com/facebook/presto/common/type/AbstractVarcharType.java index 1e7a3a454f80f..56c40511eed49 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/AbstractVarcharType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/AbstractVarcharType.java @@ -18,9 +18,13 @@ import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.function.SqlFunctionProperties; import io.airlift.slice.Slice; +import io.airlift.slice.SliceUtf8; import io.airlift.slice.Slices; import java.util.Objects; +import java.util.Optional; + +import static java.lang.Character.MAX_CODE_POINT; public class AbstractVarcharType extends AbstractVariableWidthType @@ -106,6 +110,27 @@ public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int ri return leftBlock.compareTo(leftPosition, 0, leftLength, rightBlock, rightPosition, 0, rightLength); } + @Override + public Optional getRange() + { + if (length > 100) { + // The max/min values may be materialized in the plan, so we don't want them to be too large. + // Range comparison against large values are usually nonsensical, too, so no need to support them + // beyond a certain size. They specific choice above is arbitrary and can be adjusted if needed. + return Optional.empty(); + } + + int codePointSize = SliceUtf8.lengthOfCodePoint(MAX_CODE_POINT); + + Slice max = Slices.allocate(codePointSize * length); + int position = 0; + for (int i = 0; i < length; i++) { + position += SliceUtf8.setCodePointAt(MAX_CODE_POINT, max, position); + } + + return Optional.of(new Range(Slices.EMPTY_SLICE, max)); + } + @Override public void appendTo(Block block, int position, BlockBuilder blockBuilder) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/BigintType.java b/presto-common/src/main/java/com/facebook/presto/common/type/BigintType.java index 30101f86c5296..089079bfc1822 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/BigintType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/BigintType.java @@ -16,6 +16,8 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; public final class BigintType @@ -50,4 +52,10 @@ public int hashCode() { return getClass().hashCode(); } + + @Override + public Optional getRange() + { + return Optional.of(new Range(Long.MIN_VALUE, Long.MAX_VALUE)); + } } diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java b/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java index 877d14b916bba..c7b98e101aa76 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java @@ -21,6 +21,8 @@ import com.facebook.presto.common.block.UncheckedBlock; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.TypeUtils.doubleCompare; import static com.facebook.presto.common.type.TypeUtils.doubleEquals; @@ -170,4 +172,12 @@ public int hashCode() { return getClass().hashCode(); } + + @Override + public Optional getRange() + { + // The range for double is undefined because NaN is a special value that + // is *not* in any reasonable definition of a range for this type. + return Optional.empty(); + } } diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/IntegerType.java b/presto-common/src/main/java/com/facebook/presto/common/type/IntegerType.java index 41a5578c14cd5..98ae72604ab5e 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/IntegerType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/IntegerType.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static java.lang.String.format; @@ -54,6 +56,12 @@ else if (value < Integer.MIN_VALUE) { blockBuilder.writeInt((int) value).closeEntry(); } + @Override + public Optional getRange() + { + return Optional.of(new Range((long) Integer.MIN_VALUE, (long) Integer.MAX_VALUE)); + } + @Override @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") public boolean equals(Object other) diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java b/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java index 5bb530043b468..35203709e5ef4 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.TypeUtils.realCompare; import static com.facebook.presto.common.type.TypeUtils.realEquals; @@ -106,4 +108,12 @@ public int hashCode() { return getClass().hashCode(); } + + @Override + public Optional getRange() + { + // The range for real is undefined because NaN is a special value that + // is *not* in any reasonable definition of a range for this type. + return Optional.empty(); + } } diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/SmallintType.java b/presto-common/src/main/java/com/facebook/presto/common/type/SmallintType.java index 5cc1f41ee17c6..74f31e5618c52 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/SmallintType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/SmallintType.java @@ -22,6 +22,8 @@ import com.facebook.presto.common.block.UncheckedBlock; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static java.lang.Long.rotateLeft; import static java.lang.String.format; @@ -116,6 +118,12 @@ public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int ri return Short.compare(leftValue, rightValue); } + @Override + public Optional getRange() + { + return Optional.of(new Range((long) Short.MIN_VALUE, (long) Short.MAX_VALUE)); + } + @Override public void appendTo(Block block, int position, BlockBuilder blockBuilder) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java b/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java index b151f9291e61f..315b72898df36 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java @@ -22,6 +22,8 @@ import com.facebook.presto.common.block.UncheckedBlock; import com.facebook.presto.common.function.SqlFunctionProperties; +import java.util.Optional; + import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static java.lang.Long.rotateLeft; import static java.lang.String.format; @@ -116,6 +118,12 @@ public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int ri return Byte.compare(leftValue, rightValue); } + @Override + public Optional getRange() + { + return Optional.of(new Range((long) Byte.MIN_VALUE, (long) Byte.MAX_VALUE)); + } + @Override public void appendTo(Block block, int position, BlockBuilder blockBuilder) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/Type.java b/presto-common/src/main/java/com/facebook/presto/common/type/Type.java index 251b3b11c055d..c9d4a944094ca 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/Type.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/Type.java @@ -22,6 +22,9 @@ import io.airlift.slice.Slice; import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; public interface Type { @@ -194,4 +197,36 @@ default boolean equalValuesAreIdentical() * Compare the values in the specified block at the specified positions equal. */ int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition); + + /** + * Return the range of possible values for this type, if available. + * + * The type of the values must match {@link #getJavaType} + */ + default Optional getRange() + { + return Optional.empty(); + } + + final class Range + { + private final Object min; + private final Object max; + + public Range(Object min, Object max) + { + this.min = requireNonNull(min, "min is null"); + this.max = requireNonNull(max, "max is null"); + } + + public Object getMin() + { + return min; + } + + public Object getMax() + { + return max; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 25c69d0a47edb..f2858c86eae65 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -346,6 +346,7 @@ public final class SystemSessionProperties public static final String UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING = "utilize_unique_property_in_query_planning"; public static final String PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS = "pushdown_subfields_for_map_functions"; public static final String MAX_SERIALIZABLE_OBJECT_SIZE = "max_serializable_object_size"; + public static final String UNWRAP_CASTS = "unwrap_casts"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; @@ -1993,6 +1994,10 @@ public SystemSessionProperties( booleanProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "Add distinct aggregation below semi join build", featuresConfig.isAddDistinctBelowSemiJoinBuild(), + false), + booleanProperty(UNWRAP_CASTS, + "Enable optimization to unwrap CAST expression", + featuresConfig.isUnwrapCasts(), false)); } @@ -3394,4 +3399,9 @@ public static long getMaxSerializableObjectSize(Session session) { return session.getSystemProperty(MAX_SERIALIZABLE_OBJECT_SIZE, Long.class); } + + public static boolean isUnwrapCasts(Session session) + { + return session.getSystemProperty(UNWRAP_CASTS, Boolean.class); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 7214485a7d7a6..73d54bc24d51c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -321,6 +321,8 @@ public class FeaturesConfig private boolean builtInSidecarFunctionsEnabled; + private boolean unwrapCasts = true; + public enum PartitioningPrecisionStrategy { // Let Presto decide when to repartition @@ -3194,4 +3196,16 @@ public boolean isBuiltInSidecarFunctionsEnabled() { return this.builtInSidecarFunctionsEnabled; } + + public boolean isUnwrapCasts() + { + return unwrapCasts; + } + + @Config("optimizer.unwrap-casts") + public FeaturesConfig setUnwrapCasts(boolean unwrapCasts) + { + this.unwrapCasts = unwrapCasts; + return this; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 2d7c8be053645..e0bbc4e7e4f58 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -154,6 +154,7 @@ import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin; +import com.facebook.presto.sql.planner.iterative.rule.UnwrapCastInComparison; import com.facebook.presto.sql.planner.optimizations.AddExchanges; import com.facebook.presto.sql.planner.optimizations.AddExchangesForSingleNodeExecution; import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; @@ -340,6 +341,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .addAll(new SimplifyRowExpressions(metadata, expressionOptimizerManager).rules()) + .addAll(new UnwrapCastInComparison(metadata, expressionOptimizerManager).rules()) .add(new PruneRedundantProjectionAssignments()) .build()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/UnwrapCastInComparison.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/UnwrapCastInComparison.java new file mode 100644 index 0000000000000..1eb9a4e5ea792 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -0,0 +1,441 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.Utils; +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.expressions.LogicalRowExpressions; +import com.facebook.presto.expressions.RowExpressionRewriter; +import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.CastType; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.OperatorNotFoundException; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.InterpretedFunctionInvoker; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.planner.LiteralEncoder; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; +import com.facebook.presto.sql.relational.FunctionResolution; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isUnwrapCasts; +import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.expressions.LogicalRowExpressions.and; +import static com.facebook.presto.expressions.LogicalRowExpressions.or; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.comparisonExpression; +import static com.facebook.presto.sql.relational.Expressions.constantNull; +import static com.facebook.presto.sql.relational.Expressions.isComparison; +import static com.facebook.presto.sql.relational.Expressions.specialForm; +import static java.util.Objects.requireNonNull; + +/** + * Given s of type S, a constant expression t of type T, and when an implicit + * cast exists between S->T, converts expression of the form: + * + *
+ * CAST(s as T) = t
+ * 
+ * + * into + * + *
+ * s = CAST(t as S)
+ * 
+ * + * For example: + * + *
+ * CAST(x AS bigint) = bigint '1'
+ *
+ * + * turns into + * + *
+ * x = smallint '1'
+ * 
+ * + * It can simplify expressions that are known to be true or false, and + * remove the comparisons altogether. For example, give x::smallint, + * for an expression like: + * + *
+ * CAST(x AS bigint) > bigint '10000000'
+ *
+ */ +public class UnwrapCastInComparison + extends RowExpressionRewriteRuleSet +{ + public UnwrapCastInComparison(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) + { + super(createRewriter(metadata, expressionOptimizerManager)); + } + + private static PlanRowExpressionRewriter createRewriter(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) + { + requireNonNull(metadata, "metadata is null"); + requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + return (expression, context) -> UnWrapCastInComparisonRewriter.rewrite( + expression, + context.getSession(), + metadata, + expressionOptimizerManager); + } + + public static class UnWrapCastInComparisonRewriter + { + private UnWrapCastInComparisonRewriter() {} + + public static RowExpression rewrite(RowExpression expression, Session session, Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) + { + if (isUnwrapCasts(session)) { + RowExpression rewritten = RowExpressionTreeRewriter.rewriteWith(new Visitor(session, metadata), expression); + return expressionOptimizerManager.getExpressionOptimizer(session.toConnectorSession()).optimize(rewritten, SERIALIZABLE, session.toConnectorSession()); + } + return null; + } + + private static class Visitor + extends RowExpressionRewriter + { + private final Session session; + private final FunctionAndTypeManager functionAndTypeManager; + private final FunctionResolution functionResolution; + private final InterpretedFunctionInvoker functionInvoker; + + public Visitor(Session session, Metadata metadata) + { + this.session = requireNonNull(session, "session is null"); + this.functionAndTypeManager = metadata.getFunctionAndTypeManager(); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + this.functionInvoker = new InterpretedFunctionInvoker(functionAndTypeManager); + } + + @Override + public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) + { + if (!isComparison(node)) { + return null; + } + RowExpression expression = treeRewriter.defaultRewrite(node, context); + return unwrapCast(expression); + } + + private RowExpression unwrapCast(RowExpression rowExpression) + { + if (!(rowExpression instanceof CallExpression)) { + return null; + } + + CallExpression callExpression = (CallExpression) rowExpression; + if (!isComparison(callExpression) || callExpression.getArguments().size() != 2) { + return null; + } + + OperatorType operatorType = functionAndTypeManager.getFunctionMetadata(callExpression.getFunctionHandle()).getOperatorType().orElse(null); + if (operatorType == null) { + return null; + } + + RowExpression leftExpression = callExpression.getArguments().get(0); + RowExpression rightExpression = callExpression.getArguments().get(1); + + // Canonicalize Expression + if (leftExpression instanceof ConstantExpression && !(rightExpression instanceof ConstantExpression)) { + leftExpression = rightExpression; + rightExpression = callExpression.getArguments().get(0); + operatorType = OperatorType.flip(operatorType); + } + + if (!(leftExpression instanceof CallExpression)) { + return null; + } + + if (!functionResolution.isCastFunction(((CallExpression) leftExpression).getFunctionHandle())) { + return null; + } + + Object right = new RowExpressionInterpreter(rightExpression, functionAndTypeManager, session.toConnectorSession(), Level.OPTIMIZED).optimize(); + if (right == null || (right instanceof ConstantExpression && ((ConstantExpression) right).isNull())) { + switch (operatorType) { + case EQUAL: + case NOT_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return constantNull(BOOLEAN); + case IS_DISTINCT_FROM: + return call("NOT", functionResolution.notFunction(), BOOLEAN, specialForm(IS_NULL, BOOLEAN, leftExpression)); + default: + throw new UnsupportedOperationException("Not yet implemented"); + } + } + + RowExpression castArgument = ((CallExpression) leftExpression).getArguments().get(0); + + if (right instanceof RowExpression) { + return null; + } + + Type sourceType = castArgument.getType(); + Type targetType = rightExpression.getType(); + + if (!hasInjectiveImplicitCoercion(sourceType, targetType)) { + return null; + } + + FunctionHandle sourceToTarget = functionAndTypeManager.lookupCast(CastType.CAST, sourceType, targetType); + Optional sourceRange = sourceType.getRange(); + if (sourceRange.isPresent()) { + Object max = sourceRange.get().getMax(); + Object maxInTargetType = coerce(max, sourceToTarget); + + int upperBoundComparison = compare(targetType, right, maxInTargetType); + if (upperBoundComparison > 0) { + // larger than maximum representable value + switch (operatorType) { + case EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return falseIfNotNull(castArgument); + case NOT_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return trueIfNotNull(castArgument); + case IS_DISTINCT_FROM: + return LogicalRowExpressions.TRUE_CONSTANT; + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + + if (upperBoundComparison == 0) { + // equal to max representable value + switch (operatorType) { + case GREATER_THAN: + return falseIfNotNull(castArgument); + case GREATER_THAN_OR_EQUAL: + return comparisonExpression(functionResolution, EQUAL, castArgument, LiteralEncoder.toRowExpression(max, sourceType)); + case LESS_THAN_OR_EQUAL: + return trueIfNotNull(castArgument); + case LESS_THAN: + return comparisonExpression(functionResolution, NOT_EQUAL, castArgument, LiteralEncoder.toRowExpression(max, sourceType)); + case EQUAL: + case NOT_EQUAL: + case IS_DISTINCT_FROM: + return comparisonExpression(functionResolution, operatorType, castArgument, LiteralEncoder.toRowExpression(max, sourceType)); + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + + Object min = sourceRange.get().getMin(); + Object minInTargetType = coerce(min, sourceToTarget); + + int lowerBoundComparison = compare(targetType, right, minInTargetType); + if (lowerBoundComparison < 0) { + // smaller than minimum representable value + switch (operatorType) { + case NOT_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return trueIfNotNull(castArgument); + case EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return falseIfNotNull(castArgument); + case IS_DISTINCT_FROM: + return LogicalRowExpressions.TRUE_CONSTANT; + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + + if (lowerBoundComparison == 0) { + // equal to min representable value + switch (operatorType) { + case LESS_THAN: + return falseIfNotNull(castArgument); + case LESS_THAN_OR_EQUAL: + return comparisonExpression(functionResolution, EQUAL, castArgument, LiteralEncoder.toRowExpression(min, sourceType)); + case GREATER_THAN_OR_EQUAL: + return trueIfNotNull(castArgument); + case GREATER_THAN: + return comparisonExpression(functionResolution, NOT_EQUAL, castArgument, LiteralEncoder.toRowExpression(min, sourceType)); + case EQUAL: + case NOT_EQUAL: + case IS_DISTINCT_FROM: + return comparisonExpression(functionResolution, operatorType, castArgument, LiteralEncoder.toRowExpression(min, sourceType)); + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + } + + FunctionHandle targetToSource; + try { + targetToSource = functionAndTypeManager.lookupCast(CastType.CAST, targetType, sourceType); + } + catch (OperatorNotFoundException e) { + // Without a cast between target -> source, there's nothing more we can do + return null; + } + + Object literalInSourceType; + try { + literalInSourceType = coerce(right, targetToSource); + } + catch (PrestoException e) { + // A failure to cast from target -> source type could be because: + // 1. missing cast + // 2. bad implementation + // 3. out of range or otherwise unrepresentable value + // Since we can't distinguish between those cases, take the conservative option + // and bail out. + return null; + } + + Object roundtripLiteral = coerce(literalInSourceType, sourceToTarget); + + int literalVsRoundtripped = compare(targetType, right, roundtripLiteral); + + if (literalVsRoundtripped > 0) { + // cast rounded down + switch (operatorType) { + case EQUAL: + return falseIfNotNull(castArgument); + case NOT_EQUAL: + return trueIfNotNull(castArgument); + case IS_DISTINCT_FROM: + return LogicalRowExpressions.TRUE_CONSTANT; + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMin(), literalInSourceType) == 0) { + return comparisonExpression(functionResolution, EQUAL, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + } + return comparisonExpression(functionResolution, LESS_THAN_OR_EQUAL, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value + // larger than the next value in the source type + return comparisonExpression(functionResolution, GREATER_THAN, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + + if (literalVsRoundtripped < 0) { + // cast rounded up + switch (operatorType) { + case EQUAL: + return falseIfNotNull(castArgument); + case NOT_EQUAL: + return trueIfNotNull(castArgument); + case IS_DISTINCT_FROM: + return LogicalRowExpressions.TRUE_CONSTANT; + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value + // smaller than the next value in the source type + return comparisonExpression(functionResolution, LESS_THAN, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMax(), literalInSourceType) == 0) { + return comparisonExpression(functionResolution, EQUAL, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + } + return comparisonExpression(functionResolution, GREATER_THAN_OR_EQUAL, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + default: + throw new UnsupportedOperationException("Not yet implemented: " + operatorType); + } + } + return comparisonExpression(functionResolution, operatorType, castArgument, LiteralEncoder.toRowExpression(literalInSourceType, sourceType)); + } + + private boolean hasInjectiveImplicitCoercion(Type source, Type target) + { + if ((source.equals(BIGINT) && target.equals(DOUBLE)) || + (source.equals(BIGINT) && target.equals(REAL)) || + (source.equals(INTEGER) && target.equals(REAL))) { + // Not every BIGINT fits in DOUBLE/REAL due to 64 bit vs 53-bit/23-bit mantissa. Similarly, + // not every INTEGER fits in a REAL (32-bit vs 23-bit mantissa) + return false; + } + + if (source instanceof DecimalType) { + int precision = ((DecimalType) source).getPrecision(); + + if (precision > 15 && target.equals(DOUBLE)) { + // decimal(p,s) with p > 15 doesn't fit in a double without loss + return false; + } + + if (precision > 7 && target.equals(REAL)) { + // decimal(p,s) with p > 7 doesn't fit in a double without loss + return false; + } + } + + // Well-behaved implicit casts are injective + return functionResolution.canCoerce(source, target); + } + + private Object coerce(Object value, FunctionHandle coercion) + { + return functionInvoker.invoke(coercion, session.toConnectorSession().getSqlFunctionProperties(), value); + } + + private int compare(Type type, Object first, Object second) + { + return type.compareTo( + Utils.nativeValueToBlock(type, first), + 0, + Utils.nativeValueToBlock(type, second), + 0); + } + + private RowExpression falseIfNotNull(RowExpression argument) + { + return and(specialForm(IS_NULL, BOOLEAN, argument), constantNull(BOOLEAN)); + } + + private RowExpression trueIfNotNull(RowExpression argument) + { + return or(call("NOT", functionResolution.notFunction(), BOOLEAN, specialForm(IS_NULL, BOOLEAN, argument)), constantNull(BOOLEAN)); + } + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index 3aa2a3b544ca7..58cd97407a5d2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -438,4 +438,9 @@ public FunctionHandle lookupBuiltInFunction(String functionName, List inpu { return functionAndTypeResolver.lookupFunction(functionName, fromTypes(inputTypes)); } + + public boolean canCoerce(Type actualType, Type expectedType) + { + return functionAndTypeResolver.canCoerce(actualType, expectedType); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 9f55c3b2eab19..9ca4349276c5a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -271,7 +271,8 @@ public void testDefaults() .setInEqualityJoinPushdownEnabled(false) .setRewriteMinMaxByToTopNEnabled(false) .setPrestoSparkExecutionEnvironment(false) - .setMaxSerializableObjectSize(1000)); + .setMaxSerializableObjectSize(1000) + .setUnwrapCasts(true)); } @Test @@ -490,6 +491,7 @@ public void testExplicitPropertyMappings() .put("optimizer.utilize-unique-property-in-query-planning", "false") .put("optimizer.add-exchange-below-partial-aggregation-over-group-id", "true") .put("max_serializable_object_size", "50") + .put("optimizer.unwrap-casts", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -706,7 +708,8 @@ public void testExplicitPropertyMappings() .setRewriteMinMaxByToTopNEnabled(true) .setInnerJoinPushdownEnabled(true) .setPrestoSparkExecutionEnvironment(true) - .setMaxSerializableObjectSize(50); + .setMaxSerializableObjectSize(50) + .setUnwrapCasts(false); assertFullMapping(properties, expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestUnwrapCastInComparison.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestUnwrapCastInComparison.java new file mode 100644 index 0000000000000..386fd77f05c30 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestUnwrapCastInComparison.java @@ -0,0 +1,725 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import org.testng.annotations.Test; + +import java.util.Arrays; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.lang.String.format; + +public class TestUnwrapCastInComparison + extends BasePlanTest +{ + @Test + public void testEquals() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '1'", + anyTree( + filter("A = SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '1.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '1.9'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '32766'", + anyTree( + filter("A = SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '32766.9'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '32767'", + anyTree( + filter("A = SMALLINT '32767'", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '-32767'", + anyTree( + filter("A = SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '-32767.9'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '-32768'", + anyTree( + filter("A = SMALLINT '-32768'", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = DOUBLE '-32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + } + + @Test + public void testNotEquals() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '1'", + anyTree( + filter("A <> SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '1.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '1.9'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '32766'", + anyTree( + filter("A <> SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '32766.9'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '32767'", + anyTree( + filter("A <> SMALLINT '32767'", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '-32767'", + anyTree( + filter("A <> SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '-32767.9'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '-32768'", + anyTree( + filter("A <> SMALLINT '-32768'", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> DOUBLE '-32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + } + + @Test + public void testLessThan() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '1'", + anyTree( + filter("A < SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '1.1'", + anyTree( + filter("A <= SMALLINT '1'", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '1.9'", + anyTree( + filter("A < SMALLINT '2'", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '32766'", + anyTree( + filter("A < SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '32766.9'", + anyTree( + filter("A < SMALLINT '32767'", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '32767'", + anyTree( + filter("A <> SMALLINT '32767'", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '-32767'", + anyTree( + filter("A < SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '-32767.9'", + anyTree( + filter("A = SMALLINT '-32768'", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '-32768'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < DOUBLE '-32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + } + + @Test + public void testLessThanOrEqual() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '1'", + anyTree( + filter("A <= SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '1.1'", + anyTree( + filter("A <= SMALLINT '1'", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '1.9'", + anyTree( + filter("A < SMALLINT '2'", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '32766'", + anyTree( + filter("A <= SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '32766.9'", + anyTree( + filter("A < SMALLINT '32767'", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '32767'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '-32767'", + anyTree( + filter("A <= SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '-32767.9'", + anyTree( + filter("A = SMALLINT '-32768'", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '-32768'", + anyTree( + filter("A = SMALLINT '-32768'", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= DOUBLE '-32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + } + + @Test + public void testGreaterThan() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '1'", + anyTree( + filter("A > SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '1.1'", + anyTree( + filter("A > SMALLINT '1'", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '1.9'", + anyTree( + filter("A >= SMALLINT '2'", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '32766'", + anyTree( + filter("A > SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '32766.9'", + anyTree( + filter("A = SMALLINT '32767'", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '32767'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '-32767'", + anyTree( + filter("A > SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '-32767.9'", + anyTree( + filter("A > SMALLINT '-32768'", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '-32768'", + anyTree( + filter("A <> SMALLINT '-32768'", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > DOUBLE '-32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + } + + @Test + public void testGreaterThanOrEqual() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '1'", + anyTree( + filter("A >= SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '1.1'", + anyTree( + filter("A > SMALLINT '1'", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '1.9'", + anyTree( + filter("A >= SMALLINT '2'", + values("A")))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '32766'", + anyTree( + filter("A >= SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '32766.9'", + anyTree( + filter("A = SMALLINT '32767'", + values("A")))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '32767'", + anyTree( + filter("A = SMALLINT '32767'", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '32768.1'", + anyTree( + filter("A IS NULL AND NULL", + values("A")))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '-32767'", + anyTree( + filter("A >= SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '-32767.9'", + anyTree( + filter("A > SMALLINT '-32768' ", + values("A")))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '-32768'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= DOUBLE '-32768.1'", + anyTree( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + } + + @Test + public void testDistinctFrom() + { + // representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '1'", + anyTree( + filter("A IS DISTINCT FROM SMALLINT '1'", + values("A")))); + + // non-representable + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '1.1'", + output( + values("A"))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '1.9'", + output( + values("A"))); + + // below top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '32766'", + anyTree( + filter("A IS DISTINCT FROM SMALLINT '32766'", + values("A")))); + + // round to top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '32766.9'", + output( + values("A"))); + + // top of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '32767'", + anyTree( + filter("A IS DISTINCT FROM SMALLINT '32767'", + values("A")))); + + // above range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '32768.1'", + output( + values("A"))); + + // above bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '-32767'", + anyTree( + filter("A IS DISTINCT FROM SMALLINT '-32767'", + values("A")))); + + // round to bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '-32767.9'", + output( + values("A"))); + + // bottom of range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '-32768'", + anyTree( + filter("A IS DISTINCT FROM SMALLINT '-32768'", + values("A")))); + + // below range + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM DOUBLE '-32768.1'", + output( + values("A"))); + } + + @Test + public void testNull() + { + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a > CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a >= CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <= CAST(NULL AS DOUBLE)", + output( + filter("NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM CAST(NULL AS DOUBLE)", + output( + filter("NOT (CAST(A AS DOUBLE) IS NULL)", + values("A")))); + } + + @Test + public void smokeTests() + { + // smoke tests for various type combinations + for (String type : Arrays.asList("SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE")) { + assertPlan( + format("SELECT * FROM (VALUES TINYINT '1') t(a) WHERE a = %s '1'", type), + anyTree( + filter("A = TINYINT '1'", + values("A")))); + } + + for (String type : Arrays.asList("INTEGER", "BIGINT", "REAL", "DOUBLE")) { + assertPlan( + format("SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = %s '1'", type), + anyTree( + filter("A = SMALLINT '1'", + values("A")))); + } + + for (String type : Arrays.asList("BIGINT", "DOUBLE")) { + assertPlan( + format("SELECT * FROM (VALUES INTEGER '1') t(a) WHERE a = %s '1'", type), + anyTree( + filter("A = 1", + values("A")))); + } + + assertPlan("SELECT * FROM (VALUES REAL '1') t(a) WHERE a = DOUBLE '1'", + anyTree( + filter("A = REAL '1.0'", + values("A")))); + } + + @Test + public void testTermOrder() + { + // ensure the optimization works when the terms of the comparison are reversed + // vs the canonical form + assertPlan("SELECT * FROM (VALUES REAL '1') t(a) WHERE DOUBLE '1' = a", + anyTree( + filter("A = REAL '1.0'", + values("A")))); + } + + @Test + public void testNoEffect() + { + // BIGINT->DOUBLE implicit cast is not injective + assertPlan( + "SELECT * FROM (VALUES BIGINT '1') t(a) WHERE a = DOUBLE '1'", + anyTree( + filter("CAST(A AS DOUBLE) = 1e0", + values("A")))); + + // BIGINT->REAL implicit cast is not injective + assertPlan( + "SELECT * FROM (VALUES BIGINT '1') t(a) WHERE a = REAL '1'", + anyTree( + filter("CAST(A AS REAL) = REAL '1.0'", + values("A")))); + + // INTEGER->REAL implicit cast is not injective + assertPlan( + "SELECT * FROM (VALUES INTEGER '1') t(a) WHERE a = REAL '1'", + anyTree( + filter("CAST(A AS REAL) = REAL '1.0'", + values("A")))); + + // DECIMAL(p)->DOUBLE not injective for p > 15 + assertPlan( + "SELECT * FROM (VALUES CAST('1' AS DECIMAL(16))) t(a) WHERE a = DOUBLE '1'", + anyTree( + filter("CAST(A AS DOUBLE) = 1E0", + values("A")))); + + // DECIMAL(p)->REAL not injective for p > 7 + assertPlan( + "SELECT * FROM (VALUES CAST('1' AS DECIMAL(8))) t(a) WHERE a = REAL '1'", + anyTree( + filter("CAST(A AS REAL) = REAL '1.0'", + values("A")))); + + // no implicit cast between VARCHAR->INTEGER + assertPlan( + "SELECT * FROM (VALUES VARCHAR '1') t(a) WHERE CAST(a AS INTEGER) = INTEGER '1'", + anyTree( + filter("CAST(A AS INTEGER) = 1", + values("A")))); + + // no implicit cast between DOUBLE->INTEGER + assertPlan( + "SELECT * FROM (VALUES DOUBLE '1') t(a) WHERE CAST(a AS INTEGER) = INTEGER '1'", + anyTree( + filter("CAST(A AS INTEGER) = 1", + values("A")))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java index ec59272a234fd..5c03200640ab7 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java @@ -163,4 +163,11 @@ protected void executeExclusively(Runnable executionBlock) runner.getExclusiveLock().unlock(); } } + + public MaterializedResult execute(@Language("SQL") String query) + { + MaterializedResult actualResults; + actualResults = runner.execute(runner.getDefaultSession(), query).toTestTypes(); + return actualResults; + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestUnwrapCastInComparison.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestUnwrapCastInComparison.java new file mode 100644 index 0000000000000..272143acebba3 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestUnwrapCastInComparison.java @@ -0,0 +1,220 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.query; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestUnwrapCastInComparison +{ + private QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testTinyint() + { + for (Number from : Arrays.asList(null, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE)) { + String fromType = "TINYINT"; + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (Number to : Arrays.asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validate(operator, fromType, from, "SMALLINT", to); + } + + for (Number to : Arrays.asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validate(operator, fromType, from, "INTEGER", to); + } + + for (Number to : Arrays.asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validate(operator, fromType, from, "BIGINT", to); + } + + for (Number to : Arrays.asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validate(operator, fromType, from, "REAL", to); + } + + for (Number to : Arrays.asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validate(operator, fromType, from, "DOUBLE", to); + } + } + } + } + + @Test + public void testSmallint() + { + for (Number from : Arrays.asList(null, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE)) { + String fromType = "SMALLINT"; + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (Number to : Arrays.asList(null, Short.MIN_VALUE - 1, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE, Short.MAX_VALUE + 1)) { + validate(operator, fromType, from, "INTEGER", to); + } + + for (Number to : Arrays.asList(null, Short.MIN_VALUE - 1, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE, Short.MAX_VALUE + 1)) { + validate(operator, fromType, from, "BIGINT", to); + } + + for (Number to : Arrays.asList(null, Short.MIN_VALUE - 1, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE, Short.MAX_VALUE + 1)) { + validate(operator, fromType, from, "REAL", to); + } + + for (Number to : Arrays.asList(null, Short.MIN_VALUE - 1, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE, Short.MAX_VALUE + 1)) { + validate(operator, fromType, from, "DOUBLE", to); + } + } + } + } + + @Test + public void testInteger() + { + for (Number from : Arrays.asList(null, Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE)) { + String fromType = "INTEGER"; + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (Number to : Arrays.asList(null, Integer.MIN_VALUE - 1L, Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE, Integer.MAX_VALUE + 1L)) { + validate(operator, fromType, from, "BIGINT", to); + } + + for (Number to : Arrays.asList(null, Integer.MIN_VALUE - 1L, Integer.MIN_VALUE, 0, 0.1, 0.9, 1, Integer.MAX_VALUE, Integer.MAX_VALUE + 1L)) { + validate(operator, fromType, from, "DOUBLE", to); + } + } + } + } + + @Test + public void testReal() + { + String fromType = "REAL"; + String toType = "DOUBLE"; + + for (String from : toLiteral(fromType, Arrays.asList(null, Float.NEGATIVE_INFINITY, -Float.MAX_VALUE, 0, 0.1, 0.9, 1, Float.MAX_VALUE, Float.POSITIVE_INFINITY, Float.NaN))) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : toLiteral(toType, Arrays.asList(null, Double.NEGATIVE_INFINITY, Math.nextDown((double) -Float.MIN_VALUE), (double) -Float.MIN_VALUE, 0, 0.1, 0.9, 1, (double) Float.MAX_VALUE, Math.nextUp((double) Float.MAX_VALUE), Double.POSITIVE_INFINITY, Double.NaN))) { + validate(operator, fromType, from, toType, to); + } + } + } + } + + @Test + public void testDecimal() + { + // decimal(15) -> double + List values = ImmutableList.of("-999999999999999", "999999999999999"); + for (String from : values) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : values) { + validate(operator, "DECIMAL(15, 0)", from, "DOUBLE", Double.valueOf(to)); + } + } + } + + // decimal(16) -> double + values = ImmutableList.of("-9999999999999999", "9999999999999999"); + for (String from : values) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : values) { + validate(operator, "DECIMAL(16, 0)", from, "DOUBLE", Double.valueOf(to)); + } + } + } + + // decimal(7) -> real + values = ImmutableList.of("-999999", "999999"); + for (String from : values) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : values) { + validate(operator, "DECIMAL(7, 0)", from, "REAL", Double.valueOf(to)); + } + } + } + + // decimal(8) -> real + values = ImmutableList.of("-9999999", "9999999"); + for (String from : values) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : values) { + validate(operator, "DECIMAL(8, 0)", from, "REAL", Double.valueOf(to)); + } + } + } + } + + @Test + public void testVarchar() + { + for (String from : Arrays.asList(null, "''", "'a'", "'b'")) { + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : Arrays.asList(null, "''", "'a'", "'aa'", "'b'", "'bb'")) { + validate(operator, "VARCHAR(1)", from, "VARCHAR(2)", to); + } + } + } + + // type with no range + for (String operator : Arrays.asList("=", "<>", ">=", ">", "<=", "<", "IS DISTINCT FROM")) { + for (String to : Arrays.asList("'" + Strings.repeat("a", 200) + "'", "'" + Strings.repeat("b", 200) + "'")) { + validate(operator, "VARCHAR(200)", "'" + Strings.repeat("a", 200) + "'", "VARCHAR(300)", to); + } + } + } + + private void validate(String operator, String fromType, Object fromValue, String toType, Object toValue) + { + String query = format( + "SELECT (CAST(v AS %s) %s CAST(%s AS %s)) " + + "IS NOT DISTINCT FROM " + + "(CAST(%s AS %s) %s CAST(%s AS %s)) " + + "FROM (VALUES CAST(%s AS %s)) t(v)", + toType, operator, toValue, toType, + fromValue, toType, operator, toValue, toType, + fromValue, fromType); + + boolean result = (boolean) assertions.execute(query) + .getMaterializedRows() + .get(0) + .getField(0); + + assertTrue(result, "Query evaluated to false: " + query); + } + + private static List toLiteral(String type, List values) + { + return values.stream() + .map(value -> value == null ? "NULL" : type + "'" + value + "'") + .collect(toImmutableList()); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/TestVarcharType.java b/presto-main-base/src/test/java/com/facebook/presto/type/TestVarcharType.java index 9a102d9786458..f5a814c5dd619 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/TestVarcharType.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/TestVarcharType.java @@ -15,10 +15,15 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import org.testng.annotations.Test; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static org.testng.Assert.assertEquals; public class TestVarcharType extends AbstractTestType @@ -50,4 +55,23 @@ protected Object getGreaterValue(Object value) { return Slices.utf8Slice(((Slice) value).toStringUtf8() + "_"); } + + @Test + public void testRange() + { + VarcharType type = createVarcharType(5); + + Type.Range range = type.getRange().get(); + + String expectedMax = new StringBuilder() + .appendCodePoint(Character.MAX_CODE_POINT) + .appendCodePoint(Character.MAX_CODE_POINT) + .appendCodePoint(Character.MAX_CODE_POINT) + .appendCodePoint(Character.MAX_CODE_POINT) + .appendCodePoint(Character.MAX_CODE_POINT) + .toString(); + + assertEquals(Slices.utf8Slice(""), range.getMin()); + assertEquals(Slices.utf8Slice(expectedMax), range.getMax()); + } }