diff --git a/pom.xml b/pom.xml index e733d6a8c3d60..5d09e062b9a53 100644 --- a/pom.xml +++ b/pom.xml @@ -588,6 +588,12 @@ 42.0.0 + + org.pcollections + pcollections + 2.1.2 + + org.antlr antlr4-runtime diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml index aff146e980a21..be416992230d2 100644 --- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml +++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml @@ -2,7 +2,7 @@ datasource: presto query-names: presto/tpcds/${query}.sql runs: 6 prewarm-runs: 2 -before-execution: sleep-4s, presto/session_set_reorder_joins.sql +before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql frequency: 7 database: hive tpcds_small: tpcds_10gb_orc @@ -12,22 +12,22 @@ variables: 1: query: q01,q06,q14_1,q39_1,q39_2,q47,q57,q67,q81 schema: ${tpcds_small} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 2: query: q02,q03,q04,q05,q07,q09,q10,q11,q13,q14_2,q16,q17,q19,q22,q23_1,q23_2,q24_1,q24_2,q25,q28,q29,q30,q31,q32,q33,q35,q37,q38,q42,q43,q44,q46,q48,q49,q50,q51,q52,q53,q54,q55,q56,q58,q59,q60,q61,q63,q65,q66,q68,q69,q70,q71,q72,q74,q75,q77,q78,q80,q82,q88,q89,q94,q95 schema: ${tpcds_medium} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 3: # query not passing quick enough without reordering query: q18,q64 schema: ${tpcds_medium} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 4: query: q08,q12,q15,q20,q21,q26,q27,q34,q36,q40,q41,q45,q62,q73,q76,q79,q83,q84,q85,q86,q87,q90,q91,q92,q93,q96,q97,q98,q99 schema: ${tpcds_large} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 5: # extra runs with reordering on 1tb schema (too slow without reordering on 1tb). For 100g we keep both runs, with and without reordering query: q03,q37,q42,q43,q52,q53 schema: ${tpcds_large} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOIN diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml index 3a19db13125a7..2463762311a10 100644 --- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml +++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml @@ -2,7 +2,7 @@ datasource: presto query-names: presto/tpch/${query}.sql runs: 6 prewarm-runs: 2 -before-execution: sleep-4s, presto/session_set_reorder_joins.sql +before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql frequency: 7 database: hive tpch_small: tpch_10gb_orc @@ -14,28 +14,28 @@ variables: # queries too slow to run on 100gb without reordering query: q2, q8, q9 schema: ${tpch_small} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 2: # queries too slow to run on 100gb without reordering query: q8, q9 schema: ${tpch_medium} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 3: # queries too slow to run on 100gb without reordering query: q2 schema: ${tpch_large} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 4: # queries too slow to run on 1tb query: q3, q4, q5, q7, q17, q18, q21 schema: ${tpch_medium} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 5: query: q10, q11, q12, q13, q14, q15, q16, q19, q20, q22 schema: ${tpch_large} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 6: # queries without joins query: q1, q6 schema: ${tpch_large} - reorder_joins: false + join_reordering_strategy: NONE diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql new file mode 100644 index 0000000000000..a832822eda53a --- /dev/null +++ b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql @@ -0,0 +1 @@ +SET SESSION join_reordering_strategy='${join_reordering_strategy}' diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql deleted file mode 100644 index 43d4a7faa4705..0000000000000 --- a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql +++ /dev/null @@ -1 +0,0 @@ -SET SESSION reorder_joins='${reorder_joins}' diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java index 1b4ebb9a0d725..a15ac6f661ea9 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java @@ -143,27 +143,27 @@ public void testUpdateSessionParameters() ClientSession session = options.toClientSession(); SqlParser sqlParser = new SqlParser(); - ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "distributed_join", "true"); + ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "join_distribution_type", "repartitioned"); ImmutableMap preparedStatements = ImmutableMap.of("my_query", "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog.test_schema"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog"); assertEquals(session.getSchema(), "test_schema"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_schema_b"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog"); assertEquals(session.getSchema(), "test_schema_b"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog_2.test_schema"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog_2"); assertEquals(session.getSchema(), "test_schema"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); } } diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 6284c52f07bf1..2b79bcc5562a1 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -6,26 +6,34 @@ This section describes the most important config properties that may be used to tune Presto or alter its behavior when required. .. contents:: - :local: +:local: :backlinks: none - :depth: 1 + :depth: 1 General Properties ------------------ -``distributed-joins-enabled`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``join-distribution-type`` +^^^^^^^^^^^^^^^^^^^^^^^^^^ - * **Type:** ``boolean`` - * **Default value:** ``true`` - - Use hash distributed joins instead of broadcast joins. Distributed joins - require redistributing both tables using a hash of the join key. This can - be slower (sometimes substantially) than broadcast joins, but allows much - larger joins. Broadcast joins require that the tables on the right side of - the join after filtering fit in memory on each node, whereas distributed joins - only need to fit in distributed memory across all nodes. This can also be - specified on a per-query basis using the ``distributed_join`` session property. + * **Type:** ``string`` + * **Allowed values:** ``AUTOMATIC``, ``REPARTITIONED``, ``REPLICATED`` + * **Default value:** ``REPARTITIONED`` + + The type of distributed join to use. When set to ``REPARTITIONED``, presto will + use hash distributed joins. When set to ``REPLICATED``, it will broadcast the + right table to all nodes in the cluster that have data from the left table. + Repartitioned joins require redistributing both tables using a hash of the join key. + This can be slower (sometimes substantially) than broadcast joins, but allows much + larger joins. In particular broadcast joins will be faster if the right table is + much smaller than the left. However, broadcast joins require that the tables on the right + side of the join after filtering fit in memory on each node, whereas distributed joins + only need to fit in distributed memory across all nodes. When set to ``AUTOMATIC``, + Presto will make a cost based decision as to which distribution type is optimal. + It will also consider switching the left and right inputs to the join. In ``AUTOMATIC`` + mode, Presto will default to replicated joins if no cost could be computed, such as if + the tables do not have statistics. This can also be specified on a per-query basis using + the ``join_distribution_type`` session property. ``redistribute-writes`` ^^^^^^^^^^^^^^^^^^^^^^^ @@ -368,6 +376,22 @@ Optimizer Properties using the ``push_table_write_through_union`` session property. +``optimizer.join-reordering-strategy`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``string`` + * **Allowed values:** ``COST_BASED``, ``ELIMINATE_CROSS_JOINS``, ``NONE`` + * **Default value:** ``ELIMINATE_CROSS_JOINS`` + + The join reordering strategy to use. ``NONE`` maintains the order the tables are listed in the + query. ``ELIMINATE_CROSS_JOINS`` reorders joins to eliminate cross joins where possible and + otherwise maintains the original query order. When reordering joins it also strives to maintain the + original table order as much as possible. ``COST_BASED`` enumerates possible orders and uses + statistics-based cost estimation to determine the least cost order. If stats are not available or if + for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. This can + also be specified on a per-query basis using the ``join_reordering_strategy`` session property. + + Regular Expression Function Properties -------------------------------------- diff --git a/presto-main/etc/config.properties b/presto-main/etc/config.properties index 44d7c148de4e0..30c32d6f1fcb3 100644 --- a/presto-main/etc/config.properties +++ b/presto-main/etc/config.properties @@ -41,5 +41,4 @@ plugin.bundles=\ ../presto-postgresql/pom.xml presto.version=testversion -distributed-joins-enabled=true node-scheduler.include-coordinator=true diff --git a/presto-main/pom.xml b/presto-main/pom.xml index dd535dce447c9..0aedfbf788d72 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -262,6 +262,11 @@ jgrapht-core + + org.pcollections + pcollections + + org.apache.bval diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3ed749a98a314..02d717269eefc 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -20,6 +20,8 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; +import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -27,6 +29,7 @@ import javax.inject.Inject; import java.util.List; +import java.util.stream.Stream; import static com.facebook.presto.spi.session.PropertyMetadata.booleanSessionProperty; import static com.facebook.presto.spi.session.PropertyMetadata.integerSessionProperty; @@ -37,11 +40,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; +import static java.util.stream.Collectors.joining; public final class SystemSessionProperties { public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation"; - public static final String DISTRIBUTED_JOIN = "distributed_join"; + public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type"; public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join"; public static final String HASH_PARTITION_COUNT = "hash_partition_count"; public static final String PREFER_STREAMING_OPERATORS = "prefer_streaming_operators"; @@ -58,7 +62,7 @@ public final class SystemSessionProperties public static final String DICTIONARY_AGGREGATION = "dictionary_aggregation"; public static final String PLAN_WITH_TABLE_NODE_PARTITIONING = "plan_with_table_node_partitioning"; public static final String COLOCATED_JOIN = "colocated_join"; - public static final String REORDER_JOINS = "reorder_joins"; + public static final String JOIN_REORDERING_STRATEGY = "join_reordering_strategy"; public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node"; public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval"; public static final String OPTIMIZE_METADATA_QUERIES = "optimize_metadata_queries"; @@ -101,11 +105,18 @@ public SystemSessionProperties( "Compute hash codes for distribution, joins, and aggregations early in query plan", featuresConfig.isOptimizeHashGeneration(), false), - booleanSessionProperty( - DISTRIBUTED_JOIN, - "Use a distributed join instead of a broadcast join", - featuresConfig.isDistributedJoinsEnabled(), - false), + new PropertyMetadata<>( + JOIN_DISTRIBUTION_TYPE, + format("The join method to use. Options are %s", + Stream.of(JoinDistributionType.values()) + .map(FeaturesConfig.JoinDistributionType::name) + .collect(joining(","))), + VARCHAR, + JoinDistributionType.class, + featuresConfig.getJoinDistributionType(), + false, + value -> JoinDistributionType.valueOf(((String) value).toUpperCase()), + JoinDistributionType::name), booleanSessionProperty( DISTRIBUTED_INDEX_JOIN, "Distribute index joins on join keys instead of executing inline", @@ -240,11 +251,18 @@ public SystemSessionProperties( "Experimental: Adapt plan to pre-partitioned tables", true, false), - booleanSessionProperty( - REORDER_JOINS, - "Experimental: Reorder joins to optimize plan", - featuresConfig.isJoinReorderingEnabled(), - false), + new PropertyMetadata<>( + JOIN_REORDERING_STRATEGY, + format("The join reordering strategy to use. Options are %s", + Stream.of(JoinReorderingStrategy.values()) + .map(FeaturesConfig.JoinReorderingStrategy::name) + .collect(joining(","))), + VARCHAR, + JoinReorderingStrategy.class, + featuresConfig.getJoinReorderingStrategy(), + false, + value -> JoinReorderingStrategy.valueOf(((String) value).toUpperCase()), + JoinReorderingStrategy::name), booleanSessionProperty( FAST_INEQUALITY_JOINS, "Use faster handling of inequality join if it is possible", @@ -347,11 +365,6 @@ public static boolean isOptimizeHashGenerationEnabled(Session session) return session.getSystemProperty(OPTIMIZE_HASH_GENERATION, Boolean.class); } - public static boolean isDistributedJoinEnabled(Session session) - { - return session.getSystemProperty(DISTRIBUTED_JOIN, Boolean.class); - } - public static boolean isDistributedIndexJoinEnabled(Session session) { return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class); @@ -427,9 +440,9 @@ public static boolean isFastInequalityJoin(Session session) return session.getSystemProperty(FAST_INEQUALITY_JOINS, Boolean.class); } - public static boolean isJoinReorderingEnabled(Session session) + public static JoinReorderingStrategy getJoinReorderingStrategy(Session session) { - return session.getSystemProperty(REORDER_JOINS, Boolean.class); + return session.getSystemProperty(JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.class); } public static boolean isColocatedJoinEnabled(Session session) @@ -515,4 +528,9 @@ public static boolean isUseNewStatsCalculator(Session session) { return session.getSystemProperty(USE_NEW_STATS_CALCULATOR, Boolean.class); } + + public static JoinDistributionType getJoinDistributionType(Session session) + { + return session.getSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java new file mode 100644 index 0000000000000..4d57c0cc7eeac --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingCostCalculator + implements CostCalculator +{ + private final CostCalculator costCalculator; + private final Map costs = new HashMap<>(); + + public CachingCostCalculator(CostCalculator costCalculator) + { + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + } + + @Override + public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!costs.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback + PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(planNode, lookup, session, types); + requireNonNull(costs, "computed cost can not be null"); + checkState(costs.put(planNode, cost) == null, "cost for " + planNode + " already computed"); + } + return costs.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java new file mode 100644 index 0000000000000..80d5947c06fa9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats = new HashMap<>(); + + public CachingStatsCalculator(StatsCalculator statsCalculator) + { + this.statsCalculator = statsCalculator; + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!stats.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types); + requireNonNull(stats, "computed stats can not be null"); + checkState(stats.put(planNode, statsEstimate) == null, "statistics for " + planNode + " already computed"); + } + return stats.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java index abccad58b89fc..784996cf30bf0 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -21,8 +21,6 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.inject.BindingAnnotation; -import javax.annotation.concurrent.ThreadSafe; - import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.util.Map; @@ -36,7 +34,6 @@ * Computes estimated cost of executing given PlanNode. * Implementation may use lookup to compute needed traits for self/source nodes. */ -@ThreadSafe public interface CostCalculator { PlanNodeCostEstimate calculateCost( diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java index 94bc8e96a5718..c8d4b20d88ff1 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java @@ -18,9 +18,10 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; -import java.util.HashMap; import java.util.Map; +import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom; +import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS; import static com.google.common.base.Predicates.not; public class EnsureStatsMatchOutput @@ -29,16 +30,16 @@ public class EnsureStatsMatchOutput @Override public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types) { - Map symbolSymbolStats = new HashMap<>(); - estimate.getSymbolsWithKnownStatistics().stream() - .filter(node.getOutputSymbols()::contains) - .forEach(symbol -> symbolSymbolStats.put(symbol, estimate.getSymbolStatistics(symbol))); + PlanNodeStatsEstimate.Builder builder = buildFrom(estimate); node.getOutputSymbols().stream() .filter(not(estimate.getSymbolsWithKnownStatistics()::contains)) - .filter(not(symbolSymbolStats::containsKey)) - .forEach(symbol -> symbolSymbolStats.put(symbol, SymbolStatsEstimate.UNKNOWN_STATS)); + .forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS)); + + estimate.getSymbolsWithKnownStatistics().stream() + .filter(not(node.getOutputSymbols()::contains)) + .forEach(builder::removeSymbolStatistics); - return PlanNodeStatsEstimate.buildFrom(estimate).setSymbolStatistics(symbolSymbolStats).build(); + return builder.build(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java new file mode 100644 index 0000000000000..8a95329713d91 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class JoinNodeCachingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats = new HashMap<>(); + + public JoinNodeCachingStatsCalculator(StatsCalculator statsCalculator) + { + this.statsCalculator = statsCalculator; + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!(planNode instanceof JoinNode)) { + return statsCalculator.calculateStats(planNode, lookup, session, types); + } + + PlanNodeId key = planNode.getId(); + if (!stats.containsKey(key)) { + // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types); + requireNonNull(stats, "computed stats can not be null"); + checkState(stats.put(key, statsEstimate) == null, "statistics for " + planNode + " already computed"); + } + return stats.get(key); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index 3d59f90dab49c..cf081a6926c5f 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -14,14 +14,13 @@ package com.facebook.presto.cost; import com.facebook.presto.sql.planner.Symbol; -import com.google.common.collect.ImmutableMap; +import org.pcollections.HashTreePMap; +import org.pcollections.PMap; -import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -34,13 +33,13 @@ public class PlanNodeStatsEstimate public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10; private final double outputRowCount; - private final Map symbolStatistics; + private final PMap symbolStatistics; - private PlanNodeStatsEstimate(double outputRowCount, Map symbolStatistics) + private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; - this.symbolStatistics = ImmutableMap.copyOf(symbolStatistics); + this.symbolStatistics = symbolStatistics; } /** @@ -83,26 +82,15 @@ public PlanNodeStatsEstimate mapOutputRowCount(Function mappingF public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction) { return buildFrom(this) - .setSymbolStatistics(symbolStatistics.entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - e -> { - if (e.getKey().equals(symbol)) { - return mappingFunction.apply(e.getValue()); - } - return e.getValue(); - }))) + .addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol))) .build(); } public PlanNodeStatsEstimate add(PlanNodeStatsEstimate other) { // TODO this is broken (it does not operate on symbol stats at all). Remove or fix - ImmutableMap.Builder symbolsStatsBuilder = ImmutableMap.builder(); - symbolsStatsBuilder.putAll(symbolStatistics).putAll(other.symbolStatistics); // This may not count all information - - PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); - return statsBuilder.setSymbolStatistics(symbolsStatsBuilder.build()) + return buildFrom(this) + .addSymbolStatistics(other.symbolStatistics) .setOutputRowCount(getOutputRowCount() + other.getOutputRowCount()) .build(); } @@ -153,14 +141,24 @@ public static Builder builder() public static Builder buildFrom(PlanNodeStatsEstimate other) { - return builder().setOutputRowCount(other.getOutputRowCount()) - .setSymbolStatistics(other.symbolStatistics); + return new Builder(other.getOutputRowCount(), other.symbolStatistics); } public static final class Builder { - private double outputRowCount = NaN; - private Map symbolStatistics = new HashMap<>(); + private double outputRowCount; + private PMap symbolStatistics; + + public Builder() + { + this(NaN, HashTreePMap.empty()); + } + + private Builder(double outputRowCount, PMap symbolStatistics) + { + this.outputRowCount = outputRowCount; + this.symbolStatistics = symbolStatistics; + } public Builder setOutputRowCount(double outputRowCount) { @@ -168,15 +166,21 @@ public Builder setOutputRowCount(double outputRowCount) return this; } - public Builder setSymbolStatistics(Map symbolStatistics) + public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) { - this.symbolStatistics = new HashMap<>(symbolStatistics); + symbolStatistics = symbolStatistics.plus(symbol, statistics); return this; } - public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) + public Builder addSymbolStatistics(Map symbolStatistics) + { + this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics); + return this; + } + + public Builder removeSymbolStatistics(Symbol symbol) { - this.symbolStatistics.put(symbol, statistics); + symbolStatistics = symbolStatistics.minus(symbol); return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java index 234bc6729bb29..713e9c757a308 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java @@ -77,7 +77,7 @@ public Optional calculate(PlanNode node, Lookup lookup, S return Optional.of(PlanNodeStatsEstimate.builder() .setOutputRowCount(tableStatistics.getRowCount().getValue()) - .setSymbolStatistics(outputSymbolStats) + .addSymbolStatistics(outputSymbolStats) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java index 4ecc694ec6a00..012ea3a6b51db 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java @@ -26,7 +26,6 @@ import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.SymbolReference; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -103,7 +102,17 @@ public static Expression binaryExpression(LogicalBinaryExpression.Type type, Col { requireNonNull(type, "type is null"); requireNonNull(expressions, "expressions is null"); - Preconditions.checkArgument(!expressions.isEmpty(), "expressions is empty"); + + if (expressions.isEmpty()) { + switch (type) { + case AND: + return TRUE_LITERAL; + case OR: + return FALSE_LITERAL; + default: + throw new IllegalArgumentException("Unsupported LogicalBinaryExpression type"); + } + } // Build balanced tree for efficient recursive processing that // preserves the evaluation order of the input expressions. @@ -309,7 +318,8 @@ public static Expression normalize(Expression expression) public static Expression rewriteIdentifiersToSymbolReferences(Expression expression) { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() { + return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { @Override public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter treeRewriter) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 4c7b3a0d9a99b..b5b3205dd0683 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -27,8 +27,11 @@ import java.nio.file.Paths; import java.util.List; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; @DefunctConfig({ @@ -42,11 +45,11 @@ public class FeaturesConfig private double cpuCostWeight = 0.75; private double memoryCostWeight = 0; private double networkCostWeight = 0.25; + private boolean distributedIndexJoinsEnabled; - private boolean distributedJoinsEnabled = true; private boolean colocatedJoinsEnabled; private boolean fastInequalityJoins = true; - private boolean reorderJoins = true; + private JoinReorderingStrategy joinReorderingStrategy = ELIMINATE_CROSS_JOINS; private boolean redistributeWrites = true; private boolean optimizeMetadataQueries; private boolean optimizeHashGeneration = true; @@ -59,6 +62,7 @@ public class FeaturesConfig private boolean legacyMapSubscript; private boolean newMapBlock = true; private boolean optimizeMixedDistinctAggregations; + private JoinDistributionType joinDistributionType = REPARTITIONED; private boolean dictionaryAggregation; private boolean resourceGroups; @@ -77,6 +81,30 @@ public class FeaturesConfig private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata + public enum JoinReorderingStrategy + { + ELIMINATE_CROSS_JOINS, + COST_BASED, + NONE + } + + public enum JoinDistributionType + { + AUTOMATIC, + REPLICATED, + REPARTITIONED; + + public boolean canRepartition() + { + return this == REPARTITIONED || this == AUTOMATIC; + } + + public boolean canReplicate() + { + return this == REPLICATED || this == AUTOMATIC; + } + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -137,11 +165,6 @@ public FeaturesConfig setDistributedIndexJoinsEnabled(boolean distributedIndexJo return this; } - public boolean isDistributedJoinsEnabled() - { - return distributedJoinsEnabled; - } - @Config("deprecated.legacy-array-agg") public FeaturesConfig setLegacyArrayAgg(boolean legacyArrayAgg) { @@ -190,13 +213,6 @@ public boolean isNewMapBlock() return newMapBlock; } - @Config("distributed-joins-enabled") - public FeaturesConfig setDistributedJoinsEnabled(boolean distributedJoinsEnabled) - { - this.distributedJoinsEnabled = distributedJoinsEnabled; - return this; - } - public boolean isColocatedJoinsEnabled() { return colocatedJoinsEnabled; @@ -223,16 +239,16 @@ public boolean isFastInequalityJoins() return fastInequalityJoins; } - public boolean isJoinReorderingEnabled() + public JoinReorderingStrategy getJoinReorderingStrategy() { - return reorderJoins; + return joinReorderingStrategy; } - @Config("reorder-joins") - @ConfigDescription("Experimental: Reorder joins to optimize plan") - public FeaturesConfig setJoinReorderingEnabled(boolean reorderJoins) + @Config("optimizer.join-reordering-strategy") + @ConfigDescription("The strategy to use for reordering joins") + public FeaturesConfig setJoinReorderingStrategy(JoinReorderingStrategy joinReorderingStrategy) { - this.reorderJoins = reorderJoins; + this.joinReorderingStrategy = joinReorderingStrategy; return this; } @@ -490,4 +506,16 @@ public boolean isUseNewStatsCalculator() { return useNewStatsCalculator; } + + @Config("join-distribution-type") + public FeaturesConfig setJoinDistributionType(JoinDistributionType joinDistributionType) + { + this.joinDistributionType = requireNonNull(joinDistributionType, "joinDistributionType is null"); + return this; + } + + public JoinDistributionType getJoinDistributionType() + { + return joinDistributionType; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a1f4c44c28340..c5d07fec5a087 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -55,6 +55,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant; import com.facebook.presto.sql.planner.iterative.rule.SingleMarkDistinctToGroupBy; import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; @@ -322,13 +323,32 @@ public PlanOptimizers( ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins()), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again ImmutableSet.of(new EliminateCrossJoins()) ), + new PredicatePushDown(metadata, sqlParser), new IterativeOptimizer( stats, statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushDownTableConstraints(metadata, sqlParser))), - projectionPushDown); + projectionPushDown, + new PruneUnreferencedOutputs(), + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new RemoveRedundantIdentityProjections()) + ), + + // Because ReorderJoins runs only once, + // PredicatePushDown, PruneUnreferenedOutputpus and RemoveRedundantIdentityProjections + // need to run beforehand in order to produce an optimal join order + // It also needs to run after EliminateCrossJoins so that its chosen order doesn't get undone. + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new ReorderJoins(costComparator, statsCalculator, costCalculator)) + )); if (featuresConfig.isOptimizeSingleDistinct()) { builder.add( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 2252155b6697f..d76bea4dd65bf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -15,6 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.cost.CachingCostCalculator; +import com.facebook.presto.cost.CachingStatsCalculator; import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.matching.MatchingEngine; @@ -81,7 +83,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } Memo memo = new Memo(idAllocator, plan); - Lookup lookup = new MemoBasedLookup(memo, statsCalculator, costCalculator); + Lookup lookup = Lookup.from(memo::resolve, new CachingStatsCalculator(statsCalculator), new CachingCostCalculator(costCalculator)); Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 65b2304952fcf..d8e5146690f1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -102,7 +102,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) { - return costCalculator.calculateCumulativeCost(node, this, session, types); + return costCalculator.calculateCumulativeCost(resolve(node), this, session, types); } }; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java deleted file mode 100644 index b718fdfc5dedd..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCostEstimate; -import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.StatsCalculator; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.PlanNode; - -import java.util.HashMap; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkState; -import static java.util.Objects.requireNonNull; - -public class MemoBasedLookup - implements Lookup -{ - private final Memo memo; - private final Map stats = new HashMap<>(); - private final Map costs = new HashMap<>(); - private final StatsCalculator statsCalculator; - private final CostCalculator costCalculator; - - public MemoBasedLookup(Memo memo, StatsCalculator statsCalculator, CostCalculator costCalculator) - { - this.memo = requireNonNull(memo, "memo can not be null"); - this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); - } - - @Override - public PlanNode resolve(PlanNode node) - { - if (node instanceof GroupReference) { - return memo.getNode(((GroupReference) node).getGroupId()); - } - return node; - } - - // todo[LO] maybe lookup passed to stats/cost calculator should be constrained so only - // methods for obtaining traits and only for self and sources would be allowed? - - @Override - public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types) - { - PlanNode key = resolve(planNode); - if (!stats.containsKey(key)) { - // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback - PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(key, this, session, types); - requireNonNull(stats, "computed stats can not be null"); - checkState(stats.put(key, statsEstimate) == null, "statistics for " + key + " already computed"); - } - return stats.get(key); - } - - @Override - public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types) - { - PlanNode key = resolve(planNode); - if (!costs.containsKey(key)) { - // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback - PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(key, this, session, types); - requireNonNull(costs, "computed cost can not be null"); - checkState(costs.put(key, cost) == null, "cost for " + key + " already computed"); - } - return costs.get(key); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 2452f75ae7a40..a4e7907a8fff0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; @@ -41,6 +40,9 @@ import java.util.PriorityQueue; import java.util.Set; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -65,7 +67,8 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + // we run this for cost_based reordering also for cases when some of the tables do not have statistics + if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java new file mode 100644 index 0000000000000..ea394fd43b8ef --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; + +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.planner.DeterminismEvaluator.isDeterministic; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * This class represents a set of inner joins that can be executed in any order. + */ +class MultiJoinNode +{ + private static final int JOIN_LIMIT = 10; + + private final List sources; + private final Expression filter; + private final List outputSymbols; + + public MultiJoinNode(List sources, Expression filter, List outputSymbols) + { + this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null")); + this.filter = requireNonNull(filter, "filter is null"); + this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); + + List inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList()); + checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols"); + } + + public Expression getFilter() + { + return filter; + } + + public List getSources() + { + return sources; + } + + public List getOutputSymbols() + { + return outputSymbols; + } + + static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup) + { + return new MultiJoinNodeBuilder(joinNode, lookup).toMultiJoinNode(); + } + + private static class MultiJoinNodeBuilder + { + private final List sources = new ArrayList<>(); + private final List filters = new ArrayList<>(); + private final List outputSymbols; + private final Lookup lookup; + + MultiJoinNodeBuilder(JoinNode node, Lookup lookup) + { + requireNonNull(node, "node is null"); + checkState(node.getType() == INNER, "join type must be INNER"); + this.outputSymbols = node.getOutputSymbols(); + this.lookup = requireNonNull(lookup, "lookup is null"); + flattenNode(node); + } + + private void flattenNode(PlanNode node) + { + PlanNode resolved = lookup.resolve(node); + if (resolved instanceof JoinNode && sources.size() < JOIN_LIMIT) { + JoinNode joinNode = (JoinNode) resolved; + if (joinNode.getType() == INNER && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL))) { + flattenNode(joinNode.getLeft()); + flattenNode(joinNode.getRight()); + joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::toExpression) + .forEach(filters::add); + joinNode.getFilter().ifPresent(filters::add); + return; + } + } + sources.add(node); + } + + MultiJoinNode toMultiJoinNode() + { + return new MultiJoinNode(sources, and(filters), outputSymbols); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java new file mode 100644 index 0000000000000..f6dee21226964 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -0,0 +1,416 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CachingCostCalculator; +import com.facebook.presto.cost.CachingStatsCalculator; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.JoinNodeCachingStatsCalculator; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.EqualityInference; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; +import com.google.common.collect.Sets; +import io.airlift.log.Logger; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.cost.PlanNodeCostEstimate.INFINITE_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; +import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; +import static com.facebook.presto.sql.planner.plan.Assignments.identity; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.in; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.stream.StreamSupport.stream; + +public class ReorderJoins + implements Rule +{ + private static final Logger log = Logger.get(ReorderJoins.class); + + private final CostComparator costComparator; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; + + public ReorderJoins(CostComparator costComparator, StatsCalculator statsCalculator, CostCalculator costCalculator) + { + this.costComparator = requireNonNull(costComparator, "costComparator is null"); + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof JoinNode) || getJoinReorderingStrategy(session) != COST_BASED) { + return Optional.empty(); + } + + JoinNode joinNode = (JoinNode) node; + // We check that join distribution type is absent because we only want to do this transformation once (reordered joins will have distribution type already set). + // We check determinisitic filters because we can't reorder joins with non-deterministic filters + if (!(joinNode.getType() == INNER) || joinNode.getDistributionType().isPresent()) { + return Optional.empty(); + } + + MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, lookup); + if (multiJoinNode.getSources().size() < 2) { + return Optional.empty(); + } + + Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(statsCalculator)), new CachingCostCalculator(costCalculator)); + JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); + return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode(); + } + + @VisibleForTesting + static class JoinEnumerator + { + private final Map, JoinEnumerationResult> memo = new HashMap<>(); + private final PlanNodeIdAllocator idAllocator; + private final Session session; + private final Ordering resultOrdering; + private final EqualityInference allInference; + private final Expression allFilter; + private final SymbolAllocator symbolAllocator; + private final Lookup lookup; + + @VisibleForTesting + JoinEnumerator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator) + { + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(symbolAllocator, "symbolAllocator is null"); + requireNonNull(session, "session is null"); + requireNonNull(lookup, "lookup is null"); + requireNonNull(filter, "filter is null"); + requireNonNull(costComparator, "costComparator is null"); + this.idAllocator = idAllocator; + this.symbolAllocator = symbolAllocator; + this.session = session; + this.lookup = lookup; + this.resultOrdering = getResultOrdering(costComparator, session); + this.allInference = createEqualityInference(filter); + this.allFilter = filter; + } + + private static Ordering getResultOrdering(CostComparator costComparator, Session session) + { + return new Ordering() + { + @Override + public int compare(JoinEnumerationResult result1, JoinEnumerationResult result2) + { + return costComparator.compare(session, result1.cost, result2.cost); + } + }; + } + + private JoinEnumerationResult chooseJoinOrder(List sources, List outputSymbols) + { + Set multiJoinKey = ImmutableSet.copyOf(sources); + JoinEnumerationResult bestResult = memo.get(multiJoinKey); + if (bestResult == null) { + checkState(sources.size() > 1, "sources size is less than or equal to one"); + ImmutableList.Builder resultBuilder = ImmutableList.builder(); + Set> partitions = generatePartitions(sources.size()).collect(toImmutableSet()); + for (Set partition : partitions) { + JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputSymbols, partition); + if (result.cost.isUnknown()) { + memo.put(multiJoinKey, result); + return result; + } + if (!result.cost.equals(INFINITE_COST)) { + resultBuilder.add(result); + } + } + + List results = resultBuilder.build(); + if (results.isEmpty()) { + memo.put(multiJoinKey, INFINITE_COST_RESULT); + return INFINITE_COST_RESULT; + } + + bestResult = resultOrdering.min(resultBuilder.build()); + memo.put(multiJoinKey, bestResult); + } + if (bestResult.planNode.isPresent()) { + log.debug("Least cost join was: " + bestResult.planNode.get().toString()); + } + return bestResult; + } + + /** + * This method generates all the ways of dividing totalNodes into two sets + * each containing at least one node. It will generate one set for each + * possible partitioning. The other partition is implied in the absent values. + * In order not to generate the inverse of any set, we always include the 0th + * node in our sets. + * + * @param totalNodes + * @return A set of sets each of which defines a partitioning of totalNodes + */ + @VisibleForTesting + static Stream> generatePartitions(int totalNodes) + { + checkArgument(totalNodes >= 2, "totalNodes must be greater than or equal to 2"); + Set numbers = IntStream.range(0, totalNodes) + .boxed() + .collect(toImmutableSet()); + return Sets.powerSet(numbers).stream() + .filter(subSet -> subSet.contains(0)) + .filter(subSet -> subSet.size() < numbers.size()); + } + + JoinEnumerationResult createJoinAccordingToPartitioning(List sources, List outputSymbols, Set partitioning) + { + Set leftSources = partitioning.stream() + .map(sources::get) + .collect(toImmutableSet()); + Set rightSources = Sets.difference(ImmutableSet.copyOf(sources), ImmutableSet.copyOf(leftSources)); + return createJoin(leftSources, rightSources, outputSymbols); + } + + private JoinEnumerationResult createJoin(Set leftSources, Set rightSources, List outputSymbols) + { + Set leftSymbols = leftSources.stream() + .flatMap(node -> node.getOutputSymbols().stream()) + .collect(toImmutableSet()); + Set rightSymbols = rightSources.stream() + .flatMap(node -> node.getOutputSymbols().stream()) + .collect(toImmutableSet()); + ImmutableList.Builder joinPredicatesBuilder = ImmutableList.builder(); + + // add join conjucts that were not used for inference + stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false) + .map(conjuct -> allInference.rewriteExpression(conjuct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))) + .filter(Objects::nonNull) + // filter expressions that contain only left or right symbols + .filter(conjuct -> allInference.rewriteExpression(conjuct, leftSymbols::contains) == null) + .filter(conjuct -> allInference.rewriteExpression(conjuct, rightSymbols::contains) == null) + .forEach(joinPredicatesBuilder::add); + + // create equality inference on available symbols + // TODO: make generateEqualitiesPartitionedBy take left and right scope + List joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); + EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[joinEqualities.size()])); + joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); + + List joinPredicates = joinPredicatesBuilder.build(); + List joinConditions = joinPredicates.stream() + .filter(JoinEnumerator::isJoinEqualityCondition) + .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftSymbols)) + .collect(toImmutableList()); + if (joinConditions.isEmpty()) { + return INFINITE_COST_RESULT; + } + List joinFilters = joinPredicates.stream() + .filter(predicate -> !isJoinEqualityCondition(predicate)) + .collect(toImmutableList()); + + Set requiredJoinSymbols = ImmutableSet.builder() + .addAll(outputSymbols) + .addAll(SymbolsExtractor.extractUnique(joinPredicates)) + .build(); + + JoinEnumerationResult leftResult = getJoinSource( + idAllocator, + ImmutableList.copyOf(leftSources), + requiredJoinSymbols.stream().filter(leftSymbols::contains).collect(toImmutableList())); + if (leftResult.cost.isUnknown()) { + return UNKNOWN_COST_RESULT; + } + if (leftResult.cost.equals(INFINITE_COST)) { + return INFINITE_COST_RESULT; + } + PlanNode left = leftResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present")); + JoinEnumerationResult rightResult = getJoinSource( + idAllocator, + ImmutableList.copyOf(rightSources), + requiredJoinSymbols.stream() + .filter(rightSymbols::contains) + .collect(toImmutableList())); + if (rightResult.cost.isUnknown()) { + return UNKNOWN_COST_RESULT; + } + if (rightResult.cost.equals(INFINITE_COST)) { + return INFINITE_COST_RESULT; + } + PlanNode right = rightResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present")); + + // sort output symbols so that the left input symbols are first + List sortedOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream()) + .filter(outputSymbols::contains) + .collect(toImmutableList()); + + // Cross joins can't filter symbols as part of the join + // If we're doing a cross join, use all output symbols from the inputs and add a project node + // on top + List joinOutputSymbols = sortedOutputSymbols; + if (joinConditions.isEmpty() && joinFilters.isEmpty()) { + joinOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream()) + .collect(toImmutableList()); + } + + JoinEnumerationResult result = setJoinNodeProperties(new JoinNode( + idAllocator.getNextId(), + INNER, + left, + right, + joinConditions, + joinOutputSymbols, + joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), + Optional.empty(), + Optional.empty(), + Optional.empty())); + + if (!joinOutputSymbols.equals(sortedOutputSymbols)) { + PlanNode resultNode = new ProjectNode(idAllocator.getNextId(), result.planNode.get(), identity(sortedOutputSymbols)); + result = new JoinEnumerationResult(lookup.getCumulativeCost(resultNode, session, symbolAllocator.getTypes()), Optional.of(resultNode)); + } + + return result; + } + + private JoinEnumerationResult getJoinSource(PlanNodeIdAllocator idAllocator, List nodes, List outputSymbols) + { + PlanNode planNode; + if (nodes.size() == 1) { + planNode = getOnlyElement(nodes); + ImmutableList.Builder predicates = ImmutableList.builder(); + predicates.addAll(allInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); + stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false) + .map(conjuct -> allInference.rewriteExpression(conjuct, outputSymbols::contains)) + .filter(Objects::nonNull) + .forEach(predicates::add); + Expression filter = combineConjuncts(predicates.build()); + if (!(TRUE_LITERAL).equals(filter)) { + planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); + } + return new JoinEnumerationResult(lookup.getCumulativeCost(planNode, session, symbolAllocator.getTypes()), Optional.of(planNode)); + } + return chooseJoinOrder(nodes, outputSymbols); + } + + private static boolean isJoinEqualityCondition(Expression expression) + { + return expression instanceof ComparisonExpression + && ((ComparisonExpression) expression).getType() == EQUAL + && ((ComparisonExpression) expression).getLeft() instanceof SymbolReference + && ((ComparisonExpression) expression).getRight() instanceof SymbolReference; + } + + private static JoinNode.EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftSymbols) + { + Symbol leftSymbol = Symbol.from(equality.getLeft()); + Symbol rightSymbol = Symbol.from(equality.getRight()); + JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(leftSymbol, rightSymbol); + return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip(); + } + + private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) + { + List possibleJoinNodes = new ArrayList<>(); + FeaturesConfig.JoinDistributionType joinDistributionType = getJoinDistributionType(session); + if (joinDistributionType.canRepartition() && !joinNode.isCrossJoin()) { + JoinNode node = joinNode.withDistributionType(PARTITIONED); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + node = node.flipChildren(); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + } + if (joinDistributionType.canReplicate()) { + JoinNode node = joinNode.withDistributionType(REPLICATED); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + node = node.flipChildren(); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + } + if (possibleJoinNodes.stream().anyMatch(result -> result.cost.isUnknown())) { + return UNKNOWN_COST_RESULT; + } + return resultOrdering.min(possibleJoinNodes); + } + } + + @VisibleForTesting + static class JoinEnumerationResult + { + static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(UNKNOWN_COST, Optional.empty()); + static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(INFINITE_COST, Optional.empty()); + + private final Optional planNode; + private final PlanNodeCostEstimate cost; + + private JoinEnumerationResult(PlanNodeCostEstimate cost, Optional planNode) + { + this.cost = requireNonNull(cost); + this.planNode = requireNonNull(planNode); + checkArgument(cost.isUnknown() || cost.equals(INFINITE_COST) || planNode.isPresent(), "planNode must be present if cost is known"); + } + + public Optional getPlanNode() + { + return planNode; + } + + public PlanNodeCostEstimate getCost() + { + return cost; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java index cb912fbbb0925..424a8f4f31354 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java @@ -27,7 +27,7 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.SystemSessionProperties.isDistributedJoinEnabled; +import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -111,9 +111,12 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context) private JoinNode.DistributionType getTargetJoinDistributionType(JoinNode node) { + if (node.getDistributionType().isPresent()) { + return node.getDistributionType().get(); + } // The implementation of full outer join only works if the data is hash partitioned. See LookupJoinOperators#buildSideOuterJoinUnvisitedPositions JoinNode.Type type = node.getType(); - if (type == RIGHT || type == FULL || (isDistributedJoinEnabled(session) && !mustBroadcastJoin(node))) { + if (type == RIGHT || type == FULL || (isRepartitionedJoinEnabled(session) && !mustBroadcastJoin(node))) { return JoinNode.DistributionType.PARTITIONED; } @@ -132,11 +135,16 @@ private static boolean isCrossJoin(JoinNode node) private SemiJoinNode.DistributionType getTargetSemiJoinDistributionType(boolean isDeleteQuery) { - if (isDistributedJoinEnabled(session) && !isDeleteQuery) { + if (isRepartitionedJoinEnabled(session) && !isDeleteQuery) { return SemiJoinNode.DistributionType.PARTITIONED; } return SemiJoinNode.DistributionType.REPLICATED; } + + private static boolean isRepartitionedJoinEnabled(Session session) + { + return getJoinDistributionType(session).canRepartition(); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java index 89a4eccfc674c..57f29e4dad791 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; @@ -28,6 +27,9 @@ import java.util.Map; import java.util.Objects; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.buildJoinTree; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; @@ -47,7 +49,7 @@ public PlanNode optimize( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { - if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) { return plan; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index 3ecc8862ae35b..0eb4615f145fc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -28,10 +28,14 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.Collectors; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SortExpressionExtractor.extractSortExpression; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -100,6 +104,58 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin"); } + public JoinNode flipChildren() + { + return new JoinNode( + getId(), + flipType(type), + right, + left, + flipJoinCriteria(criteria), + flipOutputSymbols(getOutputSymbols(), left, right), + filter, + rightHashSymbol, + leftHashSymbol, + distributionType); + } + + private static Type flipType(Type type) + { + switch (type) { + case INNER: + return INNER; + case FULL: + return FULL; + case LEFT: + return RIGHT; + case RIGHT: + return LEFT; + default: + throw new IllegalStateException("No inverse defined for join type: " + type); + } + } + + private static List flipJoinCriteria(List joinCriteria) + { + return joinCriteria.stream() + .map(EquiJoinClause::flip) + .collect(toImmutableList()); + } + + private static List flipOutputSymbols(List outputSymbols, PlanNode left, PlanNode right) + { + List leftSymbols = outputSymbols.stream() + .filter(symbol -> left.getOutputSymbols().contains(symbol)) + .collect(Collectors.toList()); + List rightSymbols = outputSymbols.stream() + .filter(symbol -> right.getOutputSymbols().contains(symbol)) + .collect(Collectors.toList()); + return ImmutableList.builder() + .addAll(rightSymbols) + .addAll(leftSymbols) + .build(); + } + public enum DistributionType { PARTITIONED, @@ -230,6 +286,11 @@ public PlanNode replaceChildren(List newChildren) return new JoinNode(getId(), type, newLeft, newRight, criteria, newOutputSymbols, filter, leftHashSymbol, rightHashSymbol, distributionType); } + public JoinNode withDistributionType(DistributionType distributionType) + { + return new JoinNode(getId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.of(distributionType)); + } + public boolean isCrossJoin() { return criteria.isEmpty() && !filter.isPresent() && type == INNER; @@ -264,6 +325,11 @@ public ComparisonExpression toExpression() return new ComparisonExpression(ComparisonExpressionType.EQUAL, left.toSymbolReference(), right.toSymbolReference()); } + public EquiJoinClause flip() + { + return new EquiJoinClause(right, left); + } + @Override public boolean equals(Object obj) { diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index d938c37b8e7fa..268336a2e88b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -251,15 +251,16 @@ public LocalQueryRunner(Session defaultSession) new FeaturesConfig() .setOptimizeMixedDistinctAggregations(true) .setIterativeOptimizerEnabled(true), - false); + false, + 1); } public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig) { - this(defaultSession, featuresConfig, false); + this(defaultSession, featuresConfig, false, 1); } - private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction) + private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction, int nodeCountForStats) { requireNonNull(defaultSession, "defaultSession is null"); checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction"); @@ -379,14 +380,19 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new CoefficientBasedStatsCalculator(metadata), ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata))); this.costCalculator = new CostCalculatorUsingExchanges(getNodeCount()); - this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, getNodeCount()); + this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, nodeCountForStats); this.lookup = new StatelessLookup(statsCalculator, costCalculator); } public static LocalQueryRunner queryRunnerWithInitialTransaction(Session defaultSession) { checkArgument(!defaultSession.getTransactionId().isPresent(), "Already in transaction!"); - return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true); + return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true, 1); + } + + public static LocalQueryRunner queryRunnerWithFakeNodeCountForStats(Session defaultSession, int nodeCount) + { + return new LocalQueryRunner(defaultSession, new FeaturesConfig(), false, nodeCount); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java new file mode 100644 index 0000000000000..0bf9e06ac13d8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.testing; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class TestingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats; + + public TestingStatsCalculator(StatsCalculator statsCalculator, Map stats) + { + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.stats = ImmutableMap.copyOf(requireNonNull(stats, "stats is null")); + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (stats.containsKey(planNode.getId())) { + return stats.get(planNode.getId()); + } + + return statsCalculator.calculateStats(planNode, lookup, session, types); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java index 7f8b8e0a6c32a..f516618b1563e 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java @@ -27,8 +27,8 @@ import java.util.Locale; -import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; @@ -59,7 +59,7 @@ public void testCreateSession() .put(PRESTO_TIME_ZONE, "Asia/Taipei") .put(PRESTO_CLIENT_INFO, "client-info") .put(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") - .put(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43") + .put(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43") .put(PRESTO_PREPARED_STATEMENT, "query1=select * from foo,query2=select * from bar") .build(), "testRemote"); @@ -82,7 +82,7 @@ public void testCreateSession() assertEquals(session.getClientInfo().get(), "client-info"); assertEquals(session.getSystemProperties(), ImmutableMap.builder() .put(QUERY_MAX_MEMORY, "1GB") - .put(DISTRIBUTED_JOIN, "true") + .put(JOIN_DISTRIBUTION_TYPE, "repartitioned") .put(HASH_PARTITION_COUNT, "43") .build()); assertEquals(session.getPreparedStatements(), ImmutableMap.builder() diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java index 3fad76a15ea32..a2ddec819a837 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java @@ -35,8 +35,8 @@ import java.net.URI; import java.util.List; -import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; @@ -134,7 +134,7 @@ public void testQuery() .setHeader(PRESTO_SCHEMA, "schema") .setHeader(PRESTO_CLIENT_INFO, "{\"clientVersion\":\"testVersion\"}") .addHeader(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") - .addHeader(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43") + .addHeader(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43") .addHeader(PRESTO_PREPARED_STATEMENT, "foo=select * from bar") .build(); @@ -146,7 +146,7 @@ public void testQuery() // verify session properties assertEquals(queryInfo.getSession().getSystemProperties(), ImmutableMap.builder() .put(QUERY_MAX_MEMORY, "1GB") - .put(DISTRIBUTED_JOIN, "true") + .put(JOIN_DISTRIBUTION_TYPE, "repartitioned") .put(HASH_PARTITION_COUNT, "43") .build()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 780c48507c263..bb4db2ba43b46 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -21,6 +21,10 @@ import java.util.Map; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPLICATED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.facebook.presto.sql.analyzer.RegexLibrary.RE2J; import static io.airlift.configuration.testing.ConfigAssertions.assertDeprecatedEquivalence; @@ -40,10 +44,10 @@ public void testDefaults() .setNetworkCostWeight(0.25) .setResourceGroupsEnabled(false) .setDistributedIndexJoinsEnabled(false) - .setDistributedJoinsEnabled(true) + .setJoinDistributionType(REPARTITIONED) .setFastInequalityJoins(true) .setColocatedJoinsEnabled(false) - .setJoinReorderingEnabled(true) + .setJoinReorderingStrategy(ELIMINATE_CROSS_JOINS) .setRedistributeWrites(true) .setOptimizeMetadataQueries(false) .setOptimizeHashGeneration(true) @@ -86,10 +90,10 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") - .put("distributed-joins-enabled", "false") + .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "false") + .put("optimizer.join-reordering-strategy", "NONE") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -122,10 +126,10 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") - .put("distributed-joins-enabled", "false") + .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "false") + .put("optimizer.join-reordering-strategy", "NONE") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -155,10 +159,10 @@ public void testExplicitPropertyMappings() .setIterativeOptimizerEnabled(false) .setIterativeOptimizerTimeout(new Duration(10, SECONDS)) .setDistributedIndexJoinsEnabled(true) - .setDistributedJoinsEnabled(false) + .setJoinDistributionType(REPLICATED) .setFastInequalityJoins(false) .setColocatedJoinsEnabled(true) - .setJoinReorderingEnabled(false) + .setJoinReorderingStrategy(NONE) .setRedistributeWrites(false) .setOptimizeMetadataQueries(true) .setOptimizeHashGeneration(false) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index a78c856247075..6d002c104dd7d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -37,12 +37,14 @@ final class JoinMatcher private final JoinNode.Type joinType; private final List> equiCriteria; private final Optional filter; + private final Optional distributionType; - JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter) + JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter, Optional distributionType) { this.joinType = requireNonNull(joinType, "joinType is null"); this.equiCriteria = requireNonNull(equiCriteria, "equiCriteria is null"); this.filter = requireNonNull(filter, "filter can not be null"); + this.distributionType = requireNonNull(distributionType, "distribtuionType cannot be null"); } @Override @@ -81,6 +83,10 @@ public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Ses } } + if (distributionType.isPresent() && !distributionType.equals(joinNode.getDistributionType())) { + return NO_MATCH; + } + /* * Have to use order-independent comparison; there are no guarantees what order * the equi criteria will have after planning and optimizing. diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 9e39e36ec9506..ce7c03dabffef 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -266,12 +266,18 @@ public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, PlanMatchPattern left, PlanMatchPattern right) + { + return join(joinType, expectedEquiCriteria, expectedFilter, Optional.empty(), left, right); + } + + public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, Optional distributionType, PlanMatchPattern left, PlanMatchPattern right) { return node(JoinNode.class, left, right).with( new JoinMatcher( joinType, expectedEquiCriteria, - expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))))); + expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))), + distributionType)); } public static PlanMatchPattern exchange(PlanMatchPattern... sources) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java new file mode 100644 index 0000000000000..d6fb0ea483d0a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +public class BenchmarkReorderJoinsConnectedGraph +{ + @Benchmark + public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo) + { + return benchmarkInfo.getQueryRunner().execute(benchmarkInfo.getQuery()); + } + + @State(Thread) + public static class BenchmarkInfo + { + @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"}) + private String joinReorderingStrategy; + + @Param({"2", "4", "6", "8", "10"}) + private int numberOfTables; + + private String query; + private LocalQueryRunner queryRunner; + + @Setup + public void setup() + { + checkState(numberOfTables >= 2, "numberOfTables must be >= 2"); + Session session = testSessionBuilder() + .setSystemProperty("join_reordering_strategy", joinReorderingStrategy) + .setSystemProperty("join_distribution_type", "AUTOMATIC") + .setCatalog("tpch") + .setSchema("tiny") + .build(); + queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("EXPLAIN SELECT * FROM NATION n1"); + for (int i = 2; i <= numberOfTables; i++) { + stringBuilder.append(format(" JOIN nation n%s on n%s.nationkey = n%s.nationkey", i, i - 1, i)); + } + query = stringBuilder.toString(); + } + + public String getQuery() + { + return query; + } + + public QueryRunner getQueryRunner() + { + return queryRunner; + } + + @TearDown + public void tearDown() + { + queryRunner.close(); + } + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkReorderJoinsConnectedGraph.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java new file mode 100644 index 0000000000000..13c1644c17855 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +public class BenchmarkReorderJoinsLinearGraph +{ + @Benchmark + public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo) + { + return benchmarkInfo.getQueryRunner().execute( + "EXPLAIN SELECT * FROM " + + "nation n1 JOIN nation n2 ON n1.nationkey = n2.nationkey " + + "JOIN nation n3 on n2.comment = n3.comment " + + "JOIN nation n4 on n3.name = n4.name " + + "JOIN region r1 on n4.regionkey = r1.regionkey " + + "JOIN region r2 on r2.name = r2.name " + + "JOIN region r3 on r3.comment = r2.comment " + + "join region r4 on r4.regionkey = r3.regionkey"); + } + + @State(Thread) + public static class BenchmarkInfo + { + @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"}) + private String joinReorderingStrategy; + + private LocalQueryRunner queryRunner; + + @Setup + public void setup() + { + Session session = testSessionBuilder() + .setSystemProperty("join_reordering_strategy", joinReorderingStrategy) + .setSystemProperty("join_distribution_type", "AUTOMATIC") + .setCatalog("tpch") + .setSchema("tiny") + .build(); + queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); + } + + public QueryRunner getQueryRunner() + { + return queryRunner; + } + + @TearDown + public void tearDown() + { + queryRunner.close(); + } + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkReorderJoinsLinearGraph.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index e46826bfb567e..3609523de8c5b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -35,7 +35,7 @@ import java.util.Optional; import java.util.function.Function; -import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS; +import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; @@ -60,7 +60,7 @@ public class TestEliminateCrossJoins public void testEliminateCrossJoin() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( join(INNER, @@ -79,7 +79,7 @@ public void testEliminateCrossJoin() public void testRetainOutgoingGroupReferences() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( node(JoinNode.class, @@ -96,7 +96,7 @@ public void testRetainOutgoingGroupReferences() public void testDoNotReorderOuterJoin() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(JoinNode.Type.LEFT)) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java new file mode 100644 index 0000000000000..c2449777969c1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Set; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.testing.Closeables.closeAllRuntimeException; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; + +public class TestJoinEnumerator +{ + private LocalQueryRunner queryRunner; + + @BeforeClass + public void setUp() + { + queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(queryRunner); + queryRunner = null; + } + + @Test + public void testGeneratePartitions() + { + Set> partitions = generatePartitions(4).collect(toImmutableSet()); + assertEquals(partitions, + ImmutableSet.of( + ImmutableSet.of(0), + ImmutableSet.of(0, 1), + ImmutableSet.of(0, 2), + ImmutableSet.of(0, 3), + ImmutableSet.of(0, 1, 2), + ImmutableSet.of(0, 1, 3), + ImmutableSet.of(0, 2, 3))); + + partitions = generatePartitions(3).collect(toImmutableSet()); + assertEquals(partitions, + ImmutableSet.of( + ImmutableSet.of(0), + ImmutableSet.of(0, 1), + ImmutableSet.of(0, 2))); + } + + @Test + public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + MultiJoinNode multiJoinNode = new MultiJoinNode( + ImmutableList.of(planBuilder.values(a1), planBuilder.values(b1)), + TRUE_LITERAL, + ImmutableList.of(a1, b1)); + ReorderJoins.JoinEnumerator joinEnumerator = new ReorderJoins.JoinEnumerator( + idAllocator, + new SymbolAllocator(), + queryRunner.getDefaultSession(), + queryRunner.getLookup(), + multiJoinNode.getFilter(), + new CostComparator(1, 1, 1)); + JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), ImmutableSet.of(0)); + assertFalse(actual.getPlanNode().isPresent()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java index 7bf5a19317a4d..0aed522491b14 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java @@ -24,7 +24,6 @@ import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Memo; -import com.facebook.presto.sql.planner.iterative.MemoBasedLookup; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.LocalQueryRunner; @@ -60,7 +59,7 @@ public void testResolvesGroupReferenceNode() PlanNode plan = node(source); Memo memo = new Memo(idAllocator, plan); - MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); + Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); PlanNode memoSource = Iterables.getOnlyElement(memo.getNode(memo.getRootGroup()).getSources()); checkState(memoSource instanceof GroupReference, "expected GroupReference"); assertEquals(lookup.resolve(memoSource), source); @@ -71,7 +70,7 @@ public void testComputesStatsAndResolvesNodes() { PlanNode plan = node(node(node())); Memo memo = new Memo(idAllocator, plan); - MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); + Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); PlanNodeStatsEstimate actualStats = lookup.getStats(memo.getNode(memo.getRootGroup()), queryRunner.getDefaultSession(), ImmutableMap.of()); PlanNodeStatsEstimate expectedStats = PlanNodeStatsEstimate.builder().setOutputRowCount(3).build(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java new file mode 100644 index 0000000000000..1f9d8d402de64 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java @@ -0,0 +1,248 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; +import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Type.ADD; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.NOT_EQUAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; + +public class TestMultiJoinNodeBuilder +{ + private final LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + + @Test(expectedExceptions = IllegalStateException.class) + public void testDoesNotFireForOuterJoins() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), queryRunner.getMetadata()); + JoinNode outerJoin = p.join( + JoinNode.Type.FULL, + p.values(p.symbol("A1", BIGINT)), + p.values(p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty()); + toMultiJoinNode(outerJoin, queryRunner.getLookup()); + } + + @Test + public void testDoesNotConvertNestedOuterJoins() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + JoinNode leftJoin = planBuilder.join( + LEFT, + planBuilder.values(a1), + planBuilder.values(b1), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); + ValuesNode valuesC = planBuilder.values(c1); + JoinNode joinNode = planBuilder.join( + INNER, + leftJoin, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)), + ImmutableList.of(a1, b1, c1), + Optional.empty()); + + MultiJoinNode expected = new MultiJoinNode(ImmutableList.of(leftJoin, valuesC), new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()), ImmutableList.of(a1, b1, c1)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testRetainsOutputSymbols() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol b2 = planBuilder.symbol("B2", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol c2 = planBuilder.symbol("C2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1, b2); + ValuesNode valuesC = planBuilder.values(c1, c2); + JoinNode joinNode = planBuilder.join( + INNER, + valuesA, + planBuilder.join( + INNER, + valuesB, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)), + ImmutableList.of( + b1, + b2, + c1, + c2), + Optional.empty()), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC), + and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference())), + ImmutableList.of(a1, b1)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testCombinesCriteriaAndFilters() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol b2 = planBuilder.symbol("B2", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol c2 = planBuilder.symbol("C2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1, b2); + ValuesNode valuesC = planBuilder.values(c1, c2); + Expression bcFilter = and( + new ComparisonExpression(GREATER_THAN, c2.toSymbolReference(), new LongLiteral("0")), + new ComparisonExpression(NOT_EQUAL, c2.toSymbolReference(), new LongLiteral("7")), + new ComparisonExpression(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); + ComparisonExpression abcFilter = new ComparisonExpression( + LESS_THAN, + new ArithmeticBinaryExpression(ADD, a1.toSymbolReference(), c1.toSymbolReference()), + b1.toSymbolReference()); + JoinNode joinNode = planBuilder.join( + INNER, + valuesA, + planBuilder.join( + INNER, + valuesB, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)), + ImmutableList.of( + b1, + b2, + c1, + c2), + Optional.of(bcFilter)), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1, b2, c1, c2), + Optional.of(abcFilter)); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC), + and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter), + ImmutableList.of(a1, b1, b2, c1, c2)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testConvertsBushyTrees() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol d1 = planBuilder.symbol("D1", BIGINT); + Symbol d2 = planBuilder.symbol("D2", BIGINT); + Symbol e1 = planBuilder.symbol("E1", BIGINT); + Symbol e2 = planBuilder.symbol("E2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1); + ValuesNode valuesC = planBuilder.values(c1); + ValuesNode valuesD = planBuilder.values(d1, d2); + ValuesNode valuesE = planBuilder.values(e1, e2); + JoinNode joinNode = planBuilder.join( + INNER, + planBuilder.join( + INNER, + planBuilder.join( + INNER, + valuesA, + valuesB, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()), + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)), + ImmutableList.of(a1, b1, c1), + Optional.empty()), + planBuilder.join( + INNER, + valuesD, + valuesE, + ImmutableList.of( + new JoinNode.EquiJoinClause(d1, e1), + new JoinNode.EquiJoinClause(d2, e2)), + ImmutableList.of( + d1, + d2, + e1, + e2), + Optional.empty()), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, e1)), + ImmutableList.of( + a1, + b1, + c1, + d1, + d2, + e1, + e2), + Optional.empty()); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC, valuesD, valuesE), + and( + new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), + new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()), + new ComparisonExpression(EQUAL, d1.toSymbolReference(), e1.toSymbolReference()), + new ComparisonExpression(EQUAL, d2.toSymbolReference(), e2.toSymbolReference()), + new ComparisonExpression(EQUAL, b1.toSymbolReference(), e1.toSymbolReference())), + ImmutableList.of(a1, b1, c1, d1, d2, e1, e2)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + private static void assertMultijoinEquals(MultiJoinNode actual, MultiJoinNode expected) + { + assertEquals(ImmutableSet.copyOf(actual.getSources()), ImmutableSet.copyOf(expected.getSources())); + assertEquals(ImmutableSet.copyOf(extractConjuncts(actual.getFilter())), ImmutableSet.copyOf(extractConjuncts(expected.getFilter()))); + assertEquals(actual.getOutputSymbols(), expected.getOutputSymbols()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java new file mode 100644 index 0000000000000..c64c901c3184b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -0,0 +1,382 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.spi.TestingColumnHandle; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingStatsCalculator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithFakeNodeCountForStats; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static io.airlift.testing.Closeables.closeAllRuntimeException; + +public class TestReorderJoins +{ + private RuleTester tester; + private StatsCalculator statsCalculator; + private CostCalculator costCalculator; + + @BeforeClass + public void setUp() + { + Session session = testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .setSystemProperty("join_distribution_type", "automatic") + .setSystemProperty("join_reordering_strategy", "COST_BASED") + .build(); + LocalQueryRunner queryRunner = queryRunnerWithFakeNodeCountForStats(session, 4); + statsCalculator = queryRunner.getStatsCalculator(); + costCalculator = queryRunner.getEstimatedExchangesCostCalculator(); + tester = new RuleTester(queryRunner); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(tester); + tester = null; + } + + @Test + public void testKeepsOutputSymbols() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(5000) + .addSymbolStatistics(ImmutableMap.of( + new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), + new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT), p.symbol("A2", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A2", BIGINT)), + Optional.empty())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0, "A2", 1)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testReplicatesAndFlipsWhenOneTableMuchSmaller() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("B1", "A1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("B1", 0)), + values(ImmutableMap.of("A1", 0)) + )); + } + + @Test + public void testRepartitionsWhenRequiredBySession() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .setSystemProperty("join_distribution_type", "REPARTITIONED") + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("B1", "A1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("B1", 0)), + values(ImmutableMap.of("A1", 0)) + )); + } + + @Test + public void testRepartitionsWhenBothTablesEqual() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testReplicatesWhenRequiredBySession() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .setSystemProperty("join_distribution_type", "REPLICATED") + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("A1", 0)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testDoesNotFireForCrossJoin() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireWithNoStats() + { + StatsCalculator testingStatsCalculator = new UnknownStatsCalculator(); + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.tableScan(ImmutableList.of(p.symbol("A1", BIGINT)), ImmutableMap.of(p.symbol("A1", BIGINT), new TestingColumnHandle("A1"))), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT)), + Optional.empty())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireForNonDeterministicFilter() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), statsCalculator, costCalculator)) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.of(new ComparisonExpression(ComparisonExpressionType.LESS_THAN, p.symbol("A1", BIGINT).toSymbolReference(), new FunctionCall(QualifiedName.of("random"), ImmutableList.of()))))) + .doesNotFire(); + } + + @Test + public void testPredicatesPushedDown() + { + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesB"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(5) + .addSymbolStatistics(ImmutableMap.of( + new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), + new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesC"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) + .on(p -> + p.join( + INNER, + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)), + ImmutableList.of(), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)), + Optional.empty()), + p.values(new PlanNodeId("valuesC"), p.symbol("C1", BIGINT)), + ImmutableList.of( + new JoinNode.EquiJoinClause(p.symbol("B2", BIGINT), p.symbol("C1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT)), + Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1", BIGINT).toSymbolReference(), p.symbol("B1", BIGINT).toSymbolReference())))) + .matches( + join( + INNER, + ImmutableList.of(equiJoinClause("C1", "B2")), + values("C1"), + join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + values("A1"), + values("B1", "B2")) + ) + ); + } + + private static class UnknownStatsCalculator + implements StatsCalculator + { + @Override + public PlanNodeStatsEstimate calculateStats( + PlanNode planNode, + Lookup lookup, + Session session, + Map types) + { + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.buildFrom(UNKNOWN_STATS); + planNode.getOutputSymbols() + .forEach(symbol -> statsBuilder.addSymbolStatistics(symbol, SymbolStatsEstimate.UNKNOWN_STATS)); + return statsBuilder.build(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 64ed14842aabc..e4dafc7e55bfa 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -42,6 +42,7 @@ import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -90,9 +91,14 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) } public ValuesNode values(Symbol... columns) + { + return values(idAllocator.getNextId(), columns); + } + + public ValuesNode values(PlanNodeId id, Symbol... columns) { return new ValuesNode( - idAllocator.getNextId(), + id, ImmutableList.copyOf(columns), ImmutableList.of()); } @@ -320,6 +326,11 @@ public ExchangeNode exchange(Consumer exchangeBuilderConsumer) return exchangeBuilder.build(); } + public JoinNode join(JoinNode.Type type, PlanNode left, PlanNode right, List criteria, List outputSymbols, Optional filter) + { + return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter, Optional.empty(), Optional.empty(), Optional.empty()); + } + public class ExchangeBuilder { private ExchangeNode.Type type = ExchangeNode.Type.GATHER; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 2b79a9d079926..f125cc50c1499 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -39,25 +39,35 @@ import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static org.testng.Assert.fail; public class RuleAssert { private final Metadata metadata; + private StatsCalculator statsCalculator; + private final CostCalculator costCalculator; private Session session; private final Rule rule; private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private Map symbols; + private Lookup lookup; private PlanNode plan; private final TransactionManager transactionManager; private final AccessControl accessControl; - private final StatsCalculator statsCalculator; - private final CostCalculator costCalculator; - - public RuleAssert(Metadata metadata, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl, StatsCalculator statsCalculator, CostCalculator costCalculator) + private Memo memo; + + public RuleAssert( + Metadata metadata, + Session session, + Rule rule, + TransactionManager transactionManager, + AccessControl accessControl, + StatsCalculator statsCalculator, + CostCalculator costCalculator) { this.metadata = metadata; this.session = session; @@ -81,6 +91,13 @@ public RuleAssert withSession(Session session) return this; } + public RuleAssert withStatsCalculator(StatsCalculator statsCalculator) + { + checkState(lookup == null, "lookup has been set"); + this.statsCalculator = statsCalculator; + return this; + } + public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); @@ -88,6 +105,8 @@ public RuleAssert on(Function planProvider) PlanBuilder builder = new PlanBuilder(idAllocator, metadata); plan = planProvider.apply(builder); symbols = builder.getSymbols(); + memo = new Memo(idAllocator, plan); + lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator); return this; } @@ -143,8 +162,6 @@ public void matches(PlanMatchPattern pattern) private RuleApplication applyRule() { SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); - Memo memo = new Memo(idAllocator, plan); - Lookup lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator); if (!rule.getPattern().matches(plan)) { return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java index 40a6321583d92..f1589980d0a02 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java @@ -59,7 +59,7 @@ public class TestReorderJoins public TestReorderJoins() { - super(ImmutableMap.of(SystemSessionProperties.REORDER_JOINS, "true")); + super(ImmutableMap.of(SystemSessionProperties.JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS")); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java index 0235bba4f0cba..f29253a8c2475 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java @@ -21,6 +21,6 @@ public class TestUnionWithReplicatedJoin { public TestUnionWithReplicatedJoin() { - super(ImmutableMap.of(SystemSessionProperties.DISTRIBUTED_JOIN, "false")); + super(ImmutableMap.of(SystemSessionProperties.JOIN_DISTRIBUTION_TYPE, "replicated")); } } diff --git a/presto-product-tests/conf/presto/etc/config.properties b/presto-product-tests/conf/presto/etc/config.properties index 44aef5efc782e..a7e918b66a97f 100644 --- a/presto-product-tests/conf/presto/etc/config.properties +++ b/presto-product-tests/conf/presto/etc/config.properties @@ -38,8 +38,6 @@ plugin.bundles=\ ../../../presto-sqlserver/pom.xml presto.version=testversion -distributed-joins-enabled=true query.max-memory-per-node=1GB query.max-memory=1GB redistribute-writes=false -reorder-joins=true diff --git a/presto-product-tests/conf/presto/etc/multinode-master.properties b/presto-product-tests/conf/presto/etc/multinode-master.properties index 5be52ed2cb2d4..98d6602e08447 100644 --- a/presto-product-tests/conf/presto/etc/multinode-master.properties +++ b/presto-product-tests/conf/presto/etc/multinode-master.properties @@ -15,4 +15,3 @@ query.max-memory=1GB query.max-memory-per-node=512MB discovery-server.enabled=true discovery.uri=http://presto-master:8080 -reorder-joins=true diff --git a/presto-product-tests/conf/presto/etc/singlenode.properties b/presto-product-tests/conf/presto/etc/singlenode.properties index d0e2cda196b10..a2e66146a4eb8 100644 --- a/presto-product-tests/conf/presto/etc/singlenode.properties +++ b/presto-product-tests/conf/presto/etc/singlenode.properties @@ -15,4 +15,3 @@ query.max-memory=2GB query.max-memory-per-node=1GB discovery-server.enabled=true discovery.uri=http://presto-master:8080 -reorder-joins=true diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java index 5be87ee48944b..025394cd81b56 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java @@ -14,6 +14,7 @@ package com.facebook.presto.tests.jdbc; import com.facebook.presto.jdbc.PrestoConnection; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.teradata.tempto.BeforeTestWithContext; import com.teradata.tempto.ProductTest; import com.teradata.tempto.Requirement; @@ -52,8 +53,6 @@ import static com.teradata.tempto.internal.convention.SqlResultDescriptor.sqlResultDescriptorForResource; import static com.teradata.tempto.query.QueryExecutor.defaultQueryExecutor; import static com.teradata.tempto.query.QueryExecutor.query; -import static java.lang.Boolean.FALSE; -import static java.lang.Boolean.TRUE; import static java.util.Locale.CHINESE; import static org.assertj.core.api.Assertions.assertThat; @@ -247,13 +246,14 @@ public void testSqlEscapeFunctions() public void testSessionProperties() throws SQLException { - final String distributedJoin = "distributed_join"; + final String joinDistributionType = "join_distribution_type"; + final String defaultValue = new FeaturesConfig().getJoinDistributionType().name(); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString()); - setSessionProperty(connection, distributedJoin, FALSE.toString()); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(FALSE.toString()); - resetSessionProperty(connection, distributedJoin); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString()); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue); + setSessionProperty(connection, joinDistributionType, "REPLICATED"); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo("REPLICATED"); + resetSessionProperty(connection, joinDistributionType); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue); } private QueryResult queryResult(Statement statement, String query) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 6f683b3ef7578..7a767c2264b21 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -171,11 +171,10 @@ private static TestingPrestoServer createTestingPrestoServer(URI discoveryUri, b .put("compiler.interpreter-enabled", "false") .put("task.max-index-memory", "16kB") // causes index joins to fault load .put("datasources", "system") - .put("distributed-index-joins-enabled", "true") .put("optimizer.optimize-mixed-distinct-aggregations", "true"); if (coordinator) { propertiesBuilder.put("node-scheduler.include-coordinator", "true"); - propertiesBuilder.put("distributed-joins-enabled", "true"); + propertiesBuilder.put("join-distribution-type", "REPARTITIONED"); } HashMap properties = new HashMap<>(propertiesBuilder.build()); properties.putAll(extraProperties);