calculatedStats = rule.calculate(node, lookup, session, types);
+ if (calculatedStats.isPresent()) {
+ return normalize(node, calculatedStats.get());
+ }
+ }
+ return PlanNodeStatsEstimate.UNKNOWN_STATS;
+ }
+
+ private PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate)
+ {
+ for (Normalizer normalizer : normalizers) {
+ estimate = normalizer.normalize(node, estimate, types);
+ }
+ return estimate;
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
index f5d1f87743a21..f13394a78e1df 100644
--- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java
@@ -11,28 +11,58 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.facebook.presto.cost;
import com.facebook.presto.Session;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.PlanNode;
-import com.facebook.presto.sql.planner.plan.PlanNodeId;
+import com.google.inject.BindingAnnotation;
+import java.lang.annotation.Retention;
+import java.lang.annotation.Target;
import java.util.Map;
+import static java.lang.annotation.ElementType.PARAMETER;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+
/**
* Interface of cost calculator.
- *
- * It's responsibility is to provide approximation of cost of execution of plan node.
- * Example implementations may be based on table statistics or data samples.
+ *
+ * Computes estimated cost of executing given PlanNode.
+ * Implementation may use lookup to compute needed traits for self/source nodes.
*/
public interface CostCalculator
{
- Map calculateCostForPlan(Session session, Map types, PlanNode planNode);
+ PlanNodeCostEstimate calculateCost(
+ PlanNode planNode,
+ Lookup lookup,
+ Session session,
+ Map types);
- default PlanNodeCost calculateCostForNode(Session session, Map types, PlanNode planNode)
+ default PlanNodeCostEstimate calculateCumulativeCost(
+ PlanNode planNode,
+ Lookup lookup,
+ Session session,
+ Map types)
{
- return calculateCostForPlan(session, types, planNode).get(planNode.getId());
+ PlanNodeCostEstimate cost = calculateCost(planNode, lookup, session, types);
+
+ if (!planNode.getSources().isEmpty()) {
+ PlanNodeCostEstimate childrenCost = planNode.getSources().stream()
+ .map(child -> lookup.getCumulativeCost(child, session, types))
+ .reduce(PlanNodeCostEstimate.ZERO_COST, PlanNodeCostEstimate::add);
+
+ return cost.add(childrenCost);
+ }
+
+ return cost;
}
+
+ @BindingAnnotation
+ @Target({PARAMETER})
+ @Retention(RUNTIME)
+ @interface EstimatedExchanges {}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java
new file mode 100644
index 0000000000000..be0e312ae22a0
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java
@@ -0,0 +1,238 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.metadata.InternalNodeManager;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.FilterNode;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.LimitNode;
+import com.facebook.presto.sql.planner.plan.OutputNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanVisitor;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.facebook.presto.sql.planner.plan.TableScanNode;
+import com.facebook.presto.sql.planner.plan.ValuesNode;
+
+import javax.annotation.concurrent.ThreadSafe;
+import javax.inject.Inject;
+
+import java.util.Map;
+
+import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST;
+import static com.facebook.presto.cost.PlanNodeCostEstimate.ZERO_COST;
+import static com.facebook.presto.cost.PlanNodeCostEstimate.cpuCost;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Simple implementation of CostCalculator. It assumes that ExchangeNodes are already in the plan.
+ */
+@ThreadSafe
+public class CostCalculatorUsingExchanges
+ implements CostCalculator
+{
+ private final int numberOfNodes;
+
+ @Inject
+ public CostCalculatorUsingExchanges(InternalNodeManager nodeManager)
+ {
+ this(nodeManager.getAllNodes().getActiveNodes().size());
+ }
+
+ public CostCalculatorUsingExchanges(int numberOfNodes)
+ {
+ this.numberOfNodes = numberOfNodes;
+ }
+
+ @Override
+ public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ CostEstimator costEstimator = new CostEstimator(
+ session,
+ types,
+ lookup,
+ numberOfNodes);
+
+ return planNode.accept(costEstimator, null);
+ }
+
+ private class CostEstimator
+ extends PlanVisitor
+ {
+ private final Session session;
+ private final Map types;
+ private final Lookup lookup;
+ private final int numberOfNodes;
+
+ public CostEstimator(Session session, Map types, Lookup lookup, int numberOfNodes)
+ {
+ this.session = requireNonNull(session, "session is null");
+ this.types = requireNonNull(types, "types is null");
+ this.lookup = lookup;
+ this.numberOfNodes = numberOfNodes;
+ }
+
+ @Override
+ protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context)
+ {
+ return UNKNOWN_COST;
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitOutput(OutputNode node, Void context)
+ {
+ return ZERO_COST;
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitFilter(FilterNode node, Void context)
+ {
+ return cpuCost(getStats(node.getSource()).getOutputSizeInBytes());
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitProject(ProjectNode node, Void context)
+ {
+ return cpuCost(getStats(node).getOutputSizeInBytes());
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context)
+ {
+ PlanNodeStatsEstimate aggregationStats = getStats(node);
+ PlanNodeStatsEstimate sourceStats = getStats(node.getSource());
+ return PlanNodeCostEstimate.builder()
+ .setCpuCost(sourceStats.getOutputSizeInBytes())
+ .setMemoryCost(aggregationStats.getOutputSizeInBytes())
+ .setNetworkCost(0)
+ .build();
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitJoin(JoinNode node, Void context)
+ {
+ return calculateJoinCost(
+ node,
+ node.getLeft(),
+ node.getRight(),
+ node.getDistributionType().orElse(JoinNode.DistributionType.PARTITIONED).equals(JoinNode.DistributionType.REPLICATED));
+ }
+
+ private PlanNodeCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated)
+ {
+ int numberOfNodesMultiplier = replicated ? numberOfNodes : 1;
+
+ PlanNodeStatsEstimate probeStats = getStats(probe);
+ PlanNodeStatsEstimate buildStats = getStats(build);
+ PlanNodeStatsEstimate outputStats = getStats(join);
+
+ double cpuCost = probeStats.getOutputSizeInBytes() +
+ buildStats.getOutputSizeInBytes() * numberOfNodesMultiplier +
+ outputStats.getOutputSizeInBytes();
+
+ double memoryCost = buildStats.getOutputSizeInBytes() * numberOfNodesMultiplier;
+
+ return PlanNodeCostEstimate.builder()
+ .setCpuCost(cpuCost)
+ .setMemoryCost(memoryCost)
+ .setNetworkCost(0)
+ .build();
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context)
+ {
+ return calculateExchangeCost(numberOfNodes, getStats(node), node.getType(), node.getScope());
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitTableScan(TableScanNode node, Void context)
+ {
+ return cpuCost(getStats(node).getOutputSizeInBytes()); // TODO: add network cost, based on input size in bytes?
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitValues(ValuesNode node, Void context)
+ {
+ return ZERO_COST;
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context)
+ {
+ return ZERO_COST;
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context)
+ {
+ return calculateJoinCost(
+ node,
+ node.getSource(),
+ node.getFilteringSource(),
+ node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED));
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitLimit(LimitNode node, Void context)
+ {
+ return cpuCost(getStats(node).getOutputSizeInBytes());
+ }
+
+ private PlanNodeStatsEstimate getStats(PlanNode node)
+ {
+ return lookup.getStats(node, session, types);
+ }
+ }
+
+ public static PlanNodeCostEstimate calculateExchangeCost(int numberOfNodes, PlanNodeStatsEstimate exchangeStats, ExchangeNode.Type type, ExchangeNode.Scope scope)
+ {
+ double network = 0;
+ double cpu = 0;
+
+ switch (type) {
+ case GATHER:
+ network = exchangeStats.getOutputSizeInBytes();
+ break;
+ case REPARTITION:
+ network = exchangeStats.getOutputSizeInBytes();
+ cpu = exchangeStats.getOutputSizeInBytes();
+ break;
+ case REPLICATE:
+ network = exchangeStats.getOutputSizeInBytes() * numberOfNodes;
+ break;
+ default:
+ throw new UnsupportedOperationException(format("Unsupported type [%s] of the exchange", type));
+ }
+
+ if (scope.equals(ExchangeNode.Scope.LOCAL)) {
+ network = 0;
+ }
+
+ return PlanNodeCostEstimate.builder()
+ .setNetworkCost(network)
+ .setCpuCost(cpu)
+ .setMemoryCost(0)
+ .build();
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java
new file mode 100644
index 0000000000000..4241dcaa7be55
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.metadata.InternalNodeManager;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.PlanVisitor;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+
+import javax.annotation.concurrent.ThreadSafe;
+import javax.inject.Inject;
+
+import java.util.Map;
+
+import static com.facebook.presto.cost.PlanNodeCostEstimate.ZERO_COST;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
+import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This is a wrapper class around CostCalculator that estimates ExchangeNodes cost.
+ */
+@ThreadSafe
+public class CostCalculatorWithEstimatedExchanges
+ implements CostCalculator
+{
+ private final CostCalculator costCalculator;
+ private final int numberOfNodes;
+
+ @Inject
+ public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, InternalNodeManager nodeManager)
+ {
+ this(costCalculator, nodeManager.getAllNodes().getActiveNodes().size());
+ }
+
+ public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, int numberOfNodes)
+ {
+ this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
+ this.numberOfNodes = numberOfNodes;
+ }
+
+ @Override
+ public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(
+ session,
+ types,
+ lookup,
+ numberOfNodes);
+ PlanNodeCostEstimate estimatedExchangeCost = planNode.accept(exchangeCostEstimator, null);
+
+ return costCalculator.calculateCost(planNode, lookup, session, types).add(estimatedExchangeCost);
+ }
+
+ private class ExchangeCostEstimator
+ extends PlanVisitor
+ {
+ private final Session session;
+ private final Map types;
+ private final Lookup lookup;
+ private final int numberOfNodes;
+
+ public ExchangeCostEstimator(Session session, Map types, Lookup lookup, int numberOfNodes)
+ {
+ this.session = requireNonNull(session, "session is null");
+ this.types = requireNonNull(types, "types is null");
+ this.lookup = lookup;
+ this.numberOfNodes = numberOfNodes;
+ }
+
+ @Override
+ protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context)
+ {
+ return ZERO_COST;
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context)
+ {
+ return CostCalculatorUsingExchanges.calculateExchangeCost(
+ numberOfNodes,
+ getStats(node.getSource()),
+ REPARTITION,
+ REMOTE);
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitJoin(JoinNode node, Void context)
+ {
+ return calculateJoinCost(
+ node.getLeft(),
+ node.getRight(),
+ node.getDistributionType().orElse(JoinNode.DistributionType.PARTITIONED).equals(JoinNode.DistributionType.REPLICATED));
+ }
+
+ @Override
+ public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context)
+ {
+ return calculateJoinCost(
+ node.getSource(),
+ node.getFilteringSource(),
+ node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED));
+ }
+
+ private PlanNodeCostEstimate calculateJoinCost(PlanNode probe, PlanNode build, boolean replicated)
+ {
+ if (replicated) {
+ return CostCalculatorUsingExchanges.calculateExchangeCost(
+ numberOfNodes,
+ getStats(build),
+ REPLICATE,
+ REMOTE);
+ }
+ else {
+ PlanNodeCostEstimate probeCost = CostCalculatorUsingExchanges.calculateExchangeCost(
+ numberOfNodes,
+ getStats(probe),
+ REPARTITION,
+ REMOTE);
+ PlanNodeCostEstimate buildCost = CostCalculatorUsingExchanges.calculateExchangeCost(
+ numberOfNodes,
+ getStats(build),
+ REPARTITION,
+ REMOTE);
+ return probeCost.add(buildCost);
+ }
+ }
+
+ private PlanNodeStatsEstimate getStats(PlanNode node)
+ {
+ return lookup.getStats(node, session, types);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java
new file mode 100644
index 0000000000000..b37e8f32f3051
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.google.common.annotations.VisibleForTesting;
+
+import javax.inject.Inject;
+
+import java.util.Comparator;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+public class CostComparator
+{
+ private final double cpuWeight;
+ private final double memoryWeight;
+ private final double networkWeight;
+
+ @Inject
+ public CostComparator(FeaturesConfig featuresConfig)
+ {
+ this(featuresConfig.getCpuCostWeight(), featuresConfig.getMemoryCostWeight(), featuresConfig.getNetworkCostWeight());
+ }
+
+ @VisibleForTesting
+ public CostComparator(double cpuWeight, double memoryWeight, double networkWeight)
+ {
+ checkArgument(cpuWeight >= 0, "cpuWeight can not be negative");
+ checkArgument(memoryWeight >= 0, "memoryWeight can not be negative");
+ checkArgument(networkWeight >= 0, "networkWeight can not be negative");
+ this.cpuWeight = cpuWeight;
+ this.memoryWeight = memoryWeight;
+ this.networkWeight = networkWeight;
+ }
+
+ public Comparator forSession(Session session)
+ {
+ return (left, right) -> this.compare(session, left, right);
+ }
+
+ public int compare(Session session, PlanNodeCostEstimate left, PlanNodeCostEstimate right)
+ {
+ requireNonNull(session, "session can not be null");
+ requireNonNull(left, "left can not be null");
+ requireNonNull(right, "right can not be null");
+ checkArgument(!left.hasUnknownComponents() && !right.hasUnknownComponents(), "cannot compare unknown costs");
+ double leftCost = left.getCpuCost() * cpuWeight
+ + left.getMemoryCost() * memoryWeight
+ + left.getNetworkCost() * networkWeight;
+
+ double rightCost = right.getCpuCost() * cpuWeight
+ + right.getMemoryCost() * memoryWeight
+ + right.getNetworkCost() * networkWeight;
+
+ return Double.compare(leftCost, rightCost);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java b/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java
new file mode 100644
index 0000000000000..23eefc5c0c235
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.metadata.FunctionRegistry;
+import com.facebook.presto.metadata.Signature;
+import com.facebook.presto.operator.scalar.ScalarFunctionImplementation;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.type.BigintType;
+import com.facebook.presto.spi.type.BooleanType;
+import com.facebook.presto.spi.type.DecimalType;
+import com.facebook.presto.spi.type.DoubleType;
+import com.facebook.presto.spi.type.IntegerType;
+import com.facebook.presto.spi.type.RealType;
+import com.facebook.presto.spi.type.SmallintType;
+import com.facebook.presto.spi.type.TinyintType;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.spi.type.VarcharType;
+import com.facebook.presto.sql.planner.ExpressionInterpreter;
+import io.airlift.slice.Slice;
+
+import java.util.OptionalDouble;
+
+import static java.util.Collections.singletonList;
+
+/**
+ * This will contain set of function used in process of calculation stats.
+ * It is mostly for mapping Type domain to double domain which is used for range comparisons
+ * during stats computations.
+ */
+public class DomainConverter
+{
+ private final Type type;
+ private final FunctionRegistry functionRegistry;
+ private final ConnectorSession session;
+
+ public DomainConverter(Type type, FunctionRegistry functionRegistry, ConnectorSession session)
+ {
+ this.type = type;
+ this.functionRegistry = functionRegistry;
+ this.session = session;
+ }
+
+ public Slice castToVarchar(Object object)
+ {
+ Signature castSignature = functionRegistry.getCoercion(type, VarcharType.createUnboundedVarcharType());
+ ScalarFunctionImplementation castImplementation = functionRegistry.getScalarFunctionImplementation(castSignature);
+ return (Slice) ExpressionInterpreter.invoke(session, castImplementation, singletonList(object));
+ }
+
+ public OptionalDouble translateToDouble(Object object)
+ {
+ if (!isDoubleTranslationSupported(type)) {
+ return OptionalDouble.empty();
+ }
+ Signature castSignature = functionRegistry.getCoercion(type, DoubleType.DOUBLE);
+ ScalarFunctionImplementation castImplementation = functionRegistry.getScalarFunctionImplementation(castSignature);
+ return OptionalDouble.of((double) ExpressionInterpreter.invoke(session, castImplementation, singletonList(object)));
+ }
+
+ private boolean isDoubleTranslationSupported(Type type)
+ {
+ return type instanceof DecimalType
+ || DoubleType.DOUBLE.equals(type)
+ || RealType.REAL.equals(type)
+ || BigintType.BIGINT.equals(type)
+ || IntegerType.INTEGER.equals(type)
+ || SmallintType.SMALLINT.equals(type)
+ || TinyintType.TINYINT.equals(type)
+ || BooleanType.BOOLEAN.equals(type);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java
new file mode 100644
index 0000000000000..0bd95ffbbc596
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+import java.util.Optional;
+
+public class EnforceSingleRowStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(EnforceSingleRowNode.class);
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ return Optional.of(
+ PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(1)
+ .build());
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java
new file mode 100644
index 0000000000000..c8d4b20d88ff1
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+
+import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom;
+import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS;
+import static com.google.common.base.Predicates.not;
+
+public class EnsureStatsMatchOutput
+ implements ComposableStatsCalculator.Normalizer
+{
+ @Override
+ public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types)
+ {
+ PlanNodeStatsEstimate.Builder builder = buildFrom(estimate);
+
+ node.getOutputSymbols().stream()
+ .filter(not(estimate.getSymbolsWithKnownStatistics()::contains))
+ .forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS));
+
+ estimate.getSymbolsWithKnownStatistics().stream()
+ .filter(not(node.getOutputSymbols()::contains))
+ .forEach(builder::removeSymbolStatistics);
+
+ return builder.build();
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java
new file mode 100644
index 0000000000000..f71444cdbcd55
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStats;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+// WIP
+public class ExchangeStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(ExchangeNode.class);
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ ExchangeNode exchangeNode = (ExchangeNode) node;
+ // QUESTION should I check partitioning schema?
+
+ Optional estimate = Optional.empty();
+ for (int i = 0; i < node.getSources().size(); i++) {
+ PlanNode source = node.getSources().get(i);
+ PlanNodeStatsEstimate sourceStats = lookup.getStats(source, session, types);
+
+ PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, exchangeNode.getInputs().get(i), exchangeNode.getOutputSymbols());
+
+ if (estimate.isPresent()) {
+ estimate = Optional.of(addStats(estimate.get(), sourceStatsWithMappedSymbols));
+ }
+ else {
+ estimate = Optional.of(sourceStatsWithMappedSymbols);
+ }
+ }
+
+ checkState(estimate.isPresent());
+ return estimate;
+ }
+
+ private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, List inputs, List outputs)
+ {
+ checkArgument(inputs.size() == outputs.size(), "Inputs does not match outputs");
+ PlanNodeStatsEstimate.Builder mapped = PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(estimate.getOutputRowCount());
+
+ for (int i = 0; i < inputs.size(); i++) {
+ mapped.addSymbolStatistics(outputs.get(i), estimate.getSymbolStatistics(inputs.get(i)));
+ }
+
+ return mapped.build();
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java
new file mode 100644
index 0000000000000..67b97fca260bd
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java
@@ -0,0 +1,271 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.LiteralInterpreter;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.tree.AstVisitor;
+import com.facebook.presto.sql.tree.BetweenPredicate;
+import com.facebook.presto.sql.tree.BooleanLiteral;
+import com.facebook.presto.sql.tree.ComparisonExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.facebook.presto.sql.tree.InListExpression;
+import com.facebook.presto.sql.tree.InPredicate;
+import com.facebook.presto.sql.tree.IsNotNullPredicate;
+import com.facebook.presto.sql.tree.IsNullPredicate;
+import com.facebook.presto.sql.tree.Literal;
+import com.facebook.presto.sql.tree.LogicalBinaryExpression;
+import com.facebook.presto.sql.tree.NotExpression;
+import com.facebook.presto.sql.tree.SymbolReference;
+
+import javax.inject.Inject;
+
+import java.util.Map;
+
+import static com.facebook.presto.cost.ComparisonStatsCalculator.comparisonSymbolToLiteralStats;
+import static com.facebook.presto.cost.ComparisonStatsCalculator.comparisonSymbolToSymbolStats;
+import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStats;
+import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.differenceInNonRangeStats;
+import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.differenceInStats;
+import static com.facebook.presto.cost.SymbolStatsEstimate.buildFrom;
+import static com.facebook.presto.sql.ExpressionUtils.and;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN_OR_EQUAL;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN_OR_EQUAL;
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.Double.NaN;
+import static java.lang.Double.isInfinite;
+import static java.lang.Double.min;
+import static java.lang.String.format;
+
+public class FilterStatsCalculator
+{
+ private final Metadata metadata;
+
+ @Inject
+ public FilterStatsCalculator(Metadata metadata)
+ {
+ this.metadata = metadata;
+ }
+
+ public PlanNodeStatsEstimate filterStats(
+ PlanNodeStatsEstimate statsEstimate,
+ Expression predicate,
+ Session session,
+ Map types)
+ {
+ return new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process(predicate);
+ }
+
+ public static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate inputStatistics)
+ {
+ return inputStatistics.mapOutputRowCount(size -> size * 0.5);
+ }
+
+ private class FilterExpressionStatsCalculatingVisitor
+ extends AstVisitor
+ {
+ private final PlanNodeStatsEstimate input;
+ private final Session session;
+ private final Map types;
+
+ FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, Map types)
+ {
+ this.input = input;
+ this.session = session;
+ this.types = types;
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitExpression(Expression node, Void context)
+ {
+ return filterForUnknownExpression();
+ }
+
+ private PlanNodeStatsEstimate filterForUnknownExpression()
+ {
+ return filterStatsForUnknownExpression(input);
+ }
+
+ private PlanNodeStatsEstimate filterForFalseExpression()
+ {
+ PlanNodeStatsEstimate.Builder falseStatsBuilder = PlanNodeStatsEstimate.builder();
+
+ input.getSymbolsWithKnownStatistics().forEach(
+ symbol ->
+ falseStatsBuilder.addSymbolStatistics(symbol,
+ buildFrom(input.getSymbolStatistics(symbol))
+ .setLowValue(NaN)
+ .setHighValue(NaN)
+ .setDistinctValuesCount(0.0)
+ .setNullsFraction(NaN).build()));
+
+ return falseStatsBuilder.setOutputRowCount(0.0).build();
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context)
+ {
+ return differenceInStats(input, process(node.getValue()));
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context)
+ {
+ PlanNodeStatsEstimate leftStats = process(node.getLeft());
+ PlanNodeStatsEstimate rightStats = process(node.getRight());
+ PlanNodeStatsEstimate andStats = new FilterExpressionStatsCalculatingVisitor(leftStats, session, types).process(node.getRight());
+
+ switch (node.getType()) {
+ case AND:
+ return andStats;
+ case OR:
+ return differenceInNonRangeStats(addStats(leftStats, rightStats), andStats);
+ default:
+ checkState(false, format("Unimplemented logical binary operator expression %s", node.getType()));
+ return PlanNodeStatsEstimate.UNKNOWN_STATS;
+ }
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void context)
+ {
+ if (node.equals(BooleanLiteral.TRUE_LITERAL)) {
+ return input;
+ }
+ else {
+ return filterForFalseExpression();
+ }
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context)
+ {
+ if (node.getValue() instanceof SymbolReference) {
+ Symbol symbol = Symbol.from(node.getValue());
+ SymbolStatsEstimate symbolStatsEstimate = input.getSymbolStatistics(symbol);
+ return input.mapOutputRowCount(rowCount -> rowCount * (1 - symbolStatsEstimate.getNullsFraction()))
+ .mapSymbolColumnStatistics(symbol, statsEstimate -> statsEstimate.mapNullsFraction(x -> 0.0));
+ }
+ return visitExpression(node, context);
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context)
+ {
+ if (node.getValue() instanceof SymbolReference) {
+ Symbol symbol = Symbol.from(node.getValue());
+ SymbolStatsEstimate symbolStatsEstimate = input.getSymbolStatistics(symbol);
+ return input.mapOutputRowCount(rowCount -> rowCount * symbolStatsEstimate.getNullsFraction())
+ .mapSymbolColumnStatistics(symbol, statsEstimate ->
+ SymbolStatsEstimate.builder().setNullsFraction(1.0)
+ .setLowValue(NaN)
+ .setHighValue(NaN)
+ .setDistinctValuesCount(0.0).build());
+ }
+ return visitExpression(node, context);
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context)
+ {
+ if (!(node.getValue() instanceof SymbolReference) || !(node.getMin() instanceof Literal) || !(node.getMax() instanceof Literal)) {
+ return visitExpression(node, context);
+ }
+
+ SymbolStatsEstimate valueStats = input.getSymbolStatistics(Symbol.from((SymbolReference) node.getValue()));
+ Expression leftComparison;
+ Expression rightComparison;
+
+ // We want to do heuristic cut (infinite range to finite range) ASAP and than do filtering on finite range.
+ if (isInfinite(valueStats.getLowValue())) {
+ leftComparison = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
+ rightComparison = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
+ }
+ else {
+ rightComparison = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
+ leftComparison = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
+ }
+
+ // we relay on and processing left to right
+ return process(and(leftComparison, rightComparison));
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
+ {
+ if (!(node.getValueList() instanceof InListExpression) || !(node.getValue() instanceof SymbolReference)) {
+ return visitExpression(node, context);
+ }
+
+ InListExpression inList = (InListExpression) node.getValueList();
+ PlanNodeStatsEstimate statsSum = inList.getValues().stream()
+ .map(inValue -> process(new ComparisonExpression(EQUAL, node.getValue(), inValue)))
+ .reduce(filterForFalseExpression(),
+ PlanNodeStatsEstimateMath::addStats);
+
+ Symbol inValueSymbol = Symbol.from(node.getValue());
+ SymbolStatsEstimate symbolStat = input.getSymbolStatistics(inValueSymbol);
+ double notNullValuesBeforeIn = input.getOutputRowCount() * (1 - symbolStat.getNullsFraction());
+
+ return statsSum.mapOutputRowCount(rowCount -> min(rowCount, notNullValuesBeforeIn))
+ .mapSymbolColumnStatistics(inValueSymbol,
+ symbolStats ->
+ symbolStats.mapNullsFraction(x -> 0.0)
+ .mapDistinctValuesCount(distinctValues ->
+ min(distinctValues, input.getSymbolStatistics(inValueSymbol).getDistinctValuesCount())));
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context)
+ {
+ if (node.getLeft() instanceof SymbolReference && node.getRight() instanceof SymbolReference) {
+ return comparisonSymbolToSymbolStats(input,
+ Symbol.from(node.getLeft()),
+ Symbol.from(node.getRight()),
+ node.getType()
+ );
+ }
+ else if (node.getLeft() instanceof SymbolReference && node.getRight() instanceof Literal) {
+ Symbol symbol = Symbol.from(node.getLeft());
+ return comparisonSymbolToLiteralStats(input,
+ symbol,
+ doubleValueFromLiteral(types.get(symbol), (Literal) node.getRight()),
+ node.getType()
+ );
+ }
+ else if (node.getLeft() instanceof Literal && node.getRight() instanceof SymbolReference) {
+ Symbol symbol = Symbol.from(node.getRight());
+ return comparisonSymbolToLiteralStats(input,
+ symbol,
+ doubleValueFromLiteral(types.get(symbol), (Literal) node.getLeft()),
+ node.getType().flip()
+ );
+ }
+ else {
+ return filterStatsForUnknownExpression(input);
+ }
+ }
+
+ private double doubleValueFromLiteral(Type type, Literal literal)
+ {
+ Object literalValue = LiteralInterpreter.evaluate(metadata, session.toConnectorSession(), literal);
+ DomainConverter domainConverter = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession());
+ return domainConverter.translateToDouble(literalValue).orElse(NaN);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java
new file mode 100644
index 0000000000000..0044b9f6d67ca
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.FilterNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+import java.util.Optional;
+
+public class FilterStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(FilterNode.class);
+
+ private final FilterStatsCalculator filterStatsCalculator;
+
+ public FilterStatsRule(FilterStatsCalculator filterStatsCalculator)
+ {
+ this.filterStatsCalculator = filterStatsCalculator;
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ FilterNode filterNode = (FilterNode) node;
+ PlanNodeStatsEstimate sourceStats = lookup.getStats(filterNode.getSource(), session, types);
+ return Optional.of(filterStatsCalculator.filterStats(sourceStats, filterNode.getPredicate(), session, types));
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java
new file mode 100644
index 0000000000000..f07ff86e7c473
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.presto.cost;
+
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.sql.planner.plan.IntersectNode;
+
+import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.intersect;
+
+public class IntersectStatsRule
+ extends AbstractSetOperationStatsRule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(IntersectNode.class);
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ protected PlanNodeStatsEstimate operate(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second)
+ {
+ return intersect(first, second);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java
new file mode 100644
index 0000000000000..204ab30f3f2bb
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.tree.ComparisonExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.google.common.annotations.VisibleForTesting;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS;
+import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
+import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
+import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
+import static com.facebook.presto.util.MoreMath.rangeMax;
+import static com.facebook.presto.util.MoreMath.rangeMin;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static com.google.common.collect.Sets.difference;
+import static java.lang.Double.NaN;
+
+public class JoinStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(JoinNode.class);
+
+ private final FilterStatsCalculator filterStatsCalculator;
+
+ public JoinStatsRule(FilterStatsCalculator filterStatsCalculator)
+ {
+ this.filterStatsCalculator = filterStatsCalculator;
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ JoinNode joinNode = (JoinNode) node;
+
+ PlanNodeStatsEstimate leftStats = lookup.getStats(joinNode.getLeft(), session, types);
+ PlanNodeStatsEstimate rightStats = lookup.getStats(joinNode.getRight(), session, types);
+
+ switch (joinNode.getType()) {
+ case INNER:
+ return Optional.of(computeInnerJoinStats(joinNode, leftStats, rightStats, session, types));
+ case LEFT:
+ return Optional.of(computeLeftJoinStats(joinNode, leftStats, rightStats, session, types));
+ case RIGHT:
+ return Optional.of(computeRightJoinStats(joinNode, leftStats, rightStats, session, types));
+ case FULL:
+ return Optional.of(computeFullJoinStats(joinNode, leftStats, rightStats, session, types));
+ default:
+ return Optional.empty();
+ }
+ }
+
+ private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types)
+ {
+ PlanNodeStatsEstimate rightAntiJoinStats = calculateAntiJoinStats(node.getFilter(), flippedCriteria(node), rightStats, leftStats);
+ return addAntiJoinStats(computeLeftJoinStats(node, leftStats, rightStats, session, types), rightAntiJoinStats, getRightJoinSymbols(node));
+ }
+
+ private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types)
+ {
+ PlanNodeStatsEstimate innerJoinStats = computeInnerJoinStats(node, leftStats, rightStats, session, types);
+ PlanNodeStatsEstimate leftAntiJoinStats = calculateAntiJoinStats(node.getFilter(), node.getCriteria(), leftStats, rightStats);
+ return addAntiJoinStats(innerJoinStats, leftAntiJoinStats, getLeftJoinSymbols(node));
+ }
+
+ private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types)
+ {
+ PlanNodeStatsEstimate innerJoinStats = computeInnerJoinStats(node, leftStats, rightStats, session, types);
+ PlanNodeStatsEstimate rightAntiJoinStats = calculateAntiJoinStats(node.getFilter(), flippedCriteria(node), rightStats, leftStats);
+ return addAntiJoinStats(innerJoinStats, rightAntiJoinStats, getRightJoinSymbols(node));
+ }
+
+ private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types)
+ {
+ List comparisons = node.getCriteria().stream()
+ .map(criteria -> new ComparisonExpression(EQUAL, criteria.getLeft().toSymbolReference(), criteria.getRight().toSymbolReference()))
+ .collect(toImmutableList());
+ Expression predicate = combineConjuncts(combineConjuncts(comparisons), node.getFilter().orElse(TRUE_LITERAL));
+ PlanNodeStatsEstimate crossJoinStats = crossJoinStats(node, leftStats, rightStats);
+ return filterStatsCalculator.filterStats(crossJoinStats, predicate, session, types);
+ }
+
+ @VisibleForTesting
+ PlanNodeStatsEstimate calculateAntiJoinStats(Optional filter, List criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats)
+ {
+ // TODO: add support for non-equality conditions (e.g: <=, !=, >)
+ if (filter.isPresent()) {
+ // non-equi filters are not supported
+ return UNKNOWN_STATS;
+ }
+
+ PlanNodeStatsEstimate outputStats = leftStats;
+
+ for (EquiJoinClause clause : criteria) {
+ SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(clause.getLeft());
+ SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(clause.getRight());
+
+ StatisticRange rightRange = StatisticRange.from(rightColumnStats);
+ StatisticRange antiRange = StatisticRange.from(leftColumnStats)
+ .subtract(rightRange);
+
+ // TODO: use NDVs from left and right StatisticRange when they are fixed
+ double leftNDV = leftColumnStats.getDistinctValuesCount();
+ double rightNDV = rightColumnStats.getDistinctValuesCount();
+
+ if (leftNDV > rightNDV) {
+ double selectedRangeFraction = leftColumnStats.getValuesFraction() * (leftNDV - rightNDV) / leftNDV;
+ double scaleFactor = selectedRangeFraction + leftColumnStats.getNullsFraction();
+ double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor;
+ outputStats = outputStats.mapSymbolColumnStatistics(clause.getLeft(), columnStats ->
+ SymbolStatsEstimate.buildFrom(columnStats)
+ .setLowValue(antiRange.getLow())
+ .setHighValue(antiRange.getHigh())
+ .setNullsFraction(newLeftNullsFraction)
+ .setDistinctValuesCount(leftNDV - rightNDV)
+ .build());
+ outputStats = outputStats.mapOutputRowCount(rowCount -> rowCount * scaleFactor);
+ }
+ else if (leftNDV <= rightNDV) {
+ // only null values are left
+ outputStats = outputStats.mapSymbolColumnStatistics(clause.getLeft(), columnStats ->
+ SymbolStatsEstimate.buildFrom(columnStats)
+ .setLowValue(NaN)
+ .setHighValue(NaN)
+ .setNullsFraction(1.0)
+ .setDistinctValuesCount(0.0)
+ .build());
+ outputStats = outputStats.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction());
+ }
+ else {
+ // either leftNDV or rightNDV is NaN
+ return UNKNOWN_STATS;
+ }
+ }
+
+ return outputStats;
+ }
+
+ @VisibleForTesting
+ PlanNodeStatsEstimate addAntiJoinStats(PlanNodeStatsEstimate joinStats, PlanNodeStatsEstimate antiJoinStats, Set joinSymbols)
+ {
+ checkState(joinStats.getSymbolsWithKnownStatistics().containsAll(antiJoinStats.getSymbolsWithKnownStatistics()));
+
+ double joinOutputRowCount = joinStats.getOutputRowCount();
+ double antiJoinOutputRowCount = antiJoinStats.getOutputRowCount();
+ double totalRowCount = joinOutputRowCount + antiJoinOutputRowCount;
+ PlanNodeStatsEstimate outputStats = joinStats.mapOutputRowCount(rowCount -> rowCount + antiJoinOutputRowCount);
+
+ for (Symbol symbol : antiJoinStats.getSymbolsWithKnownStatistics()) {
+ outputStats = outputStats.mapSymbolColumnStatistics(symbol, joinColumnStats -> {
+ SymbolStatsEstimate antiJoinColumnStats = antiJoinStats.getSymbolStatistics(symbol);
+ // weighted average
+ double newNullsFraction = (joinColumnStats.getNullsFraction() * joinOutputRowCount + antiJoinColumnStats.getNullsFraction() * antiJoinOutputRowCount) / totalRowCount;
+ double distinctValues;
+ if (joinSymbols.contains(symbol)) {
+ distinctValues = joinColumnStats.getDistinctValuesCount() + antiJoinColumnStats.getDistinctValuesCount();
+ }
+ else {
+ distinctValues = joinColumnStats.getDistinctValuesCount();
+ }
+ return SymbolStatsEstimate.buildFrom(joinColumnStats)
+ .setLowValue(rangeMin(joinColumnStats.getLowValue(), antiJoinColumnStats.getLowValue()))
+ .setHighValue(rangeMax(joinColumnStats.getHighValue(), antiJoinColumnStats.getHighValue()))
+ .setDistinctValuesCount(distinctValues)
+ .setNullsFraction(newNullsFraction)
+ .build();
+ });
+ }
+
+ // add nulls to columns that don't exist in right stats
+ for (Symbol symbol : difference(joinStats.getSymbolsWithKnownStatistics(), antiJoinStats.getSymbolsWithKnownStatistics())) {
+ outputStats = outputStats.mapSymbolColumnStatistics(symbol, joinColumnStats ->
+ joinColumnStats.mapNullsFraction(nullsFraction -> (nullsFraction * joinOutputRowCount + antiJoinOutputRowCount) / totalRowCount));
+ }
+
+ return outputStats;
+ }
+
+ private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats)
+ {
+ PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount());
+
+ node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, leftStats.getSymbolStatistics(symbol)));
+ node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, rightStats.getSymbolStatistics(symbol)));
+
+ return builder.build();
+ }
+
+ private Set getLeftJoinSymbols(JoinNode node)
+ {
+ return node.getCriteria().stream()
+ .map(EquiJoinClause::getLeft)
+ .collect(toImmutableSet());
+ }
+
+ private Set getRightJoinSymbols(JoinNode node)
+ {
+ return node.getCriteria().stream()
+ .map(EquiJoinClause::getRight)
+ .collect(toImmutableSet());
+ }
+
+ private List flippedCriteria(JoinNode node)
+ {
+ return node.getCriteria().stream()
+ .map(criteria -> new JoinNode.EquiJoinClause(criteria.getRight(), criteria.getLeft()))
+ .collect(toImmutableList());
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java
new file mode 100644
index 0000000000000..26175be3459b8
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.LimitNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+import java.util.Optional;
+
+public class LimitStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class);
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ LimitNode limitNode = (LimitNode) node;
+
+ PlanNodeStatsEstimate sourceStats = lookup.getStats(limitNode.getSource(), session, types);
+ PlanNodeStatsEstimate.Builder limitCost = PlanNodeStatsEstimate.builder();
+ // TODO special handling for NaN?
+ if (sourceStats.getOutputRowCount() < limitNode.getCount()) {
+ limitCost.setOutputRowCount(sourceStats.getOutputRowCount());
+ }
+ else {
+ limitCost.setOutputRowCount(limitNode.getCount());
+ }
+ return Optional.of(limitCost.build());
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java
new file mode 100644
index 0000000000000..4a079e376698c
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.OutputNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+import java.util.Optional;
+
+public class OutputStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(OutputNode.class);
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ OutputNode outputNode = (OutputNode) node;
+ return Optional.of(lookup.getStats(outputNode.getSource(), session, types));
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java
deleted file mode 100644
index c30eaa90538f3..0000000000000
--- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.facebook.presto.cost;
-
-import com.facebook.presto.spi.statistics.Estimate;
-
-import java.util.Objects;
-import java.util.function.Function;
-
-import static com.facebook.presto.spi.statistics.Estimate.unknownValue;
-import static java.util.Objects.requireNonNull;
-
-public class PlanNodeCost
-{
- public static final PlanNodeCost UNKNOWN_COST = PlanNodeCost.builder().build();
-
- private final Estimate outputRowCount;
- private final Estimate outputSizeInBytes;
-
- private PlanNodeCost(Estimate outputRowCount, Estimate outputSizeInBytes)
- {
- this.outputRowCount = requireNonNull(outputRowCount, "outputRowCount can not be null");
- this.outputSizeInBytes = requireNonNull(outputSizeInBytes, "outputSizeInBytes can not be null");
- }
-
- public Estimate getOutputRowCount()
- {
- return outputRowCount;
- }
-
- public Estimate getOutputSizeInBytes()
- {
- return outputSizeInBytes;
- }
-
- public PlanNodeCost mapOutputRowCount(Function mappingFunction)
- {
- return buildFrom(this).setOutputRowCount(outputRowCount.map(mappingFunction)).build();
- }
-
- public PlanNodeCost mapOutputSizeInBytes(Function mappingFunction)
- {
- return buildFrom(this).setOutputSizeInBytes(outputRowCount.map(mappingFunction)).build();
- }
-
- @Override
- public String toString()
- {
- return "PlanNodeCost{outputRowCount=" + outputRowCount + ", outputSizeInBytes=" + outputSizeInBytes + '}';
- }
-
- @Override
- public boolean equals(Object o)
- {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- PlanNodeCost that = (PlanNodeCost) o;
- return Objects.equals(outputRowCount, that.outputRowCount) &&
- Objects.equals(outputSizeInBytes, that.outputSizeInBytes);
- }
-
- @Override
- public int hashCode()
- {
- return Objects.hash(outputRowCount, outputSizeInBytes);
- }
-
- public static Builder builder()
- {
- return new Builder();
- }
-
- public static Builder buildFrom(PlanNodeCost other)
- {
- return builder().setOutputRowCount(other.getOutputRowCount())
- .setOutputSizeInBytes(other.getOutputSizeInBytes());
- }
-
- public static final class Builder
- {
- private Estimate outputRowCount = unknownValue();
- private Estimate outputSizeInBytes = unknownValue();
-
- public Builder setOutputRowCount(Estimate outputRowCount)
- {
- this.outputRowCount = outputRowCount;
- return this;
- }
-
- public Builder setOutputSizeInBytes(Estimate outputSizeInBytes)
- {
- this.outputSizeInBytes = outputSizeInBytes;
- return this;
- }
-
- public PlanNodeCost build()
- {
- return new PlanNodeCost(outputRowCount, outputSizeInBytes);
- }
- }
-}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java
new file mode 100644
index 0000000000000..2a1cb29139553
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import java.util.Objects;
+import java.util.Optional;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.Double.NaN;
+import static java.lang.Double.POSITIVE_INFINITY;
+import static java.lang.Double.isNaN;
+
+public class PlanNodeCostEstimate
+{
+ public static final PlanNodeCostEstimate INFINITE_COST = new PlanNodeCostEstimate(POSITIVE_INFINITY, POSITIVE_INFINITY, POSITIVE_INFINITY);
+ public static final PlanNodeCostEstimate UNKNOWN_COST = new PlanNodeCostEstimate(NaN, NaN, NaN);
+ public static final PlanNodeCostEstimate ZERO_COST = new PlanNodeCostEstimate(0, 0, 0);
+
+ private final double cpuCost;
+ private final double memoryCost;
+ private final double networkCost;
+
+ private PlanNodeCostEstimate(double cpuCost, double memoryCost, double networkCost)
+ {
+ checkArgument(isNaN(cpuCost) || cpuCost >= 0, "cpuCost cannot be negative");
+ checkArgument(isNaN(memoryCost) || memoryCost >= 0, "memoryCost cannot be negative");
+ checkArgument(isNaN(networkCost) || networkCost >= 0, "networkCost cannot be negative");
+ this.cpuCost = cpuCost;
+ this.memoryCost = memoryCost;
+ this.networkCost = networkCost;
+ }
+
+ /**
+ * Returns CPU component of the cost. Unknown value is represented by {@link Double#NaN}
+ */
+ public double getCpuCost()
+ {
+ return cpuCost;
+ }
+
+ /**
+ * Returns memory component of the cost. Unknown value is represented by {@link Double#NaN}
+ */
+ public double getMemoryCost()
+ {
+ return memoryCost;
+ }
+
+ /**
+ * Returns network component of the cost. Unknown value is represented by {@link Double#NaN}
+ */
+ public double getNetworkCost()
+ {
+ return networkCost;
+ }
+
+ /**
+ * Returns true if this cost has unknown components.
+ */
+ public boolean hasUnknownComponents()
+ {
+ return isNaN(cpuCost) || isNaN(memoryCost) || isNaN(networkCost);
+ }
+
+ @Override
+ public String toString()
+ {
+ return toStringHelper(this)
+ .add("cpuCost", cpuCost)
+ .add("memoryCost", memoryCost)
+ .add("networkCost", networkCost)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ PlanNodeCostEstimate that = (PlanNodeCostEstimate) o;
+ return Double.compare(that.cpuCost, cpuCost) == 0 &&
+ Double.compare(that.memoryCost, memoryCost) == 0 &&
+ Double.compare(that.networkCost, networkCost) == 0;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(cpuCost, memoryCost, networkCost);
+ }
+
+ public PlanNodeCostEstimate add(PlanNodeCostEstimate other)
+ {
+ return new PlanNodeCostEstimate(
+ cpuCost + other.cpuCost,
+ memoryCost + other.memoryCost,
+ networkCost + other.networkCost);
+ }
+
+ public static PlanNodeCostEstimate networkCost(double networkCost)
+ {
+ return builder().setCpuCost(0).setMemoryCost(0).setNetworkCost(networkCost).build();
+ }
+
+ public static PlanNodeCostEstimate cpuCost(double cpuCost)
+ {
+ return builder().setCpuCost(cpuCost).setMemoryCost(0).setNetworkCost(0).build();
+ }
+
+ public static PlanNodeCostEstimate memoryCost(double memoryCost)
+ {
+ return builder().setCpuCost(0).setMemoryCost(memoryCost).setNetworkCost(0).build();
+ }
+
+ public static Builder builder()
+ {
+ return new Builder();
+ }
+
+ public static final class Builder
+ {
+ private Optional cpuCost = Optional.empty();
+ private Optional memoryCost = Optional.empty();
+ private Optional networkCost = Optional.empty();
+
+ public Builder setFrom(PlanNodeCostEstimate otherStatistics)
+ {
+ return setCpuCost(otherStatistics.getCpuCost())
+ .setMemoryCost(otherStatistics.getMemoryCost())
+ .setNetworkCost(otherStatistics.getNetworkCost());
+ }
+
+ public Builder setCpuCost(double cpuCost)
+ {
+ checkState(!this.cpuCost.isPresent(), "cpuCost already set");
+ this.cpuCost = Optional.of(cpuCost);
+ return this;
+ }
+
+ public Builder setMemoryCost(double memoryCost)
+ {
+ checkState(!this.memoryCost.isPresent(), "memoryCost already set");
+ this.memoryCost = Optional.of(memoryCost);
+ return this;
+ }
+
+ public Builder setNetworkCost(double networkCost)
+ {
+ checkState(!this.networkCost.isPresent(), "networkCost already set");
+ this.networkCost = Optional.of(networkCost);
+ return this;
+ }
+
+ public PlanNodeCostEstimate build()
+ {
+ checkState(cpuCost.isPresent(), "cpuCost not set");
+ checkState(memoryCost.isPresent(), "memoryCost not set");
+ checkState(networkCost.isPresent(), "networkCost not set");
+ return new PlanNodeCostEstimate(cpuCost.get(), memoryCost.get(), networkCost.get());
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java
new file mode 100644
index 0000000000000..c01e64e8f355d
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.sql.planner.Symbol;
+import org.pcollections.HashTreePMap;
+import org.pcollections.PMap;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.function.Function;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.lang.Double.NaN;
+import static java.lang.Double.isNaN;
+
+public class PlanNodeStatsEstimate
+{
+ public static final PlanNodeStatsEstimate UNKNOWN_STATS = PlanNodeStatsEstimate.builder().build();
+ public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10;
+
+ private final double outputRowCount;
+ private final PMap symbolStatistics;
+
+ private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics)
+ {
+ checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative");
+ this.outputRowCount = outputRowCount;
+ this.symbolStatistics = symbolStatistics;
+ }
+
+ /**
+ * Returns estimated number of rows.
+ * Unknown value is represented by {@link Double#NaN}
+ */
+ public double getOutputRowCount()
+ {
+ return outputRowCount;
+ }
+
+ /**
+ * Returns estimated data size.
+ * Unknown value is represented by {@link Double#NaN}
+ */
+ public double getOutputSizeInBytes()
+ {
+ if (isNaN(outputRowCount)) {
+ return Double.NaN;
+ }
+ double outputSizeInBytes = 0;
+ for (Map.Entry entry : symbolStatistics.entrySet()) {
+ outputSizeInBytes += getOutputSizeForSymbol(entry.getValue());
+ }
+ return outputSizeInBytes;
+ }
+
+ private double getOutputSizeForSymbol(SymbolStatsEstimate symbolStatistics)
+ {
+ double averageRowSize = symbolStatistics.getAverageRowSize();
+ if (isNaN(averageRowSize)) {
+ // TODO take into consderation data type of column
+ return outputRowCount * DEFAULT_DATA_SIZE_PER_COLUMN;
+ }
+ return outputRowCount * averageRowSize;
+ }
+
+ public PlanNodeStatsEstimate mapOutputRowCount(Function mappingFunction)
+ {
+ return buildFrom(this).setOutputRowCount(mappingFunction.apply(outputRowCount)).build();
+ }
+
+ public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction)
+ {
+ return buildFrom(this)
+ .addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol)))
+ .build();
+ }
+
+ public SymbolStatsEstimate getSymbolStatistics(Symbol symbol)
+ {
+ return symbolStatistics.getOrDefault(symbol, SymbolStatsEstimate.UNKNOWN_STATS);
+ }
+
+ public Set getSymbolsWithKnownStatistics()
+ {
+ return symbolStatistics.keySet();
+ }
+
+ @Override
+ public String toString()
+ {
+ return toStringHelper(this)
+ .add("outputRowCount", outputRowCount)
+ .add("symbolStatistics", symbolStatistics)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ PlanNodeStatsEstimate that = (PlanNodeStatsEstimate) o;
+ return Double.compare(that.outputRowCount, outputRowCount) == 0 &&
+ Objects.equals(symbolStatistics, that.symbolStatistics);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(outputRowCount, symbolStatistics);
+ }
+
+ public static Builder builder()
+ {
+ return new Builder();
+ }
+
+ public static Builder buildFrom(PlanNodeStatsEstimate other)
+ {
+ return new Builder(other.getOutputRowCount(), other.symbolStatistics);
+ }
+
+ public static final class Builder
+ {
+ private double outputRowCount;
+ private PMap symbolStatistics;
+
+ public Builder()
+ {
+ this(NaN, HashTreePMap.empty());
+ }
+
+ private Builder(double outputRowCount, PMap symbolStatistics)
+ {
+ this.outputRowCount = outputRowCount;
+ this.symbolStatistics = symbolStatistics;
+ }
+
+ public Builder setOutputRowCount(double outputRowCount)
+ {
+ this.outputRowCount = outputRowCount;
+ return this;
+ }
+
+ public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics)
+ {
+ symbolStatistics = symbolStatistics.plus(symbol, statistics);
+ return this;
+ }
+
+ public Builder addSymbolStatistics(Map symbolStatistics)
+ {
+ this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics);
+ return this;
+ }
+
+ public Builder removeSymbolStatistics(Symbol symbol)
+ {
+ symbolStatistics = symbolStatistics.minus(symbol);
+ return this;
+ }
+
+ public PlanNodeStatsEstimate build()
+ {
+ return new PlanNodeStatsEstimate(outputRowCount, symbolStatistics);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java
new file mode 100644
index 0000000000000..99d328a6eb464
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.sql.planner.Symbol;
+
+import java.util.HashSet;
+import java.util.stream.Stream;
+
+import static com.facebook.presto.cost.AggregationStatsRule.groupBy;
+import static com.facebook.presto.util.MoreMath.min;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Collections.emptyMap;
+
+public class PlanNodeStatsEstimateMath
+{
+ private PlanNodeStatsEstimateMath()
+ {
+ }
+
+ private interface SubtractRangeStrategy
+ {
+ StatisticRange range(StatisticRange leftRange, StatisticRange rightRange);
+ }
+
+ public static PlanNodeStatsEstimate differenceInStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
+ {
+ return differenceInStatsWithRangeStrategy(left, right, StatisticRange::subtract);
+ }
+
+ public static PlanNodeStatsEstimate differenceInNonRangeStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
+ {
+ return differenceInStatsWithRangeStrategy(left, right, ((leftRange, rightRange) -> leftRange));
+ }
+
+ private static PlanNodeStatsEstimate differenceInStatsWithRangeStrategy(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right, SubtractRangeStrategy strategy)
+ {
+ PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
+ double newRowCount = left.getOutputRowCount() - right.getOutputRowCount();
+
+ Stream.concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream())
+ .forEach(symbol -> {
+ statsBuilder.addSymbolStatistics(symbol,
+ subtractColumnStats(left.getSymbolStatistics(symbol),
+ left.getOutputRowCount(),
+ right.getSymbolStatistics(symbol),
+ right.getOutputRowCount(),
+ newRowCount,
+ strategy));
+ });
+
+ return statsBuilder.setOutputRowCount(newRowCount).build();
+ }
+
+ private static SymbolStatsEstimate subtractColumnStats(SymbolStatsEstimate leftStats,
+ double leftRowCount,
+ SymbolStatsEstimate rightStats,
+ double rightRowCount,
+ double newRowCount,
+ SubtractRangeStrategy strategy)
+ {
+ StatisticRange leftRange = StatisticRange.from(leftStats);
+ StatisticRange rightRange = StatisticRange.from(rightStats);
+
+ double nullsCountLeft = leftStats.getNullsFraction() * leftRowCount;
+ double nullsCountRight = rightStats.getNullsFraction() * rightRowCount;
+ double totalSizeLeft = leftRowCount * leftStats.getAverageRowSize();
+ double totalSizeRight = rightRowCount * rightStats.getAverageRowSize();
+ StatisticRange range = strategy.range(leftRange, rightRange);
+
+ return SymbolStatsEstimate.builder()
+ .setDistinctValuesCount(leftStats.getDistinctValuesCount() - rightStats.getDistinctValuesCount())
+ .setHighValue(range.getHigh())
+ .setLowValue(range.getLow())
+ .setAverageRowSize((totalSizeLeft - totalSizeRight) / newRowCount)
+ .setNullsFraction((nullsCountLeft - nullsCountRight) / newRowCount)
+ .build();
+ }
+
+ public static PlanNodeStatsEstimate addStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
+ {
+ PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
+ double newRowCount = left.getOutputRowCount() + right.getOutputRowCount();
+
+ Stream.concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream())
+ .forEach(symbol -> {
+ statsBuilder.addSymbolStatistics(symbol,
+ addColumnStats(left.getSymbolStatistics(symbol),
+ left.getOutputRowCount(),
+ right.getSymbolStatistics(symbol),
+ right.getOutputRowCount(), newRowCount));
+ });
+
+ return statsBuilder.setOutputRowCount(newRowCount).build();
+ }
+
+ private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, double leftRows, SymbolStatsEstimate rightStats, double rightRows, double newRowCount)
+ {
+ StatisticRange leftRange = StatisticRange.from(leftStats);
+ StatisticRange rightRange = StatisticRange.from(rightStats);
+
+ StatisticRange sum = leftRange.add(rightRange);
+ double nullsCountRight = rightStats.getNullsFraction() * rightRows;
+ double nullsCountLeft = leftStats.getNullsFraction() * leftRows;
+ double totalSizeLeft = leftRows * leftStats.getAverageRowSize();
+ double totalSizeRight = rightRows * rightStats.getAverageRowSize();
+
+ return SymbolStatsEstimate.builder()
+ .setStatisticsRange(sum)
+ .setAverageRowSize((totalSizeLeft + totalSizeRight) / newRowCount) // FIXME, weights to average. left and right should be equal in most cases anyway
+ .setNullsFraction((nullsCountLeft + nullsCountRight) / newRowCount)
+ .build();
+ }
+
+ public static PlanNodeStatsEstimate intersect(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
+ {
+ checkArgument(new HashSet<>(left.getSymbolsWithKnownStatistics()).equals(new HashSet<>(right.getSymbolsWithKnownStatistics())));
+
+ PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
+
+ for (Symbol symbol : left.getSymbolsWithKnownStatistics()) {
+ SymbolStatsEstimate leftSymbolStats = left.getSymbolStatistics(symbol);
+ SymbolStatsEstimate rightSymbolStats = right.getSymbolStatistics(symbol);
+ StatisticRange leftRange = StatisticRange.from(leftSymbolStats);
+ StatisticRange rightRange = StatisticRange.from(rightSymbolStats);
+ StatisticRange intersection = leftRange.intersect(rightRange);
+
+ statsBuilder.addSymbolStatistics(
+ symbol,
+ SymbolStatsEstimate.builder()
+ .setStatisticsRange(intersection)
+ // it does matter how many nulls are preserved, the intersting point is the fact if there are nulls both sides or not
+ // this will be normalized later by groupBy
+ .setNullsFraction(min(leftSymbolStats.getNullsFraction(), rightSymbolStats.getNullsFraction()))
+ .build());
+ }
+
+ PlanNodeStatsEstimate intermediateResult = statsBuilder.build();
+ return groupBy(intermediateResult, intermediateResult.getSymbolsWithKnownStatistics(), emptyMap());
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java
new file mode 100644
index 0000000000000..8e7e59f898a83
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.tree.Expression;
+
+import java.util.Map;
+import java.util.Optional;
+
+public class ProjectStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class);
+
+ private final ScalarStatsCalculator scalarStatsCalculator;
+
+ public ProjectStatsRule(ScalarStatsCalculator scalarStatsCalculator)
+ {
+ this.scalarStatsCalculator = scalarStatsCalculator;
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ ProjectNode projectNode = (ProjectNode) node;
+
+ PlanNodeStatsEstimate sourceStats = lookup.getStats(projectNode.getSource(), session, types);
+ // TODO handle output size in bytes
+ PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(sourceStats.getOutputRowCount());
+
+ for (Map.Entry entry : projectNode.getAssignments().entrySet()) {
+ calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types));
+ }
+ return Optional.of(calculatedStats.build());
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java
new file mode 100644
index 0000000000000..930c169e42619
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.spi.type.DecimalType;
+import com.facebook.presto.spi.type.StandardTypes;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.spi.type.TypeSignature;
+import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
+import com.facebook.presto.sql.analyzer.Scope;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
+import com.facebook.presto.sql.tree.AstVisitor;
+import com.facebook.presto.sql.tree.Cast;
+import com.facebook.presto.sql.tree.CoalesceExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.facebook.presto.sql.tree.Literal;
+import com.facebook.presto.sql.tree.Node;
+import com.facebook.presto.sql.tree.NullLiteral;
+import com.facebook.presto.sql.tree.SymbolReference;
+import com.google.common.collect.ImmutableList;
+
+import javax.inject.Inject;
+
+import java.util.Map;
+import java.util.OptionalDouble;
+
+import static com.facebook.presto.sql.planner.LiteralInterpreter.evaluate;
+import static com.facebook.presto.util.MoreMath.max;
+import static com.facebook.presto.util.MoreMath.min;
+import static java.lang.Double.isFinite;
+import static java.lang.Double.isNaN;
+import static java.lang.Math.abs;
+import static java.util.Objects.requireNonNull;
+
+public class ScalarStatsCalculator
+{
+ private final Metadata metadata;
+
+ @Inject
+ public ScalarStatsCalculator(Metadata metadata)
+ {
+ this.metadata = requireNonNull(metadata, "metadata can not be null");
+ }
+
+ public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, Map types)
+ {
+ return new Visitor(inputStatistics, session).process(scalarExpression);
+ }
+
+ private class Visitor
+ extends AstVisitor
+ {
+ private final PlanNodeStatsEstimate input;
+ private final Session session;
+
+ Visitor(PlanNodeStatsEstimate input, Session session)
+ {
+ this.input = input;
+ this.session = session;
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitNode(Node node, Void context)
+ {
+ return SymbolStatsEstimate.UNKNOWN_STATS;
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context)
+ {
+ return input.getSymbolStatistics(Symbol.from(node));
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitNullLiteral(NullLiteral node, Void context)
+ {
+ return SymbolStatsEstimate.builder()
+ .setDistinctValuesCount(0)
+ .setNullsFraction(1)
+ .build();
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitLiteral(Literal node, Void context)
+ {
+ Object value = evaluate(metadata, session.toConnectorSession(), node);
+ Type type = ExpressionAnalyzer.createConstantAnalyzer(metadata, session, ImmutableList.of()).analyze(node, Scope.create());
+ OptionalDouble doubleValue = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession()).translateToDouble(value);
+ SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder()
+ .setNullsFraction(0)
+ .setDistinctValuesCount(1);
+
+ if (doubleValue.isPresent()) {
+ estimate.setLowValue(doubleValue.getAsDouble());
+ estimate.setHighValue(doubleValue.getAsDouble());
+ }
+ return estimate.build();
+ }
+
+ protected SymbolStatsEstimate visitCast(Cast node, Void context)
+ {
+ SymbolStatsEstimate sourceStats = process(node.getExpression());
+ TypeSignature targetType = TypeSignature.parseTypeSignature(node.getType());
+
+ // todo - make this general postprocessing rule.
+ double distinctValuesCount = sourceStats.getDistinctValuesCount();
+ double lowValue = sourceStats.getLowValue();
+ double highValue = sourceStats.getHighValue();
+
+ if (isIntegralType(targetType)) {
+ // todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT)
+ if (isFinite(lowValue)) {
+ lowValue = Math.round(lowValue);
+ }
+ if (isFinite(highValue)) {
+ highValue = Math.round(highValue);
+ }
+ if (isFinite(lowValue) && isFinite(highValue)) {
+ double integersInRange = highValue - lowValue + 1;
+ if (!isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
+ distinctValuesCount = integersInRange;
+ }
+ }
+ }
+
+ return SymbolStatsEstimate.builder()
+ .setNullsFraction(sourceStats.getNullsFraction())
+ .setLowValue(lowValue)
+ .setHighValue(highValue)
+ .setDistinctValuesCount(distinctValuesCount)
+ .build();
+ }
+
+ private boolean isIntegralType(TypeSignature targetType)
+ {
+ switch (targetType.getBase()) {
+ case StandardTypes.BIGINT:
+ case StandardTypes.INTEGER:
+ case StandardTypes.SMALLINT:
+ case StandardTypes.TINYINT:
+ return true;
+ case StandardTypes.DECIMAL:
+ DecimalType decimalType = (DecimalType) metadata.getType(targetType);
+ return decimalType.getScale() == 0;
+ default:
+ return false;
+ }
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
+ {
+ requireNonNull(node, "node is null");
+ SymbolStatsEstimate left = process(node.getLeft());
+ SymbolStatsEstimate right = process(node.getRight());
+
+ SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder()
+ .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize()))
+ .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction())
+ // TODO make a generic rule which cap NDV for all kind of expressions to rows count and range length (if finite)
+ .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount()));
+
+ double leftLow = left.getLowValue();
+ double leftHigh = left.getHighValue();
+ double rightLow = right.getLowValue();
+ double rightHigh = right.getHighValue();
+ if (node.getType() == ArithmeticBinaryExpression.Type.DIVIDE && rightLow < 0 && rightHigh > 0) {
+ result.setLowValue(Double.NEGATIVE_INFINITY)
+ .setHighValue(Double.POSITIVE_INFINITY);
+ }
+ else if (node.getType() == ArithmeticBinaryExpression.Type.MODULUS) {
+ double maxDivisor = max(abs(rightLow), abs(rightHigh));
+ if (leftHigh <= 0) {
+ result.setLowValue(max(-maxDivisor, leftLow))
+ .setHighValue(0);
+ }
+ else if (leftLow >= 0) {
+ result.setLowValue(0)
+ .setHighValue(min(maxDivisor, leftHigh));
+ }
+ else {
+ result.setLowValue(max(-maxDivisor, leftLow))
+ .setHighValue(min(maxDivisor, leftHigh));
+ }
+ }
+ else {
+ double v1 = operate(node.getType(), leftLow, rightLow);
+ double v2 = operate(node.getType(), leftLow, rightHigh);
+ double v3 = operate(node.getType(), leftHigh, rightLow);
+ double v4 = operate(node.getType(), leftHigh, rightHigh);
+ double lowValue = min(v1, v2, v3, v4);
+ double highValue = max(v1, v2, v3, v4);
+
+ result.setLowValue(lowValue)
+ .setHighValue(highValue);
+ }
+
+ return result.build();
+ }
+
+ private double operate(ArithmeticBinaryExpression.Type type, double left, double right)
+ {
+ switch (type) {
+ case ADD:
+ return left + right;
+ case SUBTRACT:
+ return left - right;
+ case MULTIPLY:
+ return left * right;
+ case DIVIDE:
+ return left / right;
+ case MODULUS:
+ return left % right;
+ default:
+ throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Type: " + type);
+ }
+ }
+
+ @Override
+ protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context)
+ {
+ requireNonNull(node, "node is null");
+ SymbolStatsEstimate result = null;
+ for (Expression operand : node.getOperands()) {
+ SymbolStatsEstimate operandEstimates = process(operand);
+ if (result != null) {
+ result = estimateCoalesce(result, operandEstimates);
+ }
+ else {
+ result = operandEstimates;
+ }
+ }
+ return requireNonNull(result, "result is null");
+ }
+
+ private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate left, SymbolStatsEstimate right)
+ {
+ // Question to reviewer: do you have a method to check if fraction is empty or saturated?
+ if (left.getNullsFraction() == 0) {
+ return left;
+ }
+ else if (left.getNullsFraction() == 1.0) {
+ return right;
+ }
+ else {
+ return SymbolStatsEstimate.builder()
+ .setLowValue(min(left.getLowValue(), right.getLowValue()))
+ .setHighValue(max(left.getHighValue(), right.getLowValue()))
+ .setDistinctValuesCount(left.getDistinctValuesCount() +
+ min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction()))
+ .setNullsFraction(left.getNullsFraction() * right.getNullsFraction())
+ // TODO check if dataSize estimatation method is correct
+ .setAverageRowSize(max(left.getAverageRowSize(), right.getAverageRowSize()))
+ .build();
+ }
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java
new file mode 100644
index 0000000000000..b43a918af86ee
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.SystemSessionProperties;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.google.inject.BindingAnnotation;
+
+import javax.inject.Inject;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.Target;
+import java.util.Map;
+
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.PARAMETER;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+import static java.util.Objects.requireNonNull;
+
+public class SelectingStatsCalculator
+ implements StatsCalculator
+{
+ private final StatsCalculator oldStatsCalculator;
+ private final StatsCalculator newStatsCalculator;
+
+ @Inject
+ public SelectingStatsCalculator(@Old StatsCalculator oldStatsCalculator, @New StatsCalculator newStatsCalculator)
+ {
+ this.oldStatsCalculator = requireNonNull(oldStatsCalculator, "oldStatsCalculator can not be null");
+ this.newStatsCalculator = requireNonNull(newStatsCalculator, "newStatsCalculator can not be null");
+ }
+
+ @Override
+ public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types)
+ {
+ if (SystemSessionProperties.isUseNewStatsCalculator(session)) {
+ return newStatsCalculator.calculateStats(planNode, lookup, session, types);
+ }
+ else {
+ return oldStatsCalculator.calculateStats(planNode, lookup, session, types);
+ }
+ }
+
+ @BindingAnnotation
+ @Target({PARAMETER, METHOD})
+ @Retention(RUNTIME)
+ public @interface Old {}
+
+ @BindingAnnotation
+ @Target({PARAMETER, METHOD})
+ @Retention(RUNTIME)
+ public @interface New {}
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java
new file mode 100644
index 0000000000000..9c53eb8d284a1
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import java.util.Objects;
+
+import static com.google.common.base.Preconditions.checkState;
+import static java.lang.Double.NaN;
+import static java.lang.Double.isFinite;
+import static java.lang.Double.isInfinite;
+import static java.lang.Double.isNaN;
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+
+public class StatisticRange
+{
+ private static final double INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.25;
+ private static final double INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.5;
+
+ private final double low;
+ private final double high;
+ private final double distinctValues;
+
+ public StatisticRange(double low, double high, double distinctValues)
+ {
+ checkState(low <= high || (isNaN(low) && isNaN(high)), "Low must be smaller or equal to high or range must be empty (NaN, NaN)");
+ checkState(distinctValues >= 0 || isNaN(distinctValues), "Distinct values count cannot be negative");
+ this.low = low;
+ this.high = high;
+ this.distinctValues = distinctValues;
+ }
+
+ public static StatisticRange empty()
+ {
+ return new StatisticRange(NaN, NaN, 0);
+ }
+
+ public static StatisticRange from(SymbolStatsEstimate estimate)
+ {
+ return new StatisticRange(estimate.getLowValue(), estimate.getHighValue(), estimate.getDistinctValuesCount());
+ }
+
+ public double getLow()
+ {
+ return low;
+ }
+
+ public double getHigh()
+ {
+ return high;
+ }
+
+ public double getDistinctValuesCount()
+ {
+ return distinctValues;
+ }
+
+ public double length()
+ {
+ return high - low;
+ }
+
+ public boolean isEmpty()
+ {
+ return isNaN(low) && isNaN(high);
+ }
+
+ public double overlapPercentWith(StatisticRange other)
+ {
+ if (this.equals(other)) {
+ return 1.0;
+ }
+
+ if (isEmpty() || other.isEmpty()) {
+ return 0.0; // zero is better than NaN as it will behave properly for calculating row count
+ }
+
+ double lengthOfIntersect = min(high, other.high) - max(low, other.low);
+ if (isInfinite(lengthOfIntersect)) {
+ return INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR;
+ }
+ if (lengthOfIntersect == 0) {
+ return 1 / distinctValues;
+ }
+ if (lengthOfIntersect < 0) {
+ return 0;
+ }
+ if (isInfinite(length()) && isFinite(lengthOfIntersect)) {
+ return INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR;
+ }
+ if (lengthOfIntersect > 0) {
+ return lengthOfIntersect / length();
+ }
+
+ return NaN;
+ }
+
+ private double overlappingDistinctValues(StatisticRange other)
+ {
+ double overlapPercentOfLeft = overlapPercentWith(other);
+ double overlapPercentOfRight = other.overlapPercentWith(this);
+ double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues;
+ double overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues;
+
+ return maxExcludeNaN(overlapDistinctValuesLeft, overlapDistinctValuesRight);
+ }
+
+ public StatisticRange intersect(StatisticRange other)
+ {
+ double newLow = max(low, other.low);
+ double newHigh = min(high, other.high);
+ if (newLow <= newHigh) {
+ return new StatisticRange(newLow, newHigh, overlappingDistinctValues(other));
+ }
+ return empty();
+ }
+
+ public StatisticRange add(StatisticRange other)
+ {
+ double newDistinctValues = distinctValues + other.distinctValues;
+ return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues);
+ }
+
+ public StatisticRange subtract(StatisticRange rightRange)
+ {
+ StatisticRange intersect = intersect(rightRange);
+ double newLow = getLow();
+ double newHigh = getHigh();
+ if (intersect.getLow() == getLow()) {
+ newLow = intersect.getHigh();
+ }
+ if (intersect.getHigh() == getHigh()) {
+ newHigh = intersect.getLow();
+ }
+ if (newLow > newHigh) {
+ newLow = NaN;
+ newHigh = NaN;
+ }
+
+ return new StatisticRange(newLow, newHigh, max(getDistinctValuesCount(), rightRange.getDistinctValuesCount()) - intersect.getDistinctValuesCount());
+ }
+
+ private static double minExcludeNaN(double v1, double v2)
+ {
+ if (isNaN(v1)) {
+ return v2;
+ }
+ if (isNaN(v2)) {
+ return v1;
+ }
+ return min(v1, v2);
+ }
+
+ private static double maxExcludeNaN(double v1, double v2)
+ {
+ if (isNaN(v1)) {
+ return v2;
+ }
+ if (isNaN(v2)) {
+ return v1;
+ }
+ return max(v1, v2);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof StatisticRange)) {
+ return false;
+ }
+ StatisticRange other = (StatisticRange) obj;
+ return low == other.low &&
+ high == other.high &&
+ distinctValues == other.distinctValues;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(low, high, distinctValues);
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java
new file mode 100644
index 0000000000000..894b94887121d
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+
+import java.util.Map;
+
+/**
+ * Interface of cost calculator.
+ *
+ * Obtains estimated stats for output produced by given PlanNode
+ * Implementation may use lookup to compute needed traits for self/source nodes.
+ */
+public interface StatsCalculator
+{
+ PlanNodeStatsEstimate calculateStats(
+ PlanNode node,
+ Lookup lookup,
+ Session session,
+ Map types);
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java
new file mode 100644
index 0000000000000..d6b62a95567bb
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import java.util.Objects;
+import java.util.function.Function;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.lang.Double.NaN;
+import static java.lang.Double.isNaN;
+
+public class SymbolStatsEstimate
+{
+ public static final SymbolStatsEstimate UNKNOWN_STATS = SymbolStatsEstimate.builder().build();
+
+ // for now we support only types which map to real domain naturally and keep low/high value as double in stats.
+ private final double lowValue;
+ private final double highValue;
+ private final double nullsFraction;
+ private final double averageRowSize;
+ private final double distinctValuesCount;
+
+ public SymbolStatsEstimate(double lowValue, double highValue, double nullsFraction, double averageRowSize, double distinctValuesCount)
+ {
+ checkArgument(lowValue <= highValue || (isNaN(lowValue) && isNaN(highValue)), "lowValue must be less than or equal to highValue or both values have to be NaN");
+ this.lowValue = lowValue;
+ this.highValue = highValue;
+ this.nullsFraction = nullsFraction;
+ this.averageRowSize = averageRowSize;
+ this.distinctValuesCount = distinctValuesCount;
+ }
+
+ public double getLowValue()
+ {
+ return lowValue;
+ }
+
+ public double getHighValue()
+ {
+ return highValue;
+ }
+
+ public boolean hasEmptyRange()
+ {
+ return isNaN(lowValue) && isNaN(highValue);
+ }
+
+ public double getNullsFraction()
+ {
+ if (hasEmptyRange()) {
+ return 1.0;
+ }
+ return nullsFraction;
+ }
+
+ public StatisticRange statisticRange()
+ {
+ return new StatisticRange(lowValue, highValue, distinctValuesCount);
+ }
+
+ public double getValuesFraction()
+ {
+ return 1.0 - nullsFraction;
+ }
+
+ public double getAverageRowSize()
+ {
+ return averageRowSize;
+ }
+
+ public double getDistinctValuesCount()
+ {
+ return distinctValuesCount;
+ }
+
+ public SymbolStatsEstimate mapLowValue(Function mappingFunction)
+ {
+ return buildFrom(this).setLowValue(mappingFunction.apply(lowValue)).build();
+ }
+
+ public SymbolStatsEstimate mapHighValue(Function mappingFunction)
+ {
+ return buildFrom(this).setHighValue(mappingFunction.apply(highValue)).build();
+ }
+
+ public SymbolStatsEstimate mapNullsFraction(Function mappingFunction)
+ {
+ return buildFrom(this).setNullsFraction(mappingFunction.apply(nullsFraction)).build();
+ }
+
+ public SymbolStatsEstimate mapDistinctValuesCount(Function mappingFunction)
+ {
+ return buildFrom(this).setDistinctValuesCount(mappingFunction.apply(distinctValuesCount)).build();
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SymbolStatsEstimate that = (SymbolStatsEstimate) o;
+ return Double.compare(that.nullsFraction, nullsFraction) == 0 &&
+ Double.compare(that.averageRowSize, averageRowSize) == 0 &&
+ Double.compare(that.distinctValuesCount, distinctValuesCount) == 0 &&
+ Objects.equals(lowValue, that.lowValue) &&
+ Objects.equals(highValue, that.highValue);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount);
+ }
+
+ public static Builder builder()
+ {
+ return new Builder();
+ }
+
+ public static Builder buildFrom(SymbolStatsEstimate other)
+ {
+ return builder()
+ .setLowValue(other.getLowValue())
+ .setHighValue(other.getHighValue())
+ .setNullsFraction(other.getNullsFraction())
+ .setAverageRowSize(other.getAverageRowSize())
+ .setDistinctValuesCount(other.getDistinctValuesCount());
+ }
+
+ public static final class Builder
+ {
+ private double lowValue = Double.NEGATIVE_INFINITY;
+ private double highValue = Double.POSITIVE_INFINITY;
+ private double nullsFraction = NaN;
+ private double averageRowSize = NaN;
+ private double distinctValuesCount = NaN;
+
+ public Builder setStatisticsRange(StatisticRange range)
+ {
+ return setLowValue(range.getLow())
+ .setHighValue(range.getHigh())
+ .setDistinctValuesCount(range.getDistinctValuesCount());
+ }
+
+ public Builder setLowValue(double lowValue)
+ {
+ this.lowValue = lowValue;
+ return this;
+ }
+
+ public Builder setHighValue(double highValue)
+ {
+ this.highValue = highValue;
+ return this;
+ }
+
+ public Builder setNullsFraction(double nullsFraction)
+ {
+ this.nullsFraction = nullsFraction;
+ return this;
+ }
+
+ public Builder setAverageRowSize(double averageRowSize)
+ {
+ this.averageRowSize = averageRowSize;
+ return this;
+ }
+
+ public Builder setDistinctValuesCount(double distinctValuesCount)
+ {
+ this.distinctValuesCount = distinctValuesCount;
+ return this;
+ }
+
+ public SymbolStatsEstimate build()
+ {
+ return new SymbolStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount);
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java
new file mode 100644
index 0000000000000..713e9c757a308
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.cost;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.matching.Pattern;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.Constraint;
+import com.facebook.presto.spi.predicate.TupleDomain;
+import com.facebook.presto.spi.statistics.ColumnStatistics;
+import com.facebook.presto.spi.statistics.TableStatistics;
+import com.facebook.presto.spi.type.Type;
+import com.facebook.presto.sql.planner.DomainTranslator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.TableScanNode;
+import com.facebook.presto.sql.tree.BooleanLiteral;
+import com.facebook.presto.sql.tree.Expression;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalDouble;
+
+import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS;
+import static java.lang.Double.NEGATIVE_INFINITY;
+import static java.lang.Double.POSITIVE_INFINITY;
+import static java.util.Objects.requireNonNull;
+
+public class TableScanStatsRule
+ implements ComposableStatsCalculator.Rule
+{
+ private static final Pattern PATTERN = Pattern.matchByClass(TableScanNode.class);
+
+ private final Metadata metadata;
+
+ public TableScanStatsRule(Metadata metadata)
+ {
+ this.metadata = requireNonNull(metadata, "metadata can not be null");
+ }
+
+ @Override
+ public Pattern getPattern()
+ {
+ return PATTERN;
+ }
+
+ @Override
+ public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types)
+ {
+ TableScanNode tableScanNode = (TableScanNode) node;
+
+ Constraint constraint = getConstraint(tableScanNode, BooleanLiteral.TRUE_LITERAL, session, types);
+
+ TableStatistics tableStatistics = metadata.getTableStatistics(session, tableScanNode.getTable(), constraint);
+ Map outputSymbolStats = new HashMap<>();
+
+ for (Map.Entry entry : tableScanNode.getAssignments().entrySet()) {
+ Symbol symbol = entry.getKey();
+ Type symbolType = types.get(symbol);
+ Optional columnStatistics = Optional.ofNullable(tableStatistics.getColumnStatistics().get(entry.getValue()));
+ outputSymbolStats.put(symbol, columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics, session, symbolType)).orElse(UNKNOWN_STATS));
+ }
+
+ return Optional.of(PlanNodeStatsEstimate.builder()
+ .setOutputRowCount(tableStatistics.getRowCount().getValue())
+ .addSymbolStatistics(outputSymbolStats)
+ .build());
+ }
+
+ private SymbolStatsEstimate toSymbolStatistics(TableStatistics tableStatistics, ColumnStatistics columnStatistics, Session session, Type type)
+ {
+ DomainConverter domainConverter = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession());
+
+ return SymbolStatsEstimate.builder()
+ .setLowValue(asDouble(columnStatistics.getOnlyRangeColumnStatistics().getLowValue(), domainConverter).orElse(NEGATIVE_INFINITY))
+ .setHighValue(asDouble(columnStatistics.getOnlyRangeColumnStatistics().getHighValue(), domainConverter).orElse(POSITIVE_INFINITY))
+ .setNullsFraction(
+ columnStatistics.getNullsFraction().getValue()
+ / (columnStatistics.getNullsFraction().getValue() + columnStatistics.getOnlyRangeColumnStatistics().getFraction().getValue()))
+ .setDistinctValuesCount(columnStatistics.getOnlyRangeColumnStatistics().getDistinctValuesCount().getValue())
+ .setAverageRowSize(columnStatistics.getOnlyRangeColumnStatistics().getDataSize().getValue() / tableStatistics.getRowCount().getValue())
+ .build();
+ }
+
+ private OptionalDouble asDouble(Optional