diff --git a/pom.xml b/pom.xml
index e733d6a8c3d60..5d09e062b9a53 100644
--- a/pom.xml
+++ b/pom.xml
@@ -588,6 +588,12 @@
42.0.0
+
+ org.pcollections
+ pcollections
+ 2.1.2
+
+
org.antlr
antlr4-runtime
diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml
index aff146e980a21..be416992230d2 100644
--- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml
+++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml
@@ -2,7 +2,7 @@ datasource: presto
query-names: presto/tpcds/${query}.sql
runs: 6
prewarm-runs: 2
-before-execution: sleep-4s, presto/session_set_reorder_joins.sql
+before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql
frequency: 7
database: hive
tpcds_small: tpcds_10gb_orc
@@ -12,22 +12,22 @@ variables:
1:
query: q01,q06,q14_1,q39_1,q39_2,q47,q57,q67,q81
schema: ${tpcds_small}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
2:
query: q02,q03,q04,q05,q07,q09,q10,q11,q13,q14_2,q16,q17,q19,q22,q23_1,q23_2,q24_1,q24_2,q25,q28,q29,q30,q31,q32,q33,q35,q37,q38,q42,q43,q44,q46,q48,q49,q50,q51,q52,q53,q54,q55,q56,q58,q59,q60,q61,q63,q65,q66,q68,q69,q70,q71,q72,q74,q75,q77,q78,q80,q82,q88,q89,q94,q95
schema: ${tpcds_medium}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
3:
# query not passing quick enough without reordering
query: q18,q64
schema: ${tpcds_medium}
- reorder_joins: true
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS
4:
query: q08,q12,q15,q20,q21,q26,q27,q34,q36,q40,q41,q45,q62,q73,q76,q79,q83,q84,q85,q86,q87,q90,q91,q92,q93,q96,q97,q98,q99
schema: ${tpcds_large}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
5:
# extra runs with reordering on 1tb schema (too slow without reordering on 1tb). For 100g we keep both runs, with and without reordering
query: q03,q37,q42,q43,q52,q53
schema: ${tpcds_large}
- reorder_joins: true
+ join_reordering_strategy: ELIMINATE_CROSS_JOIN
diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml
index 3a19db13125a7..2463762311a10 100644
--- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml
+++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml
@@ -2,7 +2,7 @@ datasource: presto
query-names: presto/tpch/${query}.sql
runs: 6
prewarm-runs: 2
-before-execution: sleep-4s, presto/session_set_reorder_joins.sql
+before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql
frequency: 7
database: hive
tpch_small: tpch_10gb_orc
@@ -14,28 +14,28 @@ variables:
# queries too slow to run on 100gb without reordering
query: q2, q8, q9
schema: ${tpch_small}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
2:
# queries too slow to run on 100gb without reordering
query: q8, q9
schema: ${tpch_medium}
- reorder_joins: true
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS
3:
# queries too slow to run on 100gb without reordering
query: q2
schema: ${tpch_large}
- reorder_joins: true
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS
4:
# queries too slow to run on 1tb
query: q3, q4, q5, q7, q17, q18, q21
schema: ${tpch_medium}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
5:
query: q10, q11, q12, q13, q14, q15, q16, q19, q20, q22
schema: ${tpch_large}
- reorder_joins: true, false
+ join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE
6:
# queries without joins
query: q1, q6
schema: ${tpch_large}
- reorder_joins: false
+ join_reordering_strategy: NONE
diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql
new file mode 100644
index 0000000000000..a832822eda53a
--- /dev/null
+++ b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql
@@ -0,0 +1 @@
+SET SESSION join_reordering_strategy='${join_reordering_strategy}'
diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql
deleted file mode 100644
index 43d4a7faa4705..0000000000000
--- a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql
+++ /dev/null
@@ -1 +0,0 @@
-SET SESSION reorder_joins='${reorder_joins}'
diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java
index 1b4ebb9a0d725..a15ac6f661ea9 100644
--- a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java
+++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java
@@ -143,27 +143,27 @@ public void testUpdateSessionParameters()
ClientSession session = options.toClientSession();
SqlParser sqlParser = new SqlParser();
- ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "distributed_join", "true");
+ ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "join_distribution_type", "repartitioned");
ImmutableMap preparedStatements = ImmutableMap.of("my_query", "select * from foo");
session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog.test_schema"), session, existingProperties, preparedStatements);
assertEquals(session.getCatalog(), "test_catalog");
assertEquals(session.getSchema(), "test_schema");
assertEquals(session.getProperties().get("query_max_memory"), "10GB");
- assertEquals(session.getProperties().get("distributed_join"), "true");
+ assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned");
assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo");
session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_schema_b"), session, existingProperties, preparedStatements);
assertEquals(session.getCatalog(), "test_catalog");
assertEquals(session.getSchema(), "test_schema_b");
assertEquals(session.getProperties().get("query_max_memory"), "10GB");
- assertEquals(session.getProperties().get("distributed_join"), "true");
+ assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned");
assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo");
session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog_2.test_schema"), session, existingProperties, preparedStatements);
assertEquals(session.getCatalog(), "test_catalog_2");
assertEquals(session.getSchema(), "test_schema");
assertEquals(session.getProperties().get("query_max_memory"), "10GB");
- assertEquals(session.getProperties().get("distributed_join"), "true");
+ assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned");
assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo");
}
}
diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst
index 6284c52f07bf1..2b79bcc5562a1 100644
--- a/presto-docs/src/main/sphinx/admin/properties.rst
+++ b/presto-docs/src/main/sphinx/admin/properties.rst
@@ -6,26 +6,34 @@ This section describes the most important config properties that
may be used to tune Presto or alter its behavior when required.
.. contents::
- :local:
+:local:
:backlinks: none
- :depth: 1
+ :depth: 1
General Properties
------------------
-``distributed-joins-enabled``
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+``join-distribution-type``
+^^^^^^^^^^^^^^^^^^^^^^^^^^
- * **Type:** ``boolean``
- * **Default value:** ``true``
-
- Use hash distributed joins instead of broadcast joins. Distributed joins
- require redistributing both tables using a hash of the join key. This can
- be slower (sometimes substantially) than broadcast joins, but allows much
- larger joins. Broadcast joins require that the tables on the right side of
- the join after filtering fit in memory on each node, whereas distributed joins
- only need to fit in distributed memory across all nodes. This can also be
- specified on a per-query basis using the ``distributed_join`` session property.
+ * **Type:** ``string``
+ * **Allowed values:** ``AUTOMATIC``, ``REPARTITIONED``, ``REPLICATED``
+ * **Default value:** ``REPARTITIONED``
+
+ The type of distributed join to use. When set to ``REPARTITIONED``, presto will
+ use hash distributed joins. When set to ``REPLICATED``, it will broadcast the
+ right table to all nodes in the cluster that have data from the left table.
+ Repartitioned joins require redistributing both tables using a hash of the join key.
+ This can be slower (sometimes substantially) than broadcast joins, but allows much
+ larger joins. In particular broadcast joins will be faster if the right table is
+ much smaller than the left. However, broadcast joins require that the tables on the right
+ side of the join after filtering fit in memory on each node, whereas distributed joins
+ only need to fit in distributed memory across all nodes. When set to ``AUTOMATIC``,
+ Presto will make a cost based decision as to which distribution type is optimal.
+ It will also consider switching the left and right inputs to the join. In ``AUTOMATIC``
+ mode, Presto will default to replicated joins if no cost could be computed, such as if
+ the tables do not have statistics. This can also be specified on a per-query basis using
+ the ``join_distribution_type`` session property.
``redistribute-writes``
^^^^^^^^^^^^^^^^^^^^^^^
@@ -368,6 +376,22 @@ Optimizer Properties
using the ``push_table_write_through_union`` session property.
+``optimizer.join-reordering-strategy``
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ * **Type:** ``string``
+ * **Allowed values:** ``COST_BASED``, ``ELIMINATE_CROSS_JOINS``, ``NONE``
+ * **Default value:** ``ELIMINATE_CROSS_JOINS``
+
+ The join reordering strategy to use. ``NONE`` maintains the order the tables are listed in the
+ query. ``ELIMINATE_CROSS_JOINS`` reorders joins to eliminate cross joins where possible and
+ otherwise maintains the original query order. When reordering joins it also strives to maintain the
+ original table order as much as possible. ``COST_BASED`` enumerates possible orders and uses
+ statistics-based cost estimation to determine the least cost order. If stats are not available or if
+ for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. This can
+ also be specified on a per-query basis using the ``join_reordering_strategy`` session property.
+
+
Regular Expression Function Properties
--------------------------------------
diff --git a/presto-main/etc/config.properties b/presto-main/etc/config.properties
index 44d7c148de4e0..30c32d6f1fcb3 100644
--- a/presto-main/etc/config.properties
+++ b/presto-main/etc/config.properties
@@ -41,5 +41,4 @@ plugin.bundles=\
../presto-postgresql/pom.xml
presto.version=testversion
-distributed-joins-enabled=true
node-scheduler.include-coordinator=true
diff --git a/presto-main/pom.xml b/presto-main/pom.xml
index dd535dce447c9..0aedfbf788d72 100644
--- a/presto-main/pom.xml
+++ b/presto-main/pom.xml
@@ -262,6 +262,11 @@
jgrapht-core
+
+ org.pcollections
+ pcollections
+
+
org.apache.bval
diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
index 3ed749a98a314..02d717269eefc 100644
--- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
+++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
@@ -20,6 +20,8 @@
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.session.PropertyMetadata;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
+import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
@@ -27,6 +29,7 @@
import javax.inject.Inject;
import java.util.List;
+import java.util.stream.Stream;
import static com.facebook.presto.spi.session.PropertyMetadata.booleanSessionProperty;
import static com.facebook.presto.spi.session.PropertyMetadata.integerSessionProperty;
@@ -37,11 +40,12 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.units.DataSize.succinctBytes;
import static java.lang.String.format;
+import static java.util.stream.Collectors.joining;
public final class SystemSessionProperties
{
public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation";
- public static final String DISTRIBUTED_JOIN = "distributed_join";
+ public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type";
public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join";
public static final String HASH_PARTITION_COUNT = "hash_partition_count";
public static final String PREFER_STREAMING_OPERATORS = "prefer_streaming_operators";
@@ -58,7 +62,7 @@ public final class SystemSessionProperties
public static final String DICTIONARY_AGGREGATION = "dictionary_aggregation";
public static final String PLAN_WITH_TABLE_NODE_PARTITIONING = "plan_with_table_node_partitioning";
public static final String COLOCATED_JOIN = "colocated_join";
- public static final String REORDER_JOINS = "reorder_joins";
+ public static final String JOIN_REORDERING_STRATEGY = "join_reordering_strategy";
public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node";
public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval";
public static final String OPTIMIZE_METADATA_QUERIES = "optimize_metadata_queries";
@@ -101,11 +105,18 @@ public SystemSessionProperties(
"Compute hash codes for distribution, joins, and aggregations early in query plan",
featuresConfig.isOptimizeHashGeneration(),
false),
- booleanSessionProperty(
- DISTRIBUTED_JOIN,
- "Use a distributed join instead of a broadcast join",
- featuresConfig.isDistributedJoinsEnabled(),
- false),
+ new PropertyMetadata<>(
+ JOIN_DISTRIBUTION_TYPE,
+ format("The join method to use. Options are %s",
+ Stream.of(JoinDistributionType.values())
+ .map(FeaturesConfig.JoinDistributionType::name)
+ .collect(joining(","))),
+ VARCHAR,
+ JoinDistributionType.class,
+ featuresConfig.getJoinDistributionType(),
+ false,
+ value -> JoinDistributionType.valueOf(((String) value).toUpperCase()),
+ JoinDistributionType::name),
booleanSessionProperty(
DISTRIBUTED_INDEX_JOIN,
"Distribute index joins on join keys instead of executing inline",
@@ -240,11 +251,18 @@ public SystemSessionProperties(
"Experimental: Adapt plan to pre-partitioned tables",
true,
false),
- booleanSessionProperty(
- REORDER_JOINS,
- "Experimental: Reorder joins to optimize plan",
- featuresConfig.isJoinReorderingEnabled(),
- false),
+ new PropertyMetadata<>(
+ JOIN_REORDERING_STRATEGY,
+ format("The join reordering strategy to use. Options are %s",
+ Stream.of(JoinReorderingStrategy.values())
+ .map(FeaturesConfig.JoinReorderingStrategy::name)
+ .collect(joining(","))),
+ VARCHAR,
+ JoinReorderingStrategy.class,
+ featuresConfig.getJoinReorderingStrategy(),
+ false,
+ value -> JoinReorderingStrategy.valueOf(((String) value).toUpperCase()),
+ JoinReorderingStrategy::name),
booleanSessionProperty(
FAST_INEQUALITY_JOINS,
"Use faster handling of inequality join if it is possible",
@@ -347,11 +365,6 @@ public static boolean isOptimizeHashGenerationEnabled(Session session)
return session.getSystemProperty(OPTIMIZE_HASH_GENERATION, Boolean.class);
}
- public static boolean isDistributedJoinEnabled(Session session)
- {
- return session.getSystemProperty(DISTRIBUTED_JOIN, Boolean.class);
- }
-
public static boolean isDistributedIndexJoinEnabled(Session session)
{
return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class);
@@ -427,9 +440,9 @@ public static boolean isFastInequalityJoin(Session session)
return session.getSystemProperty(FAST_INEQUALITY_JOINS, Boolean.class);
}
- public static boolean isJoinReorderingEnabled(Session session)
+ public static JoinReorderingStrategy getJoinReorderingStrategy(Session session)
{
- return session.getSystemProperty(REORDER_JOINS, Boolean.class);
+ return session.getSystemProperty(JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.class);
}
public static boolean isColocatedJoinEnabled(Session session)
@@ -515,4 +528,9 @@ public static boolean isUseNewStatsCalculator(Session session)
{
return session.getSystemProperty(USE_NEW_STATS_CALCULATOR, Boolean.class);
}
+
+ public static JoinDistributionType getJoinDistributionType(Session session)
+ {
+ return session.getSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.class);
+ }
}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java
new file mode 100644
index 0000000000000..4d57c0cc7eeac
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static com.google.common.base.Preconditions.checkState;
+import static java.util.Objects.requireNonNull;
+
+public class CachingCostCalculator
+ implements CostCalculator
+{
+ private final CostCalculator costCalculator;
+ private final Map costs = new HashMap<>();
+
+ public CachingCostCalculator(CostCalculator costCalculator)
+ {
+ this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
+ }
+
+ @Override
+ public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ if (!costs.containsKey(planNode)) {
+ // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback
+ PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(planNode, lookup, session, types);
+ requireNonNull(costs, "computed cost can not be null");
+ checkState(costs.put(planNode, cost) == null, "cost for " + planNode + " already computed");
+ }
+ return costs.get(planNode);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java
new file mode 100644
index 0000000000000..80d5947c06fa9
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static com.google.common.base.Preconditions.checkState;
+import static java.util.Objects.requireNonNull;
+
+public class CachingStatsCalculator
+ implements StatsCalculator
+{
+ private final StatsCalculator statsCalculator;
+ private final Map stats = new HashMap<>();
+
+ public CachingStatsCalculator(StatsCalculator statsCalculator)
+ {
+ this.statsCalculator = statsCalculator;
+ }
+
+ @Override
+ public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ if (!stats.containsKey(planNode)) {
+ // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback
+ PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types);
+ requireNonNull(stats, "computed stats can not be null");
+ checkState(stats.put(planNode, statsEstimate) == null, "statistics for " + planNode + " already computed");
+ }
+ return stats.get(planNode);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
index abccad58b89fc..784996cf30bf0 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
@@ -21,8 +21,6 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.inject.BindingAnnotation;
-import javax.annotation.concurrent.ThreadSafe;
-
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.Map;
@@ -36,7 +34,6 @@
* Computes estimated cost of executing given PlanNode.
* Implementation may use lookup to compute needed traits for self/source nodes.
*/
-@ThreadSafe
public interface CostCalculator
{
PlanNodeCostEstimate calculateCost(
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java
index 94bc8e96a5718..c8d4b20d88ff1 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java
@@ -18,9 +18,10 @@
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNode;
-import java.util.HashMap;
import java.util.Map;
+import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom;
+import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS;
import static com.google.common.base.Predicates.not;
public class EnsureStatsMatchOutput
@@ -29,16 +30,16 @@ public class EnsureStatsMatchOutput
@Override
public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types)
{
- Map symbolSymbolStats = new HashMap<>();
- estimate.getSymbolsWithKnownStatistics().stream()
- .filter(node.getOutputSymbols()::contains)
- .forEach(symbol -> symbolSymbolStats.put(symbol, estimate.getSymbolStatistics(symbol)));
+ PlanNodeStatsEstimate.Builder builder = buildFrom(estimate);
node.getOutputSymbols().stream()
.filter(not(estimate.getSymbolsWithKnownStatistics()::contains))
- .filter(not(symbolSymbolStats::containsKey))
- .forEach(symbol -> symbolSymbolStats.put(symbol, SymbolStatsEstimate.UNKNOWN_STATS));
+ .forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS));
+
+ estimate.getSymbolsWithKnownStatistics().stream()
+ .filter(not(node.getOutputSymbols()::contains))
+ .forEach(builder::removeSymbolStatistics);
- return PlanNodeStatsEstimate.buildFrom(estimate).setSymbolStatistics(symbolSymbolStats).build();
+ return builder.build();
}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java
new file mode 100644
index 0000000000000..8a95329713d91
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanNodeId;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static com.google.common.base.Preconditions.checkState;
+import static java.util.Objects.requireNonNull;
+
+public class JoinNodeCachingStatsCalculator
+ implements StatsCalculator
+{
+ private final StatsCalculator statsCalculator;
+ private final Map stats = new HashMap<>();
+
+ public JoinNodeCachingStatsCalculator(StatsCalculator statsCalculator)
+ {
+ this.statsCalculator = statsCalculator;
+ }
+
+ @Override
+ public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ if (!(planNode instanceof JoinNode)) {
+ return statsCalculator.calculateStats(planNode, lookup, session, types);
+ }
+
+ PlanNodeId key = planNode.getId();
+ if (!stats.containsKey(key)) {
+ // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback
+ PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types);
+ requireNonNull(stats, "computed stats can not be null");
+ checkState(stats.put(key, statsEstimate) == null, "statistics for " + planNode + " already computed");
+ }
+ return stats.get(key);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java
index 3d59f90dab49c..cf081a6926c5f 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java
@@ -14,14 +14,13 @@
package com.facebook.presto.cost;
import com.facebook.presto.sql.planner.Symbol;
-import com.google.common.collect.ImmutableMap;
+import org.pcollections.HashTreePMap;
+import org.pcollections.PMap;
-import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
-import java.util.stream.Collectors;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
@@ -34,13 +33,13 @@ public class PlanNodeStatsEstimate
public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10;
private final double outputRowCount;
- private final Map symbolStatistics;
+ private final PMap symbolStatistics;
- private PlanNodeStatsEstimate(double outputRowCount, Map symbolStatistics)
+ private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics)
{
checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative");
this.outputRowCount = outputRowCount;
- this.symbolStatistics = ImmutableMap.copyOf(symbolStatistics);
+ this.symbolStatistics = symbolStatistics;
}
/**
@@ -83,26 +82,15 @@ public PlanNodeStatsEstimate mapOutputRowCount(Function mappingF
public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction)
{
return buildFrom(this)
- .setSymbolStatistics(symbolStatistics.entrySet().stream()
- .collect(Collectors.toMap(
- Map.Entry::getKey,
- e -> {
- if (e.getKey().equals(symbol)) {
- return mappingFunction.apply(e.getValue());
- }
- return e.getValue();
- })))
+ .addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol)))
.build();
}
public PlanNodeStatsEstimate add(PlanNodeStatsEstimate other)
{
// TODO this is broken (it does not operate on symbol stats at all). Remove or fix
- ImmutableMap.Builder symbolsStatsBuilder = ImmutableMap.builder();
- symbolsStatsBuilder.putAll(symbolStatistics).putAll(other.symbolStatistics); // This may not count all information
-
- PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
- return statsBuilder.setSymbolStatistics(symbolsStatsBuilder.build())
+ return buildFrom(this)
+ .addSymbolStatistics(other.symbolStatistics)
.setOutputRowCount(getOutputRowCount() + other.getOutputRowCount())
.build();
}
@@ -153,14 +141,24 @@ public static Builder builder()
public static Builder buildFrom(PlanNodeStatsEstimate other)
{
- return builder().setOutputRowCount(other.getOutputRowCount())
- .setSymbolStatistics(other.symbolStatistics);
+ return new Builder(other.getOutputRowCount(), other.symbolStatistics);
}
public static final class Builder
{
- private double outputRowCount = NaN;
- private Map symbolStatistics = new HashMap<>();
+ private double outputRowCount;
+ private PMap symbolStatistics;
+
+ public Builder()
+ {
+ this(NaN, HashTreePMap.empty());
+ }
+
+ private Builder(double outputRowCount, PMap symbolStatistics)
+ {
+ this.outputRowCount = outputRowCount;
+ this.symbolStatistics = symbolStatistics;
+ }
public Builder setOutputRowCount(double outputRowCount)
{
@@ -168,15 +166,21 @@ public Builder setOutputRowCount(double outputRowCount)
return this;
}
- public Builder setSymbolStatistics(Map symbolStatistics)
+ public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics)
{
- this.symbolStatistics = new HashMap<>(symbolStatistics);
+ symbolStatistics = symbolStatistics.plus(symbol, statistics);
return this;
}
- public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics)
+ public Builder addSymbolStatistics(Map symbolStatistics)
+ {
+ this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics);
+ return this;
+ }
+
+ public Builder removeSymbolStatistics(Symbol symbol)
{
- this.symbolStatistics.put(symbol, statistics);
+ symbolStatistics = symbolStatistics.minus(symbol);
return this;
}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java
index 234bc6729bb29..713e9c757a308 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java
@@ -77,7 +77,7 @@ public Optional calculate(PlanNode node, Lookup lookup, S
return Optional.of(PlanNodeStatsEstimate.builder()
.setOutputRowCount(tableStatistics.getRowCount().getValue())
- .setSymbolStatistics(outputSymbolStats)
+ .addSymbolStatistics(outputSymbolStats)
.build());
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java
index 4ecc694ec6a00..012ea3a6b51db 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java
@@ -26,7 +26,6 @@
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.SymbolReference;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
@@ -103,7 +102,17 @@ public static Expression binaryExpression(LogicalBinaryExpression.Type type, Col
{
requireNonNull(type, "type is null");
requireNonNull(expressions, "expressions is null");
- Preconditions.checkArgument(!expressions.isEmpty(), "expressions is empty");
+
+ if (expressions.isEmpty()) {
+ switch (type) {
+ case AND:
+ return TRUE_LITERAL;
+ case OR:
+ return FALSE_LITERAL;
+ default:
+ throw new IllegalArgumentException("Unsupported LogicalBinaryExpression type");
+ }
+ }
// Build balanced tree for efficient recursive processing that
// preserves the evaluation order of the input expressions.
@@ -309,7 +318,8 @@ public static Expression normalize(Expression expression)
public static Expression rewriteIdentifiersToSymbolReferences(Expression expression)
{
- return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() {
+ return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter()
+ {
@Override
public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter treeRewriter)
{
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
index 4c7b3a0d9a99b..b5b3205dd0683 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
@@ -27,8 +27,11 @@
import java.nio.file.Paths;
import java.util.List;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI;
import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MINUTES;
@DefunctConfig({
@@ -42,11 +45,11 @@ public class FeaturesConfig
private double cpuCostWeight = 0.75;
private double memoryCostWeight = 0;
private double networkCostWeight = 0.25;
+
private boolean distributedIndexJoinsEnabled;
- private boolean distributedJoinsEnabled = true;
private boolean colocatedJoinsEnabled;
private boolean fastInequalityJoins = true;
- private boolean reorderJoins = true;
+ private JoinReorderingStrategy joinReorderingStrategy = ELIMINATE_CROSS_JOINS;
private boolean redistributeWrites = true;
private boolean optimizeMetadataQueries;
private boolean optimizeHashGeneration = true;
@@ -59,6 +62,7 @@ public class FeaturesConfig
private boolean legacyMapSubscript;
private boolean newMapBlock = true;
private boolean optimizeMixedDistinctAggregations;
+ private JoinDistributionType joinDistributionType = REPARTITIONED;
private boolean dictionaryAggregation;
private boolean resourceGroups;
@@ -77,6 +81,30 @@ public class FeaturesConfig
private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata
+ public enum JoinReorderingStrategy
+ {
+ ELIMINATE_CROSS_JOINS,
+ COST_BASED,
+ NONE
+ }
+
+ public enum JoinDistributionType
+ {
+ AUTOMATIC,
+ REPLICATED,
+ REPARTITIONED;
+
+ public boolean canRepartition()
+ {
+ return this == REPARTITIONED || this == AUTOMATIC;
+ }
+
+ public boolean canReplicate()
+ {
+ return this == REPLICATED || this == AUTOMATIC;
+ }
+ }
+
public double getCpuCostWeight()
{
return cpuCostWeight;
@@ -137,11 +165,6 @@ public FeaturesConfig setDistributedIndexJoinsEnabled(boolean distributedIndexJo
return this;
}
- public boolean isDistributedJoinsEnabled()
- {
- return distributedJoinsEnabled;
- }
-
@Config("deprecated.legacy-array-agg")
public FeaturesConfig setLegacyArrayAgg(boolean legacyArrayAgg)
{
@@ -190,13 +213,6 @@ public boolean isNewMapBlock()
return newMapBlock;
}
- @Config("distributed-joins-enabled")
- public FeaturesConfig setDistributedJoinsEnabled(boolean distributedJoinsEnabled)
- {
- this.distributedJoinsEnabled = distributedJoinsEnabled;
- return this;
- }
-
public boolean isColocatedJoinsEnabled()
{
return colocatedJoinsEnabled;
@@ -223,16 +239,16 @@ public boolean isFastInequalityJoins()
return fastInequalityJoins;
}
- public boolean isJoinReorderingEnabled()
+ public JoinReorderingStrategy getJoinReorderingStrategy()
{
- return reorderJoins;
+ return joinReorderingStrategy;
}
- @Config("reorder-joins")
- @ConfigDescription("Experimental: Reorder joins to optimize plan")
- public FeaturesConfig setJoinReorderingEnabled(boolean reorderJoins)
+ @Config("optimizer.join-reordering-strategy")
+ @ConfigDescription("The strategy to use for reordering joins")
+ public FeaturesConfig setJoinReorderingStrategy(JoinReorderingStrategy joinReorderingStrategy)
{
- this.reorderJoins = reorderJoins;
+ this.joinReorderingStrategy = joinReorderingStrategy;
return this;
}
@@ -490,4 +506,16 @@ public boolean isUseNewStatsCalculator()
{
return useNewStatsCalculator;
}
+
+ @Config("join-distribution-type")
+ public FeaturesConfig setJoinDistributionType(JoinDistributionType joinDistributionType)
+ {
+ this.joinDistributionType = requireNonNull(joinDistributionType, "joinDistributionType is null");
+ return this;
+ }
+
+ public JoinDistributionType getJoinDistributionType()
+ {
+ return joinDistributionType;
+ }
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index a1f4c44c28340..c5d07fec5a087 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -55,6 +55,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
+import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant;
import com.facebook.presto.sql.planner.iterative.rule.SingleMarkDistinctToGroupBy;
import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications;
@@ -322,13 +323,32 @@ public PlanOptimizers(
ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins()), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again
ImmutableSet.of(new EliminateCrossJoins())
),
+
new PredicatePushDown(metadata, sqlParser),
new IterativeOptimizer(
stats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new PushDownTableConstraints(metadata, sqlParser))),
- projectionPushDown);
+ projectionPushDown,
+ new PruneUnreferencedOutputs(),
+ new IterativeOptimizer(
+ stats,
+ statsCalculator,
+ estimatedExchangesCostCalculator,
+ ImmutableSet.of(new RemoveRedundantIdentityProjections())
+ ),
+
+ // Because ReorderJoins runs only once,
+ // PredicatePushDown, PruneUnreferenedOutputpus and RemoveRedundantIdentityProjections
+ // need to run beforehand in order to produce an optimal join order
+ // It also needs to run after EliminateCrossJoins so that its chosen order doesn't get undone.
+ new IterativeOptimizer(
+ stats,
+ statsCalculator,
+ estimatedExchangesCostCalculator,
+ ImmutableSet.of(new ReorderJoins(costComparator, statsCalculator, costCalculator))
+ ));
if (featuresConfig.isOptimizeSingleDistinct()) {
builder.add(
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java
index 2252155b6697f..d76bea4dd65bf 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java
@@ -15,6 +15,8 @@
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
+import com.facebook.presto.cost.CachingCostCalculator;
+import com.facebook.presto.cost.CachingStatsCalculator;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.matching.MatchingEngine;
@@ -81,7 +83,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types
}
Memo memo = new Memo(idAllocator, plan);
- Lookup lookup = new MemoBasedLookup(memo, statsCalculator, costCalculator);
+ Lookup lookup = Lookup.from(memo::resolve, new CachingStatsCalculator(statsCalculator), new CachingCostCalculator(costCalculator));
Duration timeout = SystemSessionProperties.getOptimizerTimeout(session);
exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session));
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java
index 65b2304952fcf..d8e5146690f1c 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java
@@ -102,7 +102,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types)
{
- return costCalculator.calculateCumulativeCost(node, this, session, types);
+ return costCalculator.calculateCumulativeCost(resolve(node), this, session, types);
}
};
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java
deleted file mode 100644
index b718fdfc5dedd..0000000000000
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.facebook.presto.sql.planner.iterative;
-
-import com.facebook.presto.Session;
-import com.facebook.presto.cost.CostCalculator;
-import com.facebook.presto.cost.PlanNodeCostEstimate;
-import com.facebook.presto.cost.PlanNodeStatsEstimate;
-import com.facebook.presto.cost.StatsCalculator;
-import com.facebook.presto.spi.type.Type;
-import com.facebook.presto.sql.planner.Symbol;
-import com.facebook.presto.sql.planner.plan.PlanNode;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static com.google.common.base.Preconditions.checkState;
-import static java.util.Objects.requireNonNull;
-
-public class MemoBasedLookup
- implements Lookup
-{
- private final Memo memo;
- private final Map stats = new HashMap<>();
- private final Map costs = new HashMap<>();
- private final StatsCalculator statsCalculator;
- private final CostCalculator costCalculator;
-
- public MemoBasedLookup(Memo memo, StatsCalculator statsCalculator, CostCalculator costCalculator)
- {
- this.memo = requireNonNull(memo, "memo can not be null");
- this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
- this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
- }
-
- @Override
- public PlanNode resolve(PlanNode node)
- {
- if (node instanceof GroupReference) {
- return memo.getNode(((GroupReference) node).getGroupId());
- }
- return node;
- }
-
- // todo[LO] maybe lookup passed to stats/cost calculator should be constrained so only
- // methods for obtaining traits and only for self and sources would be allowed?
-
- @Override
- public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types)
- {
- PlanNode key = resolve(planNode);
- if (!stats.containsKey(key)) {
- // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback
- PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(key, this, session, types);
- requireNonNull(stats, "computed stats can not be null");
- checkState(stats.put(key, statsEstimate) == null, "statistics for " + key + " already computed");
- }
- return stats.get(key);
- }
-
- @Override
- public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types)
- {
- PlanNode key = resolve(planNode);
- if (!costs.containsKey(key)) {
- // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback
- PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(key, this, session, types);
- requireNonNull(costs, "computed cost can not be null");
- checkState(costs.put(key, cost) == null, "cost for " + key + " already computed");
- }
- return costs.get(key);
- }
-}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java
index 2452f75ae7a40..a4e7907a8fff0 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java
@@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
-import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
@@ -41,6 +40,9 @@
import java.util.PriorityQueue;
import java.util.Set;
+import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
@@ -65,7 +67,8 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
return Optional.empty();
}
- if (!SystemSessionProperties.isJoinReorderingEnabled(session)) {
+ // we run this for cost_based reordering also for cases when some of the tables do not have statistics
+ if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) {
return Optional.empty();
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java
new file mode 100644
index 0000000000000..ea394fd43b8ef
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.tree.Expression;
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static com.facebook.presto.sql.ExpressionUtils.and;
+import static com.facebook.presto.sql.planner.DeterminismEvaluator.isDeterministic;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
+import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This class represents a set of inner joins that can be executed in any order.
+ */
+class MultiJoinNode
+{
+ private static final int JOIN_LIMIT = 10;
+
+ private final List sources;
+ private final Expression filter;
+ private final List outputSymbols;
+
+ public MultiJoinNode(List sources, Expression filter, List outputSymbols)
+ {
+ this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null"));
+ this.filter = requireNonNull(filter, "filter is null");
+ this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null"));
+
+ List inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList());
+ checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols");
+ }
+
+ public Expression getFilter()
+ {
+ return filter;
+ }
+
+ public List getSources()
+ {
+ return sources;
+ }
+
+ public List getOutputSymbols()
+ {
+ return outputSymbols;
+ }
+
+ static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup)
+ {
+ return new MultiJoinNodeBuilder(joinNode, lookup).toMultiJoinNode();
+ }
+
+ private static class MultiJoinNodeBuilder
+ {
+ private final List sources = new ArrayList<>();
+ private final List filters = new ArrayList<>();
+ private final List outputSymbols;
+ private final Lookup lookup;
+
+ MultiJoinNodeBuilder(JoinNode node, Lookup lookup)
+ {
+ requireNonNull(node, "node is null");
+ checkState(node.getType() == INNER, "join type must be INNER");
+ this.outputSymbols = node.getOutputSymbols();
+ this.lookup = requireNonNull(lookup, "lookup is null");
+ flattenNode(node);
+ }
+
+ private void flattenNode(PlanNode node)
+ {
+ PlanNode resolved = lookup.resolve(node);
+ if (resolved instanceof JoinNode && sources.size() < JOIN_LIMIT) {
+ JoinNode joinNode = (JoinNode) resolved;
+ if (joinNode.getType() == INNER && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL))) {
+ flattenNode(joinNode.getLeft());
+ flattenNode(joinNode.getRight());
+ joinNode.getCriteria().stream()
+ .map(JoinNode.EquiJoinClause::toExpression)
+ .forEach(filters::add);
+ joinNode.getFilter().ifPresent(filters::add);
+ return;
+ }
+ }
+ sources.add(node);
+ }
+
+ MultiJoinNode toMultiJoinNode()
+ {
+ return new MultiJoinNode(sources, and(filters), outputSymbols);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
new file mode 100644
index 0000000000000..f6dee21226964
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
@@ -0,0 +1,416 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.cost.CachingCostCalculator;
+import com.facebook.presto.cost.CachingStatsCalculator;
+import com.facebook.presto.cost.CostCalculator;
+import com.facebook.presto.cost.CostComparator;
+import com.facebook.presto.cost.JoinNodeCachingStatsCalculator;
+import com.facebook.presto.cost.PlanNodeCostEstimate;
+import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.planner.EqualityInference;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.SymbolsExtractor;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.FilterNode;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.tree.ComparisonExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.facebook.presto.sql.tree.SymbolReference;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Ordering;
+import com.google.common.collect.Sets;
+import io.airlift.log.Logger;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
+import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
+import static com.facebook.presto.cost.PlanNodeCostEstimate.INFINITE_COST;
+import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST;
+import static com.facebook.presto.sql.ExpressionUtils.and;
+import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED;
+import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference;
+import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode;
+import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT;
+import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT;
+import static com.facebook.presto.sql.planner.plan.Assignments.identity;
+import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
+import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
+import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Predicates.in;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static java.util.Objects.requireNonNull;
+import static java.util.stream.StreamSupport.stream;
+
+public class ReorderJoins
+ implements Rule
+{
+ private static final Logger log = Logger.get(ReorderJoins.class);
+
+ private final CostComparator costComparator;
+ private final StatsCalculator statsCalculator;
+ private final CostCalculator costCalculator;
+
+ public ReorderJoins(CostComparator costComparator, StatsCalculator statsCalculator, CostCalculator costCalculator)
+ {
+ this.costComparator = requireNonNull(costComparator, "costComparator is null");
+ this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
+ this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
+ }
+
+ @Override
+ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+ {
+ if (!(node instanceof JoinNode) || getJoinReorderingStrategy(session) != COST_BASED) {
+ return Optional.empty();
+ }
+
+ JoinNode joinNode = (JoinNode) node;
+ // We check that join distribution type is absent because we only want to do this transformation once (reordered joins will have distribution type already set).
+ // We check determinisitic filters because we can't reorder joins with non-deterministic filters
+ if (!(joinNode.getType() == INNER) || joinNode.getDistributionType().isPresent()) {
+ return Optional.empty();
+ }
+
+ MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, lookup);
+ if (multiJoinNode.getSources().size() < 2) {
+ return Optional.empty();
+ }
+
+ Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(statsCalculator)), new CachingCostCalculator(costCalculator));
+ JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols());
+ return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode();
+ }
+
+ @VisibleForTesting
+ static class JoinEnumerator
+ {
+ private final Map, JoinEnumerationResult> memo = new HashMap<>();
+ private final PlanNodeIdAllocator idAllocator;
+ private final Session session;
+ private final Ordering resultOrdering;
+ private final EqualityInference allInference;
+ private final Expression allFilter;
+ private final SymbolAllocator symbolAllocator;
+ private final Lookup lookup;
+
+ @VisibleForTesting
+ JoinEnumerator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator)
+ {
+ requireNonNull(idAllocator, "idAllocator is null");
+ requireNonNull(symbolAllocator, "symbolAllocator is null");
+ requireNonNull(session, "session is null");
+ requireNonNull(lookup, "lookup is null");
+ requireNonNull(filter, "filter is null");
+ requireNonNull(costComparator, "costComparator is null");
+ this.idAllocator = idAllocator;
+ this.symbolAllocator = symbolAllocator;
+ this.session = session;
+ this.lookup = lookup;
+ this.resultOrdering = getResultOrdering(costComparator, session);
+ this.allInference = createEqualityInference(filter);
+ this.allFilter = filter;
+ }
+
+ private static Ordering getResultOrdering(CostComparator costComparator, Session session)
+ {
+ return new Ordering()
+ {
+ @Override
+ public int compare(JoinEnumerationResult result1, JoinEnumerationResult result2)
+ {
+ return costComparator.compare(session, result1.cost, result2.cost);
+ }
+ };
+ }
+
+ private JoinEnumerationResult chooseJoinOrder(List sources, List outputSymbols)
+ {
+ Set multiJoinKey = ImmutableSet.copyOf(sources);
+ JoinEnumerationResult bestResult = memo.get(multiJoinKey);
+ if (bestResult == null) {
+ checkState(sources.size() > 1, "sources size is less than or equal to one");
+ ImmutableList.Builder resultBuilder = ImmutableList.builder();
+ Set> partitions = generatePartitions(sources.size()).collect(toImmutableSet());
+ for (Set partition : partitions) {
+ JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputSymbols, partition);
+ if (result.cost.isUnknown()) {
+ memo.put(multiJoinKey, result);
+ return result;
+ }
+ if (!result.cost.equals(INFINITE_COST)) {
+ resultBuilder.add(result);
+ }
+ }
+
+ List results = resultBuilder.build();
+ if (results.isEmpty()) {
+ memo.put(multiJoinKey, INFINITE_COST_RESULT);
+ return INFINITE_COST_RESULT;
+ }
+
+ bestResult = resultOrdering.min(resultBuilder.build());
+ memo.put(multiJoinKey, bestResult);
+ }
+ if (bestResult.planNode.isPresent()) {
+ log.debug("Least cost join was: " + bestResult.planNode.get().toString());
+ }
+ return bestResult;
+ }
+
+ /**
+ * This method generates all the ways of dividing totalNodes into two sets
+ * each containing at least one node. It will generate one set for each
+ * possible partitioning. The other partition is implied in the absent values.
+ * In order not to generate the inverse of any set, we always include the 0th
+ * node in our sets.
+ *
+ * @param totalNodes
+ * @return A set of sets each of which defines a partitioning of totalNodes
+ */
+ @VisibleForTesting
+ static Stream> generatePartitions(int totalNodes)
+ {
+ checkArgument(totalNodes >= 2, "totalNodes must be greater than or equal to 2");
+ Set numbers = IntStream.range(0, totalNodes)
+ .boxed()
+ .collect(toImmutableSet());
+ return Sets.powerSet(numbers).stream()
+ .filter(subSet -> subSet.contains(0))
+ .filter(subSet -> subSet.size() < numbers.size());
+ }
+
+ JoinEnumerationResult createJoinAccordingToPartitioning(List sources, List outputSymbols, Set partitioning)
+ {
+ Set leftSources = partitioning.stream()
+ .map(sources::get)
+ .collect(toImmutableSet());
+ Set rightSources = Sets.difference(ImmutableSet.copyOf(sources), ImmutableSet.copyOf(leftSources));
+ return createJoin(leftSources, rightSources, outputSymbols);
+ }
+
+ private JoinEnumerationResult createJoin(Set leftSources, Set rightSources, List outputSymbols)
+ {
+ Set leftSymbols = leftSources.stream()
+ .flatMap(node -> node.getOutputSymbols().stream())
+ .collect(toImmutableSet());
+ Set rightSymbols = rightSources.stream()
+ .flatMap(node -> node.getOutputSymbols().stream())
+ .collect(toImmutableSet());
+ ImmutableList.Builder joinPredicatesBuilder = ImmutableList.builder();
+
+ // add join conjucts that were not used for inference
+ stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false)
+ .map(conjuct -> allInference.rewriteExpression(conjuct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)))
+ .filter(Objects::nonNull)
+ // filter expressions that contain only left or right symbols
+ .filter(conjuct -> allInference.rewriteExpression(conjuct, leftSymbols::contains) == null)
+ .filter(conjuct -> allInference.rewriteExpression(conjuct, rightSymbols::contains) == null)
+ .forEach(joinPredicatesBuilder::add);
+
+ // create equality inference on available symbols
+ // TODO: make generateEqualitiesPartitionedBy take left and right scope
+ List joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities();
+ EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[joinEqualities.size()]));
+ joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities());
+
+ List joinPredicates = joinPredicatesBuilder.build();
+ List joinConditions = joinPredicates.stream()
+ .filter(JoinEnumerator::isJoinEqualityCondition)
+ .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftSymbols))
+ .collect(toImmutableList());
+ if (joinConditions.isEmpty()) {
+ return INFINITE_COST_RESULT;
+ }
+ List joinFilters = joinPredicates.stream()
+ .filter(predicate -> !isJoinEqualityCondition(predicate))
+ .collect(toImmutableList());
+
+ Set requiredJoinSymbols = ImmutableSet.builder()
+ .addAll(outputSymbols)
+ .addAll(SymbolsExtractor.extractUnique(joinPredicates))
+ .build();
+
+ JoinEnumerationResult leftResult = getJoinSource(
+ idAllocator,
+ ImmutableList.copyOf(leftSources),
+ requiredJoinSymbols.stream().filter(leftSymbols::contains).collect(toImmutableList()));
+ if (leftResult.cost.isUnknown()) {
+ return UNKNOWN_COST_RESULT;
+ }
+ if (leftResult.cost.equals(INFINITE_COST)) {
+ return INFINITE_COST_RESULT;
+ }
+ PlanNode left = leftResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present"));
+ JoinEnumerationResult rightResult = getJoinSource(
+ idAllocator,
+ ImmutableList.copyOf(rightSources),
+ requiredJoinSymbols.stream()
+ .filter(rightSymbols::contains)
+ .collect(toImmutableList()));
+ if (rightResult.cost.isUnknown()) {
+ return UNKNOWN_COST_RESULT;
+ }
+ if (rightResult.cost.equals(INFINITE_COST)) {
+ return INFINITE_COST_RESULT;
+ }
+ PlanNode right = rightResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present"));
+
+ // sort output symbols so that the left input symbols are first
+ List sortedOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream())
+ .filter(outputSymbols::contains)
+ .collect(toImmutableList());
+
+ // Cross joins can't filter symbols as part of the join
+ // If we're doing a cross join, use all output symbols from the inputs and add a project node
+ // on top
+ List joinOutputSymbols = sortedOutputSymbols;
+ if (joinConditions.isEmpty() && joinFilters.isEmpty()) {
+ joinOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream())
+ .collect(toImmutableList());
+ }
+
+ JoinEnumerationResult result = setJoinNodeProperties(new JoinNode(
+ idAllocator.getNextId(),
+ INNER,
+ left,
+ right,
+ joinConditions,
+ joinOutputSymbols,
+ joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty()));
+
+ if (!joinOutputSymbols.equals(sortedOutputSymbols)) {
+ PlanNode resultNode = new ProjectNode(idAllocator.getNextId(), result.planNode.get(), identity(sortedOutputSymbols));
+ result = new JoinEnumerationResult(lookup.getCumulativeCost(resultNode, session, symbolAllocator.getTypes()), Optional.of(resultNode));
+ }
+
+ return result;
+ }
+
+ private JoinEnumerationResult getJoinSource(PlanNodeIdAllocator idAllocator, List nodes, List outputSymbols)
+ {
+ PlanNode planNode;
+ if (nodes.size() == 1) {
+ planNode = getOnlyElement(nodes);
+ ImmutableList.Builder predicates = ImmutableList.builder();
+ predicates.addAll(allInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities());
+ stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false)
+ .map(conjuct -> allInference.rewriteExpression(conjuct, outputSymbols::contains))
+ .filter(Objects::nonNull)
+ .forEach(predicates::add);
+ Expression filter = combineConjuncts(predicates.build());
+ if (!(TRUE_LITERAL).equals(filter)) {
+ planNode = new FilterNode(idAllocator.getNextId(), planNode, filter);
+ }
+ return new JoinEnumerationResult(lookup.getCumulativeCost(planNode, session, symbolAllocator.getTypes()), Optional.of(planNode));
+ }
+ return chooseJoinOrder(nodes, outputSymbols);
+ }
+
+ private static boolean isJoinEqualityCondition(Expression expression)
+ {
+ return expression instanceof ComparisonExpression
+ && ((ComparisonExpression) expression).getType() == EQUAL
+ && ((ComparisonExpression) expression).getLeft() instanceof SymbolReference
+ && ((ComparisonExpression) expression).getRight() instanceof SymbolReference;
+ }
+
+ private static JoinNode.EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftSymbols)
+ {
+ Symbol leftSymbol = Symbol.from(equality.getLeft());
+ Symbol rightSymbol = Symbol.from(equality.getRight());
+ JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(leftSymbol, rightSymbol);
+ return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip();
+ }
+
+ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
+ {
+ List possibleJoinNodes = new ArrayList<>();
+ FeaturesConfig.JoinDistributionType joinDistributionType = getJoinDistributionType(session);
+ if (joinDistributionType.canRepartition() && !joinNode.isCrossJoin()) {
+ JoinNode node = joinNode.withDistributionType(PARTITIONED);
+ possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ node = node.flipChildren();
+ possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ }
+ if (joinDistributionType.canReplicate()) {
+ JoinNode node = joinNode.withDistributionType(REPLICATED);
+ possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ node = node.flipChildren();
+ possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ }
+ if (possibleJoinNodes.stream().anyMatch(result -> result.cost.isUnknown())) {
+ return UNKNOWN_COST_RESULT;
+ }
+ return resultOrdering.min(possibleJoinNodes);
+ }
+ }
+
+ @VisibleForTesting
+ static class JoinEnumerationResult
+ {
+ static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(UNKNOWN_COST, Optional.empty());
+ static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(INFINITE_COST, Optional.empty());
+
+ private final Optional planNode;
+ private final PlanNodeCostEstimate cost;
+
+ private JoinEnumerationResult(PlanNodeCostEstimate cost, Optional planNode)
+ {
+ this.cost = requireNonNull(cost);
+ this.planNode = requireNonNull(planNode);
+ checkArgument(cost.isUnknown() || cost.equals(INFINITE_COST) || planNode.isPresent(), "planNode must be present if cost is known");
+ }
+
+ public Optional getPlanNode()
+ {
+ return planNode;
+ }
+
+ public PlanNodeCostEstimate getCost()
+ {
+ return cost;
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java
index cb912fbbb0925..424a8f4f31354 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java
@@ -27,7 +27,7 @@
import java.util.Map;
import java.util.Optional;
-import static com.facebook.presto.SystemSessionProperties.isDistributedJoinEnabled;
+import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
@@ -111,9 +111,12 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context)
private JoinNode.DistributionType getTargetJoinDistributionType(JoinNode node)
{
+ if (node.getDistributionType().isPresent()) {
+ return node.getDistributionType().get();
+ }
// The implementation of full outer join only works if the data is hash partitioned. See LookupJoinOperators#buildSideOuterJoinUnvisitedPositions
JoinNode.Type type = node.getType();
- if (type == RIGHT || type == FULL || (isDistributedJoinEnabled(session) && !mustBroadcastJoin(node))) {
+ if (type == RIGHT || type == FULL || (isRepartitionedJoinEnabled(session) && !mustBroadcastJoin(node))) {
return JoinNode.DistributionType.PARTITIONED;
}
@@ -132,11 +135,16 @@ private static boolean isCrossJoin(JoinNode node)
private SemiJoinNode.DistributionType getTargetSemiJoinDistributionType(boolean isDeleteQuery)
{
- if (isDistributedJoinEnabled(session) && !isDeleteQuery) {
+ if (isRepartitionedJoinEnabled(session) && !isDeleteQuery) {
return SemiJoinNode.DistributionType.PARTITIONED;
}
return SemiJoinNode.DistributionType.REPLICATED;
}
+
+ private static boolean isRepartitionedJoinEnabled(Session session)
+ {
+ return getJoinDistributionType(session).canRepartition();
+ }
}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java
index 89a4eccfc674c..57f29e4dad791 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java
@@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
-import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
@@ -28,6 +27,9 @@
import java.util.Map;
import java.util.Objects;
+import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.buildJoinTree;
import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder;
import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder;
@@ -47,7 +49,7 @@ public PlanNode optimize(
SymbolAllocator symbolAllocator,
PlanNodeIdAllocator idAllocator)
{
- if (!SystemSessionProperties.isJoinReorderingEnabled(session)) {
+ if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) {
return plan;
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java
index 3ecc8862ae35b..0eb4615f145fc 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java
@@ -28,10 +28,14 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
+import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.facebook.presto.sql.planner.SortExpressionExtractor.extractSortExpression;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
@@ -100,6 +104,58 @@ public JoinNode(@JsonProperty("id") PlanNodeId id,
checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin");
}
+ public JoinNode flipChildren()
+ {
+ return new JoinNode(
+ getId(),
+ flipType(type),
+ right,
+ left,
+ flipJoinCriteria(criteria),
+ flipOutputSymbols(getOutputSymbols(), left, right),
+ filter,
+ rightHashSymbol,
+ leftHashSymbol,
+ distributionType);
+ }
+
+ private static Type flipType(Type type)
+ {
+ switch (type) {
+ case INNER:
+ return INNER;
+ case FULL:
+ return FULL;
+ case LEFT:
+ return RIGHT;
+ case RIGHT:
+ return LEFT;
+ default:
+ throw new IllegalStateException("No inverse defined for join type: " + type);
+ }
+ }
+
+ private static List flipJoinCriteria(List joinCriteria)
+ {
+ return joinCriteria.stream()
+ .map(EquiJoinClause::flip)
+ .collect(toImmutableList());
+ }
+
+ private static List flipOutputSymbols(List outputSymbols, PlanNode left, PlanNode right)
+ {
+ List leftSymbols = outputSymbols.stream()
+ .filter(symbol -> left.getOutputSymbols().contains(symbol))
+ .collect(Collectors.toList());
+ List rightSymbols = outputSymbols.stream()
+ .filter(symbol -> right.getOutputSymbols().contains(symbol))
+ .collect(Collectors.toList());
+ return ImmutableList.builder()
+ .addAll(rightSymbols)
+ .addAll(leftSymbols)
+ .build();
+ }
+
public enum DistributionType
{
PARTITIONED,
@@ -230,6 +286,11 @@ public PlanNode replaceChildren(List newChildren)
return new JoinNode(getId(), type, newLeft, newRight, criteria, newOutputSymbols, filter, leftHashSymbol, rightHashSymbol, distributionType);
}
+ public JoinNode withDistributionType(DistributionType distributionType)
+ {
+ return new JoinNode(getId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.of(distributionType));
+ }
+
public boolean isCrossJoin()
{
return criteria.isEmpty() && !filter.isPresent() && type == INNER;
@@ -264,6 +325,11 @@ public ComparisonExpression toExpression()
return new ComparisonExpression(ComparisonExpressionType.EQUAL, left.toSymbolReference(), right.toSymbolReference());
}
+ public EquiJoinClause flip()
+ {
+ return new EquiJoinClause(right, left);
+ }
+
@Override
public boolean equals(Object obj)
{
diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
index d938c37b8e7fa..268336a2e88b6 100644
--- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
+++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
@@ -251,15 +251,16 @@ public LocalQueryRunner(Session defaultSession)
new FeaturesConfig()
.setOptimizeMixedDistinctAggregations(true)
.setIterativeOptimizerEnabled(true),
- false);
+ false,
+ 1);
}
public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig)
{
- this(defaultSession, featuresConfig, false);
+ this(defaultSession, featuresConfig, false, 1);
}
- private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction)
+ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction, int nodeCountForStats)
{
requireNonNull(defaultSession, "defaultSession is null");
checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction");
@@ -379,14 +380,19 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
new CoefficientBasedStatsCalculator(metadata),
ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata)));
this.costCalculator = new CostCalculatorUsingExchanges(getNodeCount());
- this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, getNodeCount());
+ this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, nodeCountForStats);
this.lookup = new StatelessLookup(statsCalculator, costCalculator);
}
public static LocalQueryRunner queryRunnerWithInitialTransaction(Session defaultSession)
{
checkArgument(!defaultSession.getTransactionId().isPresent(), "Already in transaction!");
- return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true);
+ return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true, 1);
+ }
+
+ public static LocalQueryRunner queryRunnerWithFakeNodeCountForStats(Session defaultSession, int nodeCount)
+ {
+ return new LocalQueryRunner(defaultSession, new FeaturesConfig(), false, nodeCount);
}
@Override
diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java
new file mode 100644
index 0000000000000..0bf9e06ac13d8
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.testing;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.cost.PlanNodeStatsEstimate;
+import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanNodeId;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.Map;
+
+import static java.util.Objects.requireNonNull;
+
+public class TestingStatsCalculator
+ implements StatsCalculator
+{
+ private final StatsCalculator statsCalculator;
+ private final Map stats;
+
+ public TestingStatsCalculator(StatsCalculator statsCalculator, Map stats)
+ {
+ this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
+ this.stats = ImmutableMap.copyOf(requireNonNull(stats, "stats is null"));
+ }
+
+ @Override
+ public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ if (stats.containsKey(planNode.getId())) {
+ return stats.get(planNode.getId());
+ }
+
+ return statsCalculator.calculateStats(planNode, lookup, session, types);
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java
index 7f8b8e0a6c32a..f516618b1563e 100644
--- a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java
+++ b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java
@@ -27,8 +27,8 @@
import java.util.Locale;
-import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN;
import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT;
+import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO;
@@ -59,7 +59,7 @@ public void testCreateSession()
.put(PRESTO_TIME_ZONE, "Asia/Taipei")
.put(PRESTO_CLIENT_INFO, "client-info")
.put(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB")
- .put(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43")
+ .put(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43")
.put(PRESTO_PREPARED_STATEMENT, "query1=select * from foo,query2=select * from bar")
.build(),
"testRemote");
@@ -82,7 +82,7 @@ public void testCreateSession()
assertEquals(session.getClientInfo().get(), "client-info");
assertEquals(session.getSystemProperties(), ImmutableMap.builder()
.put(QUERY_MAX_MEMORY, "1GB")
- .put(DISTRIBUTED_JOIN, "true")
+ .put(JOIN_DISTRIBUTION_TYPE, "repartitioned")
.put(HASH_PARTITION_COUNT, "43")
.build());
assertEquals(session.getPreparedStatements(), ImmutableMap.builder()
diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java
index 3fad76a15ea32..a2ddec819a837 100644
--- a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java
+++ b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java
@@ -35,8 +35,8 @@
import java.net.URI;
import java.util.List;
-import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN;
import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT;
+import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO;
@@ -134,7 +134,7 @@ public void testQuery()
.setHeader(PRESTO_SCHEMA, "schema")
.setHeader(PRESTO_CLIENT_INFO, "{\"clientVersion\":\"testVersion\"}")
.addHeader(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB")
- .addHeader(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43")
+ .addHeader(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43")
.addHeader(PRESTO_PREPARED_STATEMENT, "foo=select * from bar")
.build();
@@ -146,7 +146,7 @@ public void testQuery()
// verify session properties
assertEquals(queryInfo.getSession().getSystemProperties(), ImmutableMap.builder()
.put(QUERY_MAX_MEMORY, "1GB")
- .put(DISTRIBUTED_JOIN, "true")
+ .put(JOIN_DISTRIBUTION_TYPE, "repartitioned")
.put(HASH_PARTITION_COUNT, "43")
.build());
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
index 780c48507c263..bb4db2ba43b46 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java
@@ -21,6 +21,10 @@
import java.util.Map;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPLICATED;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
+import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE;
import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI;
import static com.facebook.presto.sql.analyzer.RegexLibrary.RE2J;
import static io.airlift.configuration.testing.ConfigAssertions.assertDeprecatedEquivalence;
@@ -40,10 +44,10 @@ public void testDefaults()
.setNetworkCostWeight(0.25)
.setResourceGroupsEnabled(false)
.setDistributedIndexJoinsEnabled(false)
- .setDistributedJoinsEnabled(true)
+ .setJoinDistributionType(REPARTITIONED)
.setFastInequalityJoins(true)
.setColocatedJoinsEnabled(false)
- .setJoinReorderingEnabled(true)
+ .setJoinReorderingStrategy(ELIMINATE_CROSS_JOINS)
.setRedistributeWrites(true)
.setOptimizeMetadataQueries(false)
.setOptimizeHashGeneration(true)
@@ -86,10 +90,10 @@ public void testExplicitPropertyMappings()
.put("deprecated.legacy-map-subscript", "true")
.put("deprecated.new-map-block", "false")
.put("distributed-index-joins-enabled", "true")
- .put("distributed-joins-enabled", "false")
+ .put("join-distribution-type", "REPLICATED")
.put("fast-inequality-joins", "false")
.put("colocated-joins-enabled", "true")
- .put("reorder-joins", "false")
+ .put("optimizer.join-reordering-strategy", "NONE")
.put("redistribute-writes", "false")
.put("optimizer.optimize-metadata-queries", "true")
.put("optimizer.optimize-hash-generation", "false")
@@ -122,10 +126,10 @@ public void testExplicitPropertyMappings()
.put("deprecated.legacy-map-subscript", "true")
.put("deprecated.new-map-block", "false")
.put("distributed-index-joins-enabled", "true")
- .put("distributed-joins-enabled", "false")
+ .put("join-distribution-type", "REPLICATED")
.put("fast-inequality-joins", "false")
.put("colocated-joins-enabled", "true")
- .put("reorder-joins", "false")
+ .put("optimizer.join-reordering-strategy", "NONE")
.put("redistribute-writes", "false")
.put("optimizer.optimize-metadata-queries", "true")
.put("optimizer.optimize-hash-generation", "false")
@@ -155,10 +159,10 @@ public void testExplicitPropertyMappings()
.setIterativeOptimizerEnabled(false)
.setIterativeOptimizerTimeout(new Duration(10, SECONDS))
.setDistributedIndexJoinsEnabled(true)
- .setDistributedJoinsEnabled(false)
+ .setJoinDistributionType(REPLICATED)
.setFastInequalityJoins(false)
.setColocatedJoinsEnabled(true)
- .setJoinReorderingEnabled(false)
+ .setJoinReorderingStrategy(NONE)
.setRedistributeWrites(false)
.setOptimizeMetadataQueries(true)
.setOptimizeHashGeneration(false)
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java
index a78c856247075..6d002c104dd7d 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java
@@ -37,12 +37,14 @@ final class JoinMatcher
private final JoinNode.Type joinType;
private final List> equiCriteria;
private final Optional filter;
+ private final Optional distributionType;
- JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter)
+ JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter, Optional distributionType)
{
this.joinType = requireNonNull(joinType, "joinType is null");
this.equiCriteria = requireNonNull(equiCriteria, "equiCriteria is null");
this.filter = requireNonNull(filter, "filter can not be null");
+ this.distributionType = requireNonNull(distributionType, "distribtuionType cannot be null");
}
@Override
@@ -81,6 +83,10 @@ public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Ses
}
}
+ if (distributionType.isPresent() && !distributionType.equals(joinNode.getDistributionType())) {
+ return NO_MATCH;
+ }
+
/*
* Have to use order-independent comparison; there are no guarantees what order
* the equi criteria will have after planning and optimizing.
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java
index 9e39e36ec9506..ce7c03dabffef 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java
@@ -266,12 +266,18 @@ public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, PlanMatchPattern left, PlanMatchPattern right)
+ {
+ return join(joinType, expectedEquiCriteria, expectedFilter, Optional.empty(), left, right);
+ }
+
+ public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, Optional distributionType, PlanMatchPattern left, PlanMatchPattern right)
{
return node(JoinNode.class, left, right).with(
new JoinMatcher(
joinType,
expectedEquiCriteria,
- expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate)))));
+ expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))),
+ distributionType));
}
public static PlanMatchPattern exchange(PlanMatchPattern... sources)
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java
new file mode 100644
index 0000000000000..d6fb0ea483d0a
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.testing.MaterializedResult;
+import com.facebook.presto.testing.QueryRunner;
+import com.facebook.presto.tpch.TpchConnectorFactory;
+import com.google.common.collect.ImmutableMap;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.Options;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import org.openjdk.jmh.runner.options.VerboseMode;
+
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.String.format;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.openjdk.jmh.annotations.Mode.AverageTime;
+import static org.openjdk.jmh.annotations.Scope.Thread;
+
+@State(Thread)
+@OutputTimeUnit(MILLISECONDS)
+@BenchmarkMode(AverageTime)
+@Fork(3)
+@Warmup(iterations = 10)
+@Measurement(iterations = 10)
+public class BenchmarkReorderJoinsConnectedGraph
+{
+ @Benchmark
+ public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo)
+ {
+ return benchmarkInfo.getQueryRunner().execute(benchmarkInfo.getQuery());
+ }
+
+ @State(Thread)
+ public static class BenchmarkInfo
+ {
+ @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"})
+ private String joinReorderingStrategy;
+
+ @Param({"2", "4", "6", "8", "10"})
+ private int numberOfTables;
+
+ private String query;
+ private LocalQueryRunner queryRunner;
+
+ @Setup
+ public void setup()
+ {
+ checkState(numberOfTables >= 2, "numberOfTables must be >= 2");
+ Session session = testSessionBuilder()
+ .setSystemProperty("join_reordering_strategy", joinReorderingStrategy)
+ .setSystemProperty("join_distribution_type", "AUTOMATIC")
+ .setCatalog("tpch")
+ .setSchema("tiny")
+ .build();
+ queryRunner = new LocalQueryRunner(session);
+ queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("EXPLAIN SELECT * FROM NATION n1");
+ for (int i = 2; i <= numberOfTables; i++) {
+ stringBuilder.append(format(" JOIN nation n%s on n%s.nationkey = n%s.nationkey", i, i - 1, i));
+ }
+ query = stringBuilder.toString();
+ }
+
+ public String getQuery()
+ {
+ return query;
+ }
+
+ public QueryRunner getQueryRunner()
+ {
+ return queryRunner;
+ }
+
+ @TearDown
+ public void tearDown()
+ {
+ queryRunner.close();
+ }
+ }
+
+ public static void main(String[] args)
+ throws RunnerException
+ {
+ Options options = new OptionsBuilder()
+ .verbosity(VerboseMode.NORMAL)
+ .include(".*" + BenchmarkReorderJoinsConnectedGraph.class.getSimpleName() + ".*")
+ .build();
+
+ new Runner(options).run();
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java
new file mode 100644
index 0000000000000..13c1644c17855
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.testing.MaterializedResult;
+import com.facebook.presto.testing.QueryRunner;
+import com.facebook.presto.tpch.TpchConnectorFactory;
+import com.google.common.collect.ImmutableMap;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.Options;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import org.openjdk.jmh.runner.options.VerboseMode;
+
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.openjdk.jmh.annotations.Mode.AverageTime;
+import static org.openjdk.jmh.annotations.Scope.Thread;
+
+@State(Thread)
+@OutputTimeUnit(MILLISECONDS)
+@BenchmarkMode(AverageTime)
+@Fork(3)
+@Warmup(iterations = 10)
+@Measurement(iterations = 10)
+public class BenchmarkReorderJoinsLinearGraph
+{
+ @Benchmark
+ public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo)
+ {
+ return benchmarkInfo.getQueryRunner().execute(
+ "EXPLAIN SELECT * FROM " +
+ "nation n1 JOIN nation n2 ON n1.nationkey = n2.nationkey " +
+ "JOIN nation n3 on n2.comment = n3.comment " +
+ "JOIN nation n4 on n3.name = n4.name " +
+ "JOIN region r1 on n4.regionkey = r1.regionkey " +
+ "JOIN region r2 on r2.name = r2.name " +
+ "JOIN region r3 on r3.comment = r2.comment " +
+ "join region r4 on r4.regionkey = r3.regionkey");
+ }
+
+ @State(Thread)
+ public static class BenchmarkInfo
+ {
+ @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"})
+ private String joinReorderingStrategy;
+
+ private LocalQueryRunner queryRunner;
+
+ @Setup
+ public void setup()
+ {
+ Session session = testSessionBuilder()
+ .setSystemProperty("join_reordering_strategy", joinReorderingStrategy)
+ .setSystemProperty("join_distribution_type", "AUTOMATIC")
+ .setCatalog("tpch")
+ .setSchema("tiny")
+ .build();
+ queryRunner = new LocalQueryRunner(session);
+ queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
+ }
+
+ public QueryRunner getQueryRunner()
+ {
+ return queryRunner;
+ }
+
+ @TearDown
+ public void tearDown()
+ {
+ queryRunner.close();
+ }
+ }
+
+ public static void main(String[] args)
+ throws RunnerException
+ {
+ Options options = new OptionsBuilder()
+ .verbosity(VerboseMode.NORMAL)
+ .include(".*" + BenchmarkReorderJoinsLinearGraph.class.getSimpleName() + ".*")
+ .build();
+
+ new Runner(options).run();
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java
index e46826bfb567e..3609523de8c5b 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java
@@ -35,7 +35,7 @@
import java.util.Optional;
import java.util.function.Function;
-import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS;
+import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
@@ -60,7 +60,7 @@ public class TestEliminateCrossJoins
public void testEliminateCrossJoin()
{
tester().assertThat(new EliminateCrossJoins())
- .setSystemProperty(REORDER_JOINS, "true")
+ .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS")
.on(crossJoinAndJoin(INNER))
.matches(
join(INNER,
@@ -79,7 +79,7 @@ public void testEliminateCrossJoin()
public void testRetainOutgoingGroupReferences()
{
tester().assertThat(new EliminateCrossJoins())
- .setSystemProperty(REORDER_JOINS, "true")
+ .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS")
.on(crossJoinAndJoin(INNER))
.matches(
node(JoinNode.class,
@@ -96,7 +96,7 @@ public void testRetainOutgoingGroupReferences()
public void testDoNotReorderOuterJoin()
{
tester().assertThat(new EliminateCrossJoins())
- .setSystemProperty(REORDER_JOINS, "true")
+ .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS")
.on(crossJoinAndJoin(JoinNode.Type.LEFT))
.doesNotFire();
}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java
new file mode 100644
index 0000000000000..c2449777969c1
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.cost.CostComparator;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult;
+import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.Set;
+
+import static com.facebook.presto.spi.type.BigintType.BIGINT;
+import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions;
+import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static io.airlift.testing.Closeables.closeAllRuntimeException;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+
+public class TestJoinEnumerator
+{
+ private LocalQueryRunner queryRunner;
+
+ @BeforeClass
+ public void setUp()
+ {
+ queryRunner = new LocalQueryRunner(testSessionBuilder().build());
+ }
+
+ @AfterClass(alwaysRun = true)
+ public void tearDown()
+ {
+ closeAllRuntimeException(queryRunner);
+ queryRunner = null;
+ }
+
+ @Test
+ public void testGeneratePartitions()
+ {
+ Set> partitions = generatePartitions(4).collect(toImmutableSet());
+ assertEquals(partitions,
+ ImmutableSet.of(
+ ImmutableSet.of(0),
+ ImmutableSet.of(0, 1),
+ ImmutableSet.of(0, 2),
+ ImmutableSet.of(0, 3),
+ ImmutableSet.of(0, 1, 2),
+ ImmutableSet.of(0, 1, 3),
+ ImmutableSet.of(0, 2, 3)));
+
+ partitions = generatePartitions(3).collect(toImmutableSet());
+ assertEquals(partitions,
+ ImmutableSet.of(
+ ImmutableSet.of(0),
+ ImmutableSet.of(0, 1),
+ ImmutableSet.of(0, 2)));
+ }
+
+ @Test
+ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin()
+ {
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata());
+ Symbol a1 = planBuilder.symbol("A1", BIGINT);
+ Symbol b1 = planBuilder.symbol("B1", BIGINT);
+ MultiJoinNode multiJoinNode = new MultiJoinNode(
+ ImmutableList.of(planBuilder.values(a1), planBuilder.values(b1)),
+ TRUE_LITERAL,
+ ImmutableList.of(a1, b1));
+ ReorderJoins.JoinEnumerator joinEnumerator = new ReorderJoins.JoinEnumerator(
+ idAllocator,
+ new SymbolAllocator(),
+ queryRunner.getDefaultSession(),
+ queryRunner.getLookup(),
+ multiJoinNode.getFilter(),
+ new CostComparator(1, 1, 1));
+ JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), ImmutableSet.of(0));
+ assertFalse(actual.getPlanNode().isPresent());
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java
index 7bf5a19317a4d..0aed522491b14 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java
@@ -24,7 +24,6 @@
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Memo;
-import com.facebook.presto.sql.planner.iterative.MemoBasedLookup;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.testing.LocalQueryRunner;
@@ -60,7 +59,7 @@ public void testResolvesGroupReferenceNode()
PlanNode plan = node(source);
Memo memo = new Memo(idAllocator, plan);
- MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1));
+ Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1));
PlanNode memoSource = Iterables.getOnlyElement(memo.getNode(memo.getRootGroup()).getSources());
checkState(memoSource instanceof GroupReference, "expected GroupReference");
assertEquals(lookup.resolve(memoSource), source);
@@ -71,7 +70,7 @@ public void testComputesStatsAndResolvesNodes()
{
PlanNode plan = node(node(node()));
Memo memo = new Memo(idAllocator, plan);
- MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1));
+ Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1));
PlanNodeStatsEstimate actualStats = lookup.getStats(memo.getNode(memo.getRootGroup()), queryRunner.getDefaultSession(), ImmutableMap.of());
PlanNodeStatsEstimate expectedStats = PlanNodeStatsEstimate.builder().setOutputRowCount(3).build();
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java
new file mode 100644
index 0000000000000..1f9d8d402de64
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.ValuesNode;
+import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
+import com.facebook.presto.sql.tree.ComparisonExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.facebook.presto.sql.tree.LongLiteral;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.testng.annotations.Test;
+
+import java.util.Optional;
+
+import static com.facebook.presto.spi.type.BigintType.BIGINT;
+import static com.facebook.presto.sql.ExpressionUtils.and;
+import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts;
+import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT;
+import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Type.ADD;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.NOT_EQUAL;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static org.testng.Assert.assertEquals;
+
+public class TestMultiJoinNodeBuilder
+{
+ private final LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder().build());
+
+ @Test(expectedExceptions = IllegalStateException.class)
+ public void testDoesNotFireForOuterJoins()
+ {
+ PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), queryRunner.getMetadata());
+ JoinNode outerJoin = p.join(
+ JoinNode.Type.FULL,
+ p.values(p.symbol("A1", BIGINT)),
+ p.values(p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty());
+ toMultiJoinNode(outerJoin, queryRunner.getLookup());
+ }
+
+ @Test
+ public void testDoesNotConvertNestedOuterJoins()
+ {
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata());
+ Symbol a1 = planBuilder.symbol("A1", BIGINT);
+ Symbol b1 = planBuilder.symbol("B1", BIGINT);
+ Symbol c1 = planBuilder.symbol("C1", BIGINT);
+ JoinNode leftJoin = planBuilder.join(
+ LEFT,
+ planBuilder.values(a1),
+ planBuilder.values(b1),
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)),
+ ImmutableList.of(a1, b1),
+ Optional.empty());
+ ValuesNode valuesC = planBuilder.values(c1);
+ JoinNode joinNode = planBuilder.join(
+ INNER,
+ leftJoin,
+ valuesC,
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)),
+ ImmutableList.of(a1, b1, c1),
+ Optional.empty());
+
+ MultiJoinNode expected = new MultiJoinNode(ImmutableList.of(leftJoin, valuesC), new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()), ImmutableList.of(a1, b1, c1));
+ assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected);
+ }
+
+ @Test
+ public void testRetainsOutputSymbols()
+ {
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata());
+ Symbol a1 = planBuilder.symbol("A1", BIGINT);
+ Symbol b1 = planBuilder.symbol("B1", BIGINT);
+ Symbol b2 = planBuilder.symbol("B2", BIGINT);
+ Symbol c1 = planBuilder.symbol("C1", BIGINT);
+ Symbol c2 = planBuilder.symbol("C2", BIGINT);
+ ValuesNode valuesA = planBuilder.values(a1);
+ ValuesNode valuesB = planBuilder.values(b1, b2);
+ ValuesNode valuesC = planBuilder.values(c1, c2);
+ JoinNode joinNode = planBuilder.join(
+ INNER,
+ valuesA,
+ planBuilder.join(
+ INNER,
+ valuesB,
+ valuesC,
+ ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)),
+ ImmutableList.of(
+ b1,
+ b2,
+ c1,
+ c2),
+ Optional.empty()),
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)),
+ ImmutableList.of(a1, b1),
+ Optional.empty());
+ MultiJoinNode expected = new MultiJoinNode(
+ ImmutableList.of(valuesA, valuesB, valuesC),
+ and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference())),
+ ImmutableList.of(a1, b1));
+ assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected);
+ }
+
+ @Test
+ public void testCombinesCriteriaAndFilters()
+ {
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata());
+ Symbol a1 = planBuilder.symbol("A1", BIGINT);
+ Symbol b1 = planBuilder.symbol("B1", BIGINT);
+ Symbol b2 = planBuilder.symbol("B2", BIGINT);
+ Symbol c1 = planBuilder.symbol("C1", BIGINT);
+ Symbol c2 = planBuilder.symbol("C2", BIGINT);
+ ValuesNode valuesA = planBuilder.values(a1);
+ ValuesNode valuesB = planBuilder.values(b1, b2);
+ ValuesNode valuesC = planBuilder.values(c1, c2);
+ Expression bcFilter = and(
+ new ComparisonExpression(GREATER_THAN, c2.toSymbolReference(), new LongLiteral("0")),
+ new ComparisonExpression(NOT_EQUAL, c2.toSymbolReference(), new LongLiteral("7")),
+ new ComparisonExpression(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference()));
+ ComparisonExpression abcFilter = new ComparisonExpression(
+ LESS_THAN,
+ new ArithmeticBinaryExpression(ADD, a1.toSymbolReference(), c1.toSymbolReference()),
+ b1.toSymbolReference());
+ JoinNode joinNode = planBuilder.join(
+ INNER,
+ valuesA,
+ planBuilder.join(
+ INNER,
+ valuesB,
+ valuesC,
+ ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)),
+ ImmutableList.of(
+ b1,
+ b2,
+ c1,
+ c2),
+ Optional.of(bcFilter)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)),
+ ImmutableList.of(a1, b1, b2, c1, c2),
+ Optional.of(abcFilter));
+ MultiJoinNode expected = new MultiJoinNode(
+ ImmutableList.of(valuesA, valuesB, valuesC),
+ and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter),
+ ImmutableList.of(a1, b1, b2, c1, c2));
+ assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected);
+ }
+
+ @Test
+ public void testConvertsBushyTrees()
+ {
+ PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
+ PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata());
+ Symbol a1 = planBuilder.symbol("A1", BIGINT);
+ Symbol b1 = planBuilder.symbol("B1", BIGINT);
+ Symbol c1 = planBuilder.symbol("C1", BIGINT);
+ Symbol d1 = planBuilder.symbol("D1", BIGINT);
+ Symbol d2 = planBuilder.symbol("D2", BIGINT);
+ Symbol e1 = planBuilder.symbol("E1", BIGINT);
+ Symbol e2 = planBuilder.symbol("E2", BIGINT);
+ ValuesNode valuesA = planBuilder.values(a1);
+ ValuesNode valuesB = planBuilder.values(b1);
+ ValuesNode valuesC = planBuilder.values(c1);
+ ValuesNode valuesD = planBuilder.values(d1, d2);
+ ValuesNode valuesE = planBuilder.values(e1, e2);
+ JoinNode joinNode = planBuilder.join(
+ INNER,
+ planBuilder.join(
+ INNER,
+ planBuilder.join(
+ INNER,
+ valuesA,
+ valuesB,
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)),
+ ImmutableList.of(a1, b1),
+ Optional.empty()),
+ valuesC,
+ ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)),
+ ImmutableList.of(a1, b1, c1),
+ Optional.empty()),
+ planBuilder.join(
+ INNER,
+ valuesD,
+ valuesE,
+ ImmutableList.of(
+ new JoinNode.EquiJoinClause(d1, e1),
+ new JoinNode.EquiJoinClause(d2, e2)),
+ ImmutableList.of(
+ d1,
+ d2,
+ e1,
+ e2),
+ Optional.empty()),
+ ImmutableList.of(new JoinNode.EquiJoinClause(b1, e1)),
+ ImmutableList.of(
+ a1,
+ b1,
+ c1,
+ d1,
+ d2,
+ e1,
+ e2),
+ Optional.empty());
+ MultiJoinNode expected = new MultiJoinNode(
+ ImmutableList.of(valuesA, valuesB, valuesC, valuesD, valuesE),
+ and(
+ new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()),
+ new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()),
+ new ComparisonExpression(EQUAL, d1.toSymbolReference(), e1.toSymbolReference()),
+ new ComparisonExpression(EQUAL, d2.toSymbolReference(), e2.toSymbolReference()),
+ new ComparisonExpression(EQUAL, b1.toSymbolReference(), e1.toSymbolReference())),
+ ImmutableList.of(a1, b1, c1, d1, d2, e1, e2));
+ assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected);
+ }
+
+ private static void assertMultijoinEquals(MultiJoinNode actual, MultiJoinNode expected)
+ {
+ assertEquals(ImmutableSet.copyOf(actual.getSources()), ImmutableSet.copyOf(expected.getSources()));
+ assertEquals(ImmutableSet.copyOf(extractConjuncts(actual.getFilter())), ImmutableSet.copyOf(extractConjuncts(expected.getFilter())));
+ assertEquals(actual.getOutputSymbols(), expected.getOutputSymbols());
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java
new file mode 100644
index 0000000000000..c64c901c3184b
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java
@@ -0,0 +1,382 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.cost.CostCalculator;
+import com.facebook.presto.cost.CostComparator;
+import com.facebook.presto.cost.PlanNodeStatsEstimate;
+import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.cost.SymbolStatsEstimate;
+import com.facebook.presto.spi.TestingColumnHandle;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanNodeId;
+import com.facebook.presto.sql.tree.ComparisonExpression;
+import com.facebook.presto.sql.tree.ComparisonExpressionType;
+import com.facebook.presto.sql.tree.FunctionCall;
+import com.facebook.presto.sql.tree.QualifiedName;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.testing.TestingStatsCalculator;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS;
+import static com.facebook.presto.spi.type.BigintType.BIGINT;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
+import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
+import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED;
+import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
+import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithFakeNodeCountForStats;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static io.airlift.testing.Closeables.closeAllRuntimeException;
+
+public class TestReorderJoins
+{
+ private RuleTester tester;
+ private StatsCalculator statsCalculator;
+ private CostCalculator costCalculator;
+
+ @BeforeClass
+ public void setUp()
+ {
+ Session session = testSessionBuilder()
+ .setCatalog("local")
+ .setSchema("tiny")
+ .setSystemProperty("join_distribution_type", "automatic")
+ .setSystemProperty("join_reordering_strategy", "COST_BASED")
+ .build();
+ LocalQueryRunner queryRunner = queryRunnerWithFakeNodeCountForStats(session, 4);
+ statsCalculator = queryRunner.getStatsCalculator();
+ costCalculator = queryRunner.getEstimatedExchangesCostCalculator();
+ tester = new RuleTester(queryRunner);
+ }
+
+ @AfterClass(alwaysRun = true)
+ public void tearDown()
+ {
+ closeAllRuntimeException(tester);
+ tester = null;
+ }
+
+ @Test
+ public void testKeepsOutputSymbols()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(5000)
+ .addSymbolStatistics(ImmutableMap.of(
+ new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100),
+ new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT), p.symbol("A2", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A2", BIGINT)),
+ Optional.empty()))
+ .matches(join(
+ INNER,
+ ImmutableList.of(equiJoinClause("A1", "B1")),
+ Optional.empty(),
+ Optional.of(PARTITIONED),
+ values(ImmutableMap.of("A1", 0, "A2", 1)),
+ values(ImmutableMap.of("B1", 0))
+ ));
+ }
+
+ @Test
+ public void testReplicatesAndFlipsWhenOneTableMuchSmaller()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(100)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty()))
+ .matches(join(
+ INNER,
+ ImmutableList.of(equiJoinClause("B1", "A1")),
+ Optional.empty(),
+ Optional.of(REPLICATED),
+ values(ImmutableMap.of("B1", 0)),
+ values(ImmutableMap.of("A1", 0))
+ ));
+ }
+
+ @Test
+ public void testRepartitionsWhenRequiredBySession()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(100)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty()))
+ .setSystemProperty("join_distribution_type", "REPARTITIONED")
+ .matches(join(
+ INNER,
+ ImmutableList.of(equiJoinClause("B1", "A1")),
+ Optional.empty(),
+ Optional.of(PARTITIONED),
+ values(ImmutableMap.of("B1", 0)),
+ values(ImmutableMap.of("A1", 0))
+ ));
+ }
+
+ @Test
+ public void testRepartitionsWhenBothTablesEqual()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty()))
+ .matches(join(
+ INNER,
+ ImmutableList.of(equiJoinClause("A1", "B1")),
+ Optional.empty(),
+ Optional.of(PARTITIONED),
+ values(ImmutableMap.of("A1", 0)),
+ values(ImmutableMap.of("B1", 0))
+ ));
+ }
+
+ @Test
+ public void testReplicatesWhenRequiredBySession()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty()))
+ .setSystemProperty("join_distribution_type", "REPLICATED")
+ .matches(join(
+ INNER,
+ ImmutableList.of(equiJoinClause("A1", "B1")),
+ Optional.empty(),
+ Optional.of(REPLICATED),
+ values(ImmutableMap.of("A1", 0)),
+ values(ImmutableMap.of("B1", 0))
+ ));
+ }
+
+ @Test
+ public void testDoesNotFireForCrossJoin()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build(),
+ new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.empty()))
+ .doesNotFire();
+ }
+
+ @Test
+ public void testDoesNotFireWithNoStats()
+ {
+ StatsCalculator testingStatsCalculator = new UnknownStatsCalculator();
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.tableScan(ImmutableList.of(p.symbol("A1", BIGINT)), ImmutableMap.of(p.symbol("A1", BIGINT), new TestingColumnHandle("A1"))),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT)),
+ Optional.empty()))
+ .doesNotFire();
+ }
+
+ @Test
+ public void testDoesNotFireForNonDeterministicFilter()
+ {
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), statsCalculator, costCalculator))
+ .on(p ->
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)),
+ ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)),
+ Optional.of(new ComparisonExpression(ComparisonExpressionType.LESS_THAN, p.symbol("A1", BIGINT).toSymbolReference(), new FunctionCall(QualifiedName.of("random"), ImmutableList.of())))))
+ .doesNotFire();
+ }
+
+ @Test
+ public void testPredicatesPushedDown()
+ {
+ StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of(
+ new PlanNodeId("valuesA"),
+ PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(10)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10)))
+ .build(),
+ new PlanNodeId("valuesB"),
+ PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(5)
+ .addSymbolStatistics(ImmutableMap.of(
+ new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10),
+ new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10)))
+ .build(),
+ new PlanNodeId("valuesC"),
+ PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(1000)
+ .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
+ .build()));
+
+ tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator))
+ .withStatsCalculator(testingStatsCalculator)
+ .on(p ->
+ p.join(
+ INNER,
+ p.join(
+ INNER,
+ p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)),
+ p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)),
+ ImmutableList.of(),
+ ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)),
+ Optional.empty()),
+ p.values(new PlanNodeId("valuesC"), p.symbol("C1", BIGINT)),
+ ImmutableList.of(
+ new JoinNode.EquiJoinClause(p.symbol("B2", BIGINT), p.symbol("C1", BIGINT))),
+ ImmutableList.of(p.symbol("A1", BIGINT)),
+ Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1", BIGINT).toSymbolReference(), p.symbol("B1", BIGINT).toSymbolReference()))))
+ .matches(
+ join(
+ INNER,
+ ImmutableList.of(equiJoinClause("C1", "B2")),
+ values("C1"),
+ join(
+ INNER,
+ ImmutableList.of(equiJoinClause("A1", "B1")),
+ values("A1"),
+ values("B1", "B2"))
+ )
+ );
+ }
+
+ private static class UnknownStatsCalculator
+ implements StatsCalculator
+ {
+ @Override
+ public PlanNodeStatsEstimate calculateStats(
+ PlanNode planNode,
+ Lookup lookup,
+ Session session,
+ Map types)
+ {
+ PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.buildFrom(UNKNOWN_STATS);
+ planNode.getOutputSymbols()
+ .forEach(symbol -> statsBuilder.addSymbolStatistics(symbol, SymbolStatsEstimate.UNKNOWN_STATS));
+ return statsBuilder.build();
+ }
+ }
+}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
index 64ed14842aabc..e4dafc7e55bfa 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
@@ -42,6 +42,7 @@
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
@@ -90,9 +91,14 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata)
}
public ValuesNode values(Symbol... columns)
+ {
+ return values(idAllocator.getNextId(), columns);
+ }
+
+ public ValuesNode values(PlanNodeId id, Symbol... columns)
{
return new ValuesNode(
- idAllocator.getNextId(),
+ id,
ImmutableList.copyOf(columns),
ImmutableList.of());
}
@@ -320,6 +326,11 @@ public ExchangeNode exchange(Consumer exchangeBuilderConsumer)
return exchangeBuilder.build();
}
+ public JoinNode join(JoinNode.Type type, PlanNode left, PlanNode right, List criteria, List outputSymbols, Optional filter)
+ {
+ return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter, Optional.empty(), Optional.empty(), Optional.empty());
+ }
+
public class ExchangeBuilder
{
private ExchangeNode.Type type = ExchangeNode.Type.GATHER;
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
index 2b79a9d079926..f125cc50c1499 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
@@ -39,25 +39,35 @@
import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.fail;
public class RuleAssert
{
private final Metadata metadata;
+ private StatsCalculator statsCalculator;
+ private final CostCalculator costCalculator;
private Session session;
private final Rule rule;
private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
private Map symbols;
+ private Lookup lookup;
private PlanNode plan;
private final TransactionManager transactionManager;
private final AccessControl accessControl;
- private final StatsCalculator statsCalculator;
- private final CostCalculator costCalculator;
-
- public RuleAssert(Metadata metadata, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl, StatsCalculator statsCalculator, CostCalculator costCalculator)
+ private Memo memo;
+
+ public RuleAssert(
+ Metadata metadata,
+ Session session,
+ Rule rule,
+ TransactionManager transactionManager,
+ AccessControl accessControl,
+ StatsCalculator statsCalculator,
+ CostCalculator costCalculator)
{
this.metadata = metadata;
this.session = session;
@@ -81,6 +91,13 @@ public RuleAssert withSession(Session session)
return this;
}
+ public RuleAssert withStatsCalculator(StatsCalculator statsCalculator)
+ {
+ checkState(lookup == null, "lookup has been set");
+ this.statsCalculator = statsCalculator;
+ return this;
+ }
+
public RuleAssert on(Function planProvider)
{
checkArgument(plan == null, "plan has already been set");
@@ -88,6 +105,8 @@ public RuleAssert on(Function planProvider)
PlanBuilder builder = new PlanBuilder(idAllocator, metadata);
plan = planProvider.apply(builder);
symbols = builder.getSymbols();
+ memo = new Memo(idAllocator, plan);
+ lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator);
return this;
}
@@ -143,8 +162,6 @@ public void matches(PlanMatchPattern pattern)
private RuleApplication applyRule()
{
SymbolAllocator symbolAllocator = new SymbolAllocator(symbols);
- Memo memo = new Memo(idAllocator, plan);
- Lookup lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator);
if (!rule.getPattern().matches(plan)) {
return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty());
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java
index 40a6321583d92..f1589980d0a02 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java
@@ -59,7 +59,7 @@ public class TestReorderJoins
public TestReorderJoins()
{
- super(ImmutableMap.of(SystemSessionProperties.REORDER_JOINS, "true"));
+ super(ImmutableMap.of(SystemSessionProperties.JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS"));
}
@Test
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java
index 0235bba4f0cba..f29253a8c2475 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java
@@ -21,6 +21,6 @@ public class TestUnionWithReplicatedJoin
{
public TestUnionWithReplicatedJoin()
{
- super(ImmutableMap.of(SystemSessionProperties.DISTRIBUTED_JOIN, "false"));
+ super(ImmutableMap.of(SystemSessionProperties.JOIN_DISTRIBUTION_TYPE, "replicated"));
}
}
diff --git a/presto-product-tests/conf/presto/etc/config.properties b/presto-product-tests/conf/presto/etc/config.properties
index 44aef5efc782e..a7e918b66a97f 100644
--- a/presto-product-tests/conf/presto/etc/config.properties
+++ b/presto-product-tests/conf/presto/etc/config.properties
@@ -38,8 +38,6 @@ plugin.bundles=\
../../../presto-sqlserver/pom.xml
presto.version=testversion
-distributed-joins-enabled=true
query.max-memory-per-node=1GB
query.max-memory=1GB
redistribute-writes=false
-reorder-joins=true
diff --git a/presto-product-tests/conf/presto/etc/multinode-master.properties b/presto-product-tests/conf/presto/etc/multinode-master.properties
index 5be52ed2cb2d4..98d6602e08447 100644
--- a/presto-product-tests/conf/presto/etc/multinode-master.properties
+++ b/presto-product-tests/conf/presto/etc/multinode-master.properties
@@ -15,4 +15,3 @@ query.max-memory=1GB
query.max-memory-per-node=512MB
discovery-server.enabled=true
discovery.uri=http://presto-master:8080
-reorder-joins=true
diff --git a/presto-product-tests/conf/presto/etc/singlenode.properties b/presto-product-tests/conf/presto/etc/singlenode.properties
index d0e2cda196b10..a2e66146a4eb8 100644
--- a/presto-product-tests/conf/presto/etc/singlenode.properties
+++ b/presto-product-tests/conf/presto/etc/singlenode.properties
@@ -15,4 +15,3 @@ query.max-memory=2GB
query.max-memory-per-node=1GB
discovery-server.enabled=true
discovery.uri=http://presto-master:8080
-reorder-joins=true
diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java
index 5be87ee48944b..025394cd81b56 100644
--- a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java
+++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java
@@ -14,6 +14,7 @@
package com.facebook.presto.tests.jdbc;
import com.facebook.presto.jdbc.PrestoConnection;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.teradata.tempto.BeforeTestWithContext;
import com.teradata.tempto.ProductTest;
import com.teradata.tempto.Requirement;
@@ -52,8 +53,6 @@
import static com.teradata.tempto.internal.convention.SqlResultDescriptor.sqlResultDescriptorForResource;
import static com.teradata.tempto.query.QueryExecutor.defaultQueryExecutor;
import static com.teradata.tempto.query.QueryExecutor.query;
-import static java.lang.Boolean.FALSE;
-import static java.lang.Boolean.TRUE;
import static java.util.Locale.CHINESE;
import static org.assertj.core.api.Assertions.assertThat;
@@ -247,13 +246,14 @@ public void testSqlEscapeFunctions()
public void testSessionProperties()
throws SQLException
{
- final String distributedJoin = "distributed_join";
+ final String joinDistributionType = "join_distribution_type";
+ final String defaultValue = new FeaturesConfig().getJoinDistributionType().name();
- assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString());
- setSessionProperty(connection, distributedJoin, FALSE.toString());
- assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(FALSE.toString());
- resetSessionProperty(connection, distributedJoin);
- assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString());
+ assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue);
+ setSessionProperty(connection, joinDistributionType, "REPLICATED");
+ assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo("REPLICATED");
+ resetSessionProperty(connection, joinDistributionType);
+ assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue);
}
private QueryResult queryResult(Statement statement, String query)
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java
index 6f683b3ef7578..7a767c2264b21 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java
@@ -171,11 +171,10 @@ private static TestingPrestoServer createTestingPrestoServer(URI discoveryUri, b
.put("compiler.interpreter-enabled", "false")
.put("task.max-index-memory", "16kB") // causes index joins to fault load
.put("datasources", "system")
- .put("distributed-index-joins-enabled", "true")
.put("optimizer.optimize-mixed-distinct-aggregations", "true");
if (coordinator) {
propertiesBuilder.put("node-scheduler.include-coordinator", "true");
- propertiesBuilder.put("distributed-joins-enabled", "true");
+ propertiesBuilder.put("join-distribution-type", "REPARTITIONED");
}
HashMap properties = new HashMap<>(propertiesBuilder.build());
properties.putAll(extraProperties);