diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java index d4a04f7efd58e..da55668feced3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.ExpressionInterpreter; @@ -30,18 +31,18 @@ import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.Collection; @@ -58,6 +59,7 @@ import static com.facebook.presto.sql.tree.ComparisonExpression.Type.IS_DISTINCT_FROM; import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Type.OR; import static com.facebook.presto.util.ImmutableCollectors.toImmutableList; +import static com.facebook.presto.util.ImmutableCollectors.toImmutableMap; import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static java.util.Objects.requireNonNull; @@ -94,6 +96,7 @@ private static class Rewriter private final Session session; private final Map types; private final PlanNodeIdAllocator idAllocator; + private Map expressionAssignments; public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map types, PlanNodeIdAllocator idAllocator) { @@ -108,8 +111,10 @@ public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map context) { PlanNode source = context.rewrite(node.getSource()); - Map assignments = ImmutableMap.copyOf(Maps.transformValues(node.getAssignments(), this::simplifyExpression)); - return new ProjectNode(node.getId(), source, assignments); + expressionAssignments = node.getAssignments(); + Map simplifiedAssignments = expressionAssignments.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey(), entry -> simplifyExpression(entry.getValue()))); + return new ProjectNode(node.getId(), source, simplifiedAssignments); } @Override @@ -150,6 +155,11 @@ private Expression simplifyExpression(Expression expression) if (expression instanceof SymbolReference) { return expression; } + + if (expressionAssignments != null && types != null) { + RemoveIdentityCastContext removeIdentityCastContext = new RemoveIdentityCastContext(expressionAssignments, types); + expression = ExpressionTreeRewriter.rewriteWith(new RemoveIdentityCastsRewriter(), expression, removeIdentityCastContext); + } expression = ExpressionTreeRewriter.rewriteWith(new PushDownNegationsExpressionRewriter(), expression); expression = ExpressionTreeRewriter.rewriteWith(new ExtractCommonPredicatesExpressionRewriter(), expression, NodeContext.ROOT_NODE); IdentityHashMap expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */); @@ -271,4 +281,70 @@ private static List removeAll(Collection collection, Collection ele .collect(toImmutableList()); } } + + private static class RemoveIdentityCastContext + { + private final Map expressionAssignments; + private final Map typeAssignments; + + public RemoveIdentityCastContext(Map expressionAssignments, + Map typeAssignments) + { + requireNonNull(expressionAssignments); + requireNonNull(typeAssignments); + + this.expressionAssignments = expressionAssignments; + this.typeAssignments = typeAssignments; + } + + public Map getExpressionAssignments() + { + return expressionAssignments; + } + + public Map getTypeAssignments() + { + return typeAssignments; + } + } + + private static class RemoveIdentityCastsRewriter + extends ExpressionRewriter + { + @Override + public Expression rewriteExpression(Expression node, RemoveIdentityCastContext context, + ExpressionTreeRewriter treeRewriter) + { + Map typeAssignments = context.getTypeAssignments(); + for (Map.Entry expressionAssignment : context.getExpressionAssignments().entrySet()) { + Symbol assignmentSymbol = expressionAssignment.getKey(); + Expression expression = expressionAssignment.getValue(); + if (expression == node) { + if (!(expression instanceof Cast)) { + return expression; + } + + Expression expressionToCast = ((Cast) expression).getExpression(); + TypeSignature typeOfExpressionToCastTo = typeAssignments.get(assignmentSymbol).getTypeSignature(); + TypeSignature typeOfExpressionToCast; + if (expressionToCast instanceof SymbolReference) { + Symbol expressionSymbol = new Symbol(((SymbolReference) expressionToCast).getName()); + typeOfExpressionToCast = typeAssignments.get(expressionSymbol).getTypeSignature(); + } + else if (expressionToCast instanceof GenericLiteral) { + typeOfExpressionToCast = TypeSignature.parseTypeSignature(((GenericLiteral) expressionToCast).getType()); + } + else { + return expression; + } + + if (typeOfExpressionToCast.equals(typeOfExpressionToCastTo)) { + return expressionToCast; + } + return expression; + } + } + return node; + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/PlanTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/PlanTester.java new file mode 100644 index 0000000000000..596748eedf172 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/PlanTester.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import org.intellij.lang.annotations.Language; + +public interface PlanTester +{ + public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern); + + public void assertPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern); + + public Plan plan(@Language("SQL") String sql); +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index fd21aa19cece8..8bd3594926c0e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -27,6 +27,8 @@ import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; import java.util.List; @@ -54,10 +56,12 @@ import static org.testng.Assert.fail; public class TestLogicalPlanner + implements PlanTester { - private final LocalQueryRunner queryRunner; + private LocalQueryRunner queryRunner; - public TestLogicalPlanner() + @BeforeTest + public void setUp() { this.queryRunner = new LocalQueryRunner(testSessionBuilder() .setCatalog("local") @@ -72,7 +76,7 @@ public TestLogicalPlanner() @Test public void testJoin() { - assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey = o.orderkey", + assertPlanMatches("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey = o.orderkey", anyTree( join(INNER, ImmutableList.of(aliasPair("O", "L")), any( @@ -84,7 +88,7 @@ public void testJoin() @Test public void testUncorrelatedSubqueries() { - assertPlan("SELECT * FROM orders WHERE orderkey = (SELECT orderkey FROM lineitem ORDER BY orderkey LIMIT 1)", + assertPlanMatches("SELECT * FROM orders WHERE orderkey = (SELECT orderkey FROM lineitem ORDER BY orderkey LIMIT 1)", anyTree( join(INNER, ImmutableList.of(aliasPair("X", "Y")), project( @@ -94,7 +98,7 @@ public void testUncorrelatedSubqueries() anyTree( tableScan("lineitem").withSymbol("orderkey", "Y"))))))); - assertPlan("SELECT * FROM orders WHERE orderkey IN (SELECT orderkey FROM lineitem WHERE linenumber % 4 = 0)", + assertPlanMatches("SELECT * FROM orders WHERE orderkey IN (SELECT orderkey FROM lineitem WHERE linenumber % 4 = 0)", anyTree( filter("S", project( @@ -104,7 +108,7 @@ public void testUncorrelatedSubqueries() anyTree( tableScan("lineitem").withSymbol("orderkey", "Y"))))))); - assertPlan("SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)", + assertPlanMatches("SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)", anyTree( filter("NOT S", project( @@ -122,7 +126,7 @@ public void testPushDownJoinConditionConjunctsToInnerSideBasedOnInheritedPredica .put("name", singleValue(createVarcharType(25), utf8Slice("blah"))) .build(); - assertPlan( + assertPlanMatches( "SELECT nationkey FROM nation LEFT OUTER JOIN region " + "ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah'", anyTree( @@ -187,7 +191,7 @@ private void assertPlanContainsNoApplyOrJoin(String sql) @Test public void testCorrelatedSubqueries() { - assertPlan( + assertPlanMatches( "SELECT orderkey FROM orders WHERE 3 = (SELECT orderkey)", LogicalPlanner.Stage.OPTIMIZED, anyTree( @@ -200,7 +204,7 @@ public void testCorrelatedSubqueries() )))))); // double nesting - assertPlan( + assertPlanMatches( "SELECT orderkey FROM orders o " + "WHERE 3 IN (SELECT o.custkey FROM lineitem l WHERE (SELECT l.orderkey = o.orderkey))", LogicalPlanner.Stage.OPTIMIZED, @@ -218,21 +222,26 @@ public void testCorrelatedSubqueries() )))))))); } - private void assertPlan(String sql, PlanMatchPattern pattern) + public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern) { - assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern); + throw new UnsupportedOperationException("assertPlanDoesNotMatch() is not supported"); } - private void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern) + public void assertPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern) + { + assertPlanMatches(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern); + } + + private void assertPlanMatches(@Language("SQL") String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern) { Plan actualPlan = plan(sql, stage); queryRunner.inTransaction(transactionSession -> { - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); return null; }); } - private Plan plan(String sql) + public Plan plan(String sql) { return plan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 310d5d7eea566..cf8c662546090 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.sql.tree.AstVisitor; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.GenericLiteral; @@ -115,6 +116,16 @@ else if (expression instanceof GenericLiteral) { } } + @Override + protected Boolean visitCast(Cast actual, Expression expectedExpession) + { + if (expectedExpession instanceof Cast) { + Cast expected = (Cast) expectedExpession; + return process(actual.getExpression(), expected.getExpression()); + } + return false; + } + @Override protected Boolean visitStringLiteral(StringLiteral actual, Expression expectedExpression) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index 5708beb408128..1b702f1de4d02 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -20,20 +20,37 @@ import static com.facebook.presto.sql.planner.PlanPrinter.textLogicalPlan; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public final class PlanAssert { private PlanAssert() {} - public static void assertPlan(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern) + public static void assertPlanMatches(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern) + { + assertPlan(session, metadata, actual, pattern, true); + } + + public static void assertPlanDoesNotMatch(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern) + { + assertPlan(session, metadata, actual, pattern, false); + } + + private static void assertPlan(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern, boolean expectedMatch) { requireNonNull(actual, "root is null"); - boolean matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata), new PlanMatchingContext(pattern)); - if (!matches) { + boolean actualMatch = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata), new PlanMatchingContext(pattern)); + if (expectedMatch != actualMatch) { String logicalPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, session); - assertTrue(matches, format("Plan does not match:\n %s\n, to pattern:\n%s", logicalPlan, pattern)); + String errorMessage; + if (expectedMatch) { + errorMessage = format("Plan does not match:\n%s\nto pattern:\n%s", logicalPlan, pattern); + } + else { + errorMessage = format("Plan matches:\n%s\nto pattern:\n%s", logicalPlan, pattern); + } + fail(errorMessage); } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 207436d364eb3..d44dbcd85fc62 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -38,6 +39,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import static com.facebook.presto.util.ImmutableCollectors.toImmutableList; import static com.google.common.base.Preconditions.checkState; @@ -147,6 +149,11 @@ List matches(PlanNode node, Session session, Metadata metadat return states.build(); } + public PlanMatchPattern withAssignments(Map assignments) + { + return with(new ProjectNodeMatcher(assignments)); + } + public PlanMatchPattern withSymbol(String pattern, String alias) { return with(new SymbolMatcher(pattern, alias)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ProjectNodeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ProjectNodeMatcher.java new file mode 100644 index 0000000000000..577749edf247c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ProjectNodeMatcher.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Map; +import java.util.regex.Pattern; + +final class ProjectNodeMatcher + implements Matcher +{ + private final Map assignments; + + ProjectNodeMatcher(Map assignments) + { + this.assignments = assignments; + } + + @Override + public boolean matches(PlanNode node, Session session, Metadata metadata, ExpressionAliases expressionAliases) + { + if (node instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) node; + if (projectNode.getAssignments().equals(assignments)) { + return true; + } + } + return false; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeIdenticalWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeIdenticalWindows.java index a2a9a25d82238..eaf3caee58bc6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeIdenticalWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeIdenticalWindows.java @@ -126,7 +126,7 @@ public void testMergeableWindowsAllOptimizers() Plan actualPlan = queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql)); queryRunner.inTransaction(transactionSession -> { - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); return null; }); } @@ -221,7 +221,7 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter { Plan actualPlan = unitPlan(sql); queryRunner.inTransaction(transactionSession -> { - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java index 07a68a85b72eb..c1a3938d35eaf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java @@ -13,20 +13,42 @@ */ package com.facebook.presto.sql.planner.optimizations; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.PlanTester; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.assertions.PlanAssert; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import javax.inject.Provider; + +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -37,14 +59,39 @@ import static com.facebook.presto.sql.ExpressionUtils.binaryExpression; import static com.facebook.presto.sql.ExpressionUtils.extractPredicates; import static com.facebook.presto.sql.ExpressionUtils.rewriteQualifiedNamesToSymbolReferences; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; public class TestSimplifyExpressions + implements PlanTester { private static final SqlParser SQL_PARSER = new SqlParser(); private static final SimplifyExpressions SIMPLIFIER = new SimplifyExpressions(createTestMetadataManager(), SQL_PARSER); + private LocalQueryRunner queryRunner; + private Map inputAssignments; + private Map expectedAssignments; + private Map typeAssignments; + + @BeforeTest + public void setUp() + { + this.queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .build()); + + queryRunner.createCatalog(queryRunner.getDefaultSession().getCatalog().get(), + new TpchConnectorFactory(queryRunner.getNodeManager(), 1), + ImmutableMap.of()); + + inputAssignments = new HashMap(); + expectedAssignments = new HashMap(); + typeAssignments = new HashMap(); + } @Test public void testPushesDownNegations() @@ -88,16 +135,108 @@ public void testExtractsCommonPredicate() assertSimplifies("(A AND B AND C AND D) OR (A AND B AND E) OR (A AND F)", "A AND ((B AND C AND D) OR (B AND E) OR F)"); } + @Test + public void testRemoveIdentityCastBigintLiteral() + { + inputAssignments.put(new Symbol("expr"), new Cast(new GenericLiteral("BIGINT", "5"), "BIGINT")); + expectedAssignments.put(new Symbol("expr"), new GenericLiteral("BIGINT", "5")); + typeAssignments.put(new Symbol("expr"), BigintType.BIGINT); + ProjectNode simplifiedProjectNode = simplifyProjectNode(inputAssignments, typeAssignments); + Map actualAssignments = simplifiedProjectNode.getAssignments(); + assert (actualAssignments.equals(expectedAssignments)); + } + + @Test + public void testRemoveIdentityCastMultiplyBigintLiteral() + { + inputAssignments.put(new Symbol("expr"), + new Cast(new ArithmeticBinaryExpression(BigintType.BIGINT, + new GenericLiteral("BIGINT", "3"), + ), + "BIGINT")) + } + + /* + assertCastNotInPlan("SELECT 3 * CAST(BIGINT '5' as BIGINT)"); + assertCastNotInPlan("SELECT CAST(nationkey AS BIGINT) FROM nation"); + assertCastNotInPlan("SELECT 3 * CAST(nationkey AS BIGINT) FROM nation"); + assertCastNotInPlan("SELECT CAST(nationkey AS BIGINT) FROM nation WHERE nationkey > 10 " + + "AND nationkey < 20"); + assertCastNotInPlan("SELECT CAST(COUNT(*) AS BIGINT) FROM nation WHERE nationkey > 10 " + + "AND nationkey < 20"); + assertCastNotInPlan("SELECT CAST(COUNT(*) AS BIGINT) FROM nation WHERE nationkey > 10 " + + "AND nationkey < 20"); + assertCastNotInPlan("SELECT CAST(COUNT(*) AS BIGINT) AS count FROM nation WHERE nationkey > 10 " + + "AND nationkey < 20 GROUP BY regionkey HAVING COUNT(*) > 1 ORDER BY count ASC"); + assertCastNotInPlan("SELECT CAST(name AS VARCHAR(25)) FROM nation"); + + assertCastInPlan("SELECT CAST(nationkey AS SMALLINT) FROM nation"); + assertCastInPlan("SELECT CAST(name AS VARCHAR(30)) FROM nation"); + assertCastInPlan("SELECT CAST(name AS VARCHAR) FROM nation");*/ + + + private void assertRemovesIdentityCast() + + public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern) + { + Plan actualPlan = plan(sql); + queryRunner.inTransaction(transactionSession -> { + PlanAssert.assertPlanDoesNotMatch(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + return null; + }); + } + + public void assertPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern) + { + Plan actualPlan = plan(sql); + queryRunner.inTransaction(transactionSession -> { + PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + return null; + }); + } + + public Plan plan(@Language("SQL") String sql) + { + return queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql)); + } + + private void assertUnitPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern) + { + Plan actualPlan = unitPlan(sql); + queryRunner.inTransaction(transactionSession -> { + PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + return null; + }); + } + + private Plan unitPlan(@Language("SQL") String sql) + { + FeaturesConfig featuresConfig = new FeaturesConfig() + .setExperimentalSyntaxEnabled(true) + .setDistributedIndexJoinsEnabled(false) + .setOptimizeHashGeneration(true); + Metadata metadata = new MetadataManager(featuresConfig, + new TypeRegistry(), + ) + Provider> optimizerProvider = () -> ImmutableList.of( + new UnaliasSymbolReferences(), + new PruneIdentityProjections(), + new MergeIdenticalWindows(), + new PruneUnreferencedOutputs(), + new SimplifyExpressions()); + return queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql, featuresConfig, optimizerProvider)); + } + private static void assertSimplifies(String expression, String expected) { Expression actualExpression = rewriteQualifiedNamesToSymbolReferences(SQL_PARSER.createExpression(expression)); Expression expectedExpression = rewriteQualifiedNamesToSymbolReferences(SQL_PARSER.createExpression(expected)); assertEquals( - normalize(simplifyExpressions(actualExpression)), + normalize(simplifyFilterNode(actualExpression)), normalize(expectedExpression)); } - private static Expression simplifyExpressions(Expression expression) + private static Expression simplifyFilterNode(Expression expression) { PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); FilterNode filterNode = new FilterNode( @@ -112,6 +251,19 @@ private static Expression simplifyExpressions(Expression expression) return simplifiedNode.getPredicate(); } + private static ProjectNode simplifyProjectNode(Map assignments, Map typeAssignments) + { + PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); + ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), + new ValuesNode(planNodeIdAllocator.getNextId(), emptyList(), emptyList()), + assignments); + return (ProjectNode) SIMPLIFIER.optimize(projectNode, + TEST_SESSION, + typeAssignments, + new SymbolAllocator(), + planNodeIdAllocator); + } + private static Map booleanSymbolTypeMapFor(Expression expression) { return DependencyExtractor.extractUnique(expression).stream()