diff --git a/README.md b/README.md index 7df5380c68c48..acfe40570b3da 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,13 @@ See the [User Manual](https://prestodb.io/docs/current/) for deployment instruct Presto is a standard Maven project. Simply run the following command from the project root directory: - mvn clean install + ./mvnw clean install On the first build, Maven will download all the dependencies from the internet and cache them in the local repository (`~/.m2/repository`), which can take a considerable amount of time. Subsequent builds will be faster. Presto has a comprehensive set of unit tests that can take several minutes to run. You can disable the tests when building: - mvn clean install -DskipTests + ./mvnw clean install -DskipTests ## Running Presto in your IDE diff --git a/pom.xml b/pom.xml index e733d6a8c3d60..5d09e062b9a53 100644 --- a/pom.xml +++ b/pom.xml @@ -588,6 +588,12 @@ 42.0.0 + + org.pcollections + pcollections + 2.1.2 + + org.antlr antlr4-runtime diff --git a/presto-docs/src/main/sphinx/index.rst b/presto-docs/src/main/sphinx/index.rst index 8a91e003b3506..78ec8dfd91a4a 100644 --- a/presto-docs/src/main/sphinx/index.rst +++ b/presto-docs/src/main/sphinx/index.rst @@ -15,6 +15,7 @@ Presto Documentation language sql migration + optimizer develop release diff --git a/presto-docs/src/main/sphinx/optimizer.rst b/presto-docs/src/main/sphinx/optimizer.rst new file mode 100644 index 0000000000000..483df3d2375c2 --- /dev/null +++ b/presto-docs/src/main/sphinx/optimizer.rst @@ -0,0 +1,9 @@ +*************** +Query Optimizer +*************** + +.. toctree:: + :maxdepth: 1 + + optimizer/statistics + optimizer/cost-in-explain diff --git a/presto-docs/src/main/sphinx/optimizer/cost-in-explain.rst b/presto-docs/src/main/sphinx/optimizer/cost-in-explain.rst new file mode 100644 index 0000000000000..80d2f60821b68 --- /dev/null +++ b/presto-docs/src/main/sphinx/optimizer/cost-in-explain.rst @@ -0,0 +1,37 @@ +=============== +Cost in EXPLAIN +=============== + +During planning, the cost associated with each node of the plan is computed based on the root table statistics +for the tables in the query. This calculated cost is printed as part of the output of an ``EXPLAIN`` statement. + +Cost information is displayed in the plan tree using the format ``{rows: XX, bytes: XX}``. ``rows`` refers to the +expected number of rows output by each plan node during execution. ``bytes`` refers to the expected size of the +data output by each plan node in bytes. If any of the values is not known, a ``?`` is printed. + +For example: + +.. code-block:: none + + presto:default> EXPLAIN SELECT comment FROM nation_with_column_stats WHERE nationkey > 3 + + - Output[comment] => [comment:varchar(152)] {rows: ?, bytes: ?} + - RemoteExchange[GATHER] => comment:varchar(152) {rows: 12, bytes: ?} + - ScanFilterProject[table = hive:hive:default:nation_with_column_stats, + originalConstraint = (""nationkey"" > BIGINT '3'), + filterPredicate = (""nationkey"" > BIGINT '3')] => [comment:varchar(152)] {rows: 25, bytes: ?}/{rows: 12, bytes: ?}/{rows: 12, bytes: ?} + LAYOUT: hive + nationkey := HiveColumnHandle{clientId=hive, name=nationkey, hiveType=bigint, hiveColumnIndex=0, columnType=REGULAR} + comment := HiveColumnHandle{clientId=hive, name=comment, hiveType=varchar(152), hiveColumnIndex=3, columnType=REGULAR} + +Generally there is only one cost printed for each plan node. +However, when a ``Scan`` operator is combined with a ``Filter`` and/or ``Project`` operator, then multiple cost structures will be printed, +each corresponding to an individual logical part of the combined meta-operator. +For example, for a ``ScanFilterProject`` operator three cost structures will be printed. + + * the first will correspond to ``Scan`` part of operator + * the second will correspond to ``Filter`` part of opertor + * the third will corresponde to ``Project`` part of operator + +Estimated cost is also printed in ``EXPLAIN ANALYZE`` in addition to actual runtime statistics. + diff --git a/presto-docs/src/main/sphinx/optimizer/statistics.rst b/presto-docs/src/main/sphinx/optimizer/statistics.rst new file mode 100644 index 0000000000000..8b4db88c2387c --- /dev/null +++ b/presto-docs/src/main/sphinx/optimizer/statistics.rst @@ -0,0 +1,146 @@ +================ +Table Statistics +================ + +Presto supports statistics based optimizations for queries. For a query to take advantage of these optimizations, +Presto must have statistical information for the tables in that query. + +Table statistics are provided to the query planner by connectors. +Currently the only connector that supports statistics is the :doc:`/connector/hive`. + +Table Layouts +------------- + +Statistics are exposed to the query planner by a table layout. A table layout represents a subset of a table's data +and contains information about the organizational properties of that data (like sort order and bucketing). + +The number of table layouts available for a table and the details of those table layouts are specific to each connector. +Using the Hive connector as an example: + +* Non-partitioned tables have just one table layout representing all data in the table +* Partitioned tables have a family of table layouts. Each set of partitions to be scanned represents one table layout. + Presto will try to pick a table layout consisting of the smallest number of partitions based on filtering predicates + from the query. + +Available Statistics +-------------------- + +Currently, the following statistics are available in Presto: + + * For the table: + + * **row count**: the total number of rows for the table layout + + * For each column in a table: + + * **data size**: the data size that needs to be read + * **nulls fraction**: the fraction of null values + * **distinct value count**: the number of distinct values + * **low value**: the smallest value in the column + * **high value**: the largest value in the column + + +The set of statistics available for a particular query depends on the connector being used and can also vary by table or +even by table layout. For example, the Hive connector does not currently provide statistics on data size. + +Displaying Table Statistics +--------------------------- + +Table statistics can be displayed via the Presto SQL interface using the ``SHOW STATS`` command. +There are two flavors of the command: + + * ``SHOW STATS FOR `` will show statistics for the table layout representing all data in the table + * ``SHOW STATS FOR (SELECT FROM WHERE )`` + will show statistics for the table layout of table ``t`` representing a subset of data after applying the given filtering + condition. Both the column list and the filtering condition used in the ``WHERE`` clause can reference table columns. + +In both cases, the ``SHOW STATS`` command outputs two types of rows. +For each column in the table there is a row with ``column_name`` equal to the name of that column. +These rows expose column-related statistics for a table (data size, nulls count, distinct values count, min value, max value). +Additionally there is one row with NULL as the ``column_name``. This row contains table-layout wide statistics - for now just the row count. + +For example: + +.. code-block:: none + + presto:default> SHOW STATS FOR nation; + + column_name | data_size | distinct_values_count | nulls_fraction | row_count | low_value | high_value + -------------+-----------+-----------------------+----------------+-----------+--------------------+-------------------- + regionkey | NULL | 5.0 | 0.0 | NULL | 0 | 4 + name | NULL | 25.0 | 0.0 | NULL | ALGERIA | VIETNAM + comment | NULL | 25.0 | 0.0 | NULL | haggle. carefu... | y final package... + nationkey | NULL | 25.0 | 0.0 | NULL | 0 | 24 + NULL | NULL | NULL | NULL | 25.0 | NULL | NULL + (5 rows) + + + presto:default> SHOW STATS FOR (SELECT * FROM nation WHERE nationkey > 10); + + column_name | data_size | distinct_values_count | nulls_fraction | row_count | low_value | high_value + -------------+-----------+-----------------------+----------------+-----------+--------------------+-------------------- + regionkey | NULL | 5.0 | 0.0 | NULL | 0 | 4 + name | NULL | 9.0 | 0.0 | NULL | IRAN | VIETNAM + comment | NULL | 14.0 | 0.0 | NULL | pending excuse... | y final package... + nationkey | NULL | 3.0 | 0.0 | NULL | 10 | 24 + NULL | NULL | NULL | NULL | 25.0 | NULL | NULL + (5 rows) + +If provided ``SELECT`` will filter out all of the partitions (all table layouts), +then the ``SHOW STATS`` will return no statistic which will be represented as in example below. + +.. code-block:: none + + presto:default> SHOW STATS FOR (SELECT * FROM nation WHERE nationkey > 999); + + column_name + ------------- + NULL + (1 row) + +Note, that currently providing ``column_list`` instead of ``*`` in ``SELECT`` will not influence the output table. + +For example: + +.. code-block:: none + + presto:default> SHOW STATS FOR (SELECT comment FROM nation WHERE nationkey > 10); + + column_name | data_size | distinct_values_count | nulls_fraction | row_count | low_value | high_value + -------------+-----------+-----------------------+----------------+-----------+--------------------+-------------------- + regionkey | NULL | 5.0 | 0.0 | NULL | 0 | 4 + name | NULL | 9.0 | 0.0 | NULL | IRAN | VIETNAM + comment | NULL | 14.0 | 0.0 | NULL | pending excuse... | y final package... + nationkey | NULL | 3.0 | 0.0 | NULL | 10 | 24 + NULL | NULL | NULL | NULL | 25.0 | NULL | NULL + (5 rows) + + +Updating Statistics For Hive Tables +----------------------------------- + +For the Hive connector, Presto uses the statistics that are managed by Hive and exposed via the Hive metastore API. +Depending on the Hive configuration, table statistics may not be updated automatically. + +If statistics are not updated automatically, the user needs to trigger a statistics update via the Hive CLI. + +The following command can be used in the Hive CLI to update table statistics for non-partitioned table ``t``:: + + hive> ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS + +For partitioned tables, partitioning information must be specified in the command. +Assuming table ``t`` has two partitioning keys ``a`` and ``b``, the following command would +update the table statistics for all partitions:: + + hive> ANALYZE TABLE t PARTITION (a, b) COMPUTE STATISTICS FOR COLUMNS + +It is also possible to update statistics for just a subset of partitions. +This command will update statistics for all partitions for which partitioning key ``a`` is equal to ``1``:: + + hive> ANALYZE TABLE t PARTITION (a=1, b) COMPUTE STATISTICS FOR COLUMNS + +And this command will update statistics for just one partition:: + + hive> ANALYZE TABLE t PARTITION (a=1, b=5) COMPUTE STATISTICS FOR COLUMNS + +For documentation on Hive's statistics mechanism see https://cwiki.apache.org/confluence/display/Hive/StatsDev diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java index 706986fe71188..0c8f92a38fd1a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java @@ -169,6 +169,6 @@ public HiveMetadata create() partitionUpdateCodec, typeTranslator, prestoVersion, - new MetastoreHiveStatisticsProvider(typeManager, metastore)); + new MetastoreHiveStatisticsProvider(typeManager, metastore, timeZone)); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java index 6b69fef2c656f..431084ba76e6e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java @@ -27,15 +27,27 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.RangeColumnStatistics; import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.joda.time.DateTimeZone; import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.LocalDate; import java.util.Collection; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -49,7 +61,18 @@ import java.util.stream.DoubleStream; import static com.facebook.presto.hive.HiveSessionProperties.isStatisticsEnabled; +import static com.facebook.presto.spi.predicate.Utils.nativeValueToBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.RealType.REAL; +import static com.facebook.presto.spi.type.SmallintType.SMALLINT; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -59,11 +82,13 @@ public class MetastoreHiveStatisticsProvider { private final TypeManager typeManager; private final SemiTransactionalHiveMetastore metastore; + private final DateTimeZone timeZone; - public MetastoreHiveStatisticsProvider(TypeManager typeManager, SemiTransactionalHiveMetastore metastore) + public MetastoreHiveStatisticsProvider(TypeManager typeManager, SemiTransactionalHiveMetastore metastore, DateTimeZone timeZone) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.metastore = requireNonNull(metastore, "metastore is null"); + this.timeZone = timeZone; } @Override @@ -75,7 +100,8 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab Map partitionStatistics = getPartitionsStatistics((HiveTableHandle) tableHandle, hivePartitions, tableColumns.keySet()); TableStatistics.Builder tableStatistics = TableStatistics.builder(); - tableStatistics.setRowCount(calculateRowsCount(partitionStatistics)); + Estimate rowCount = calculateRowsCount(partitionStatistics); + tableStatistics.setRowCount(rowCount); for (Map.Entry columnEntry : tableColumns.entrySet()) { String columnName = columnEntry.getKey(); HiveColumnHandle hiveColumnHandle = (HiveColumnHandle) columnEntry.getValue(); @@ -83,19 +109,134 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab continue; } ColumnStatistics.Builder columnStatistics = ColumnStatistics.builder(); + RangeColumnStatistics.Builder rangeStatistics = RangeColumnStatistics.builder(); + + List lowValueCandidates = ImmutableList.of(); + List highValueCandidates = ImmutableList.of(); + + Type prestoType = typeManager.getType(hiveColumnHandle.getTypeSignature()); + Estimate nullsFraction; if (hiveColumnHandle.isPartitionKey()) { - columnStatistics.setDistinctValuesCount(countDistinctPartitionKeys(hiveColumnHandle, hivePartitions)); - columnStatistics.setNullsCount(calculateNullsCountForPartitioningKey(hiveColumnHandle, hivePartitions, partitionStatistics)); + rangeStatistics.setDistinctValuesCount(countDistinctPartitionKeys(hiveColumnHandle, hivePartitions)); + nullsFraction = calculateNullsFractionForPartitioningKey(hiveColumnHandle, hivePartitions, partitionStatistics); + if (isLowHighSupportedForType(prestoType)) { + lowValueCandidates = hivePartitions.stream() + .map(HivePartition::getKeys) + .map(keys -> keys.get(hiveColumnHandle)) + .filter(value -> !value.isNull()) + .map(NullableValue::getValue) + .collect(toImmutableList()); + highValueCandidates = lowValueCandidates; + } } else { - columnStatistics.setDistinctValuesCount(calculateDistinctValuesCount(partitionStatistics, columnName)); - columnStatistics.setNullsCount(calculateNullsCount(partitionStatistics, columnName)); + rangeStatistics.setDistinctValuesCount(calculateDistinctValuesCount(partitionStatistics, columnName)); + nullsFraction = calculateNullsFraction(partitionStatistics, columnName, rowCount); + + // TODO[lo] Maybe we do not want to expose high/low value if it is based on too small fraction of + // partitions. And return unknown if most of the partitions we are working with do not have + // statistics computed. + + if (isLowHighSupportedForType(prestoType)) { + lowValueCandidates = partitionStatistics.values().stream() + .map(PartitionStatistics::getColumnStatistics) + .filter(stats -> stats.containsKey(columnName)) + .map(stats -> stats.get(columnName)) + .map(HiveColumnStatistics::getLowValue) + .filter(Optional::isPresent) + .map(Optional::get) + .map(value -> highLowValueAsPrestoType(value, prestoType)) + .collect(toImmutableList()); + + highValueCandidates = partitionStatistics.values().stream() + .map(PartitionStatistics::getColumnStatistics) + .filter(stats -> stats.containsKey(columnName)) + .map(stats -> stats.get(columnName)) + .map(HiveColumnStatistics::getHighValue) + .filter(Optional::isPresent) + .map(Optional::get) + .map(value -> highLowValueAsPrestoType(value, prestoType)) + .collect(toImmutableList()); + } } + columnStatistics.setNullsFraction(nullsFraction); + rangeStatistics.setFraction(nullsFraction.map(value -> 1.0 - value)); + + Comparator comparator = (leftValue, rightValue) -> { + Block leftBlock = nativeValueToBlock(prestoType, leftValue); + Block rightBlock = nativeValueToBlock(prestoType, rightValue); + return prestoType.compareTo(leftBlock, 0, rightBlock, 0); + }; + rangeStatistics.setLowValue(lowValueCandidates.stream().min(comparator)); + rangeStatistics.setHighValue(highValueCandidates.stream().max(comparator)); + + columnStatistics.addRange(rangeStatistics.build()); tableStatistics.setColumnStatistics(hiveColumnHandle, columnStatistics.build()); } return tableStatistics.build(); } + private boolean isLowHighSupportedForType(Type type) + { + if (type instanceof DecimalType) { + return true; + } + if (type.equals(TINYINT) + || type.equals(SMALLINT) + || type.equals(INTEGER) + || type.equals(BIGINT) + || type.equals(REAL) + || type.equals(DOUBLE) + || type.equals(DATE) + || type.equals(TIMESTAMP)) { + return true; + } + return false; + } + + private Object highLowValueAsPrestoType(Object value, Type prestoType) + { + checkArgument(isLowHighSupportedForType(prestoType), "Unsupported type " + prestoType); + requireNonNull(value, "high/low value connot be null"); + + if (prestoType.equals(BIGINT) + || prestoType.equals(INTEGER) + || prestoType.equals(SMALLINT) + || prestoType.equals(TINYINT)) { + checkArgument(value instanceof Long, "expected Long value but got " + value.getClass()); + return value; + } + else if (prestoType.equals(DOUBLE)) { + checkArgument(value instanceof Double, "expected Double value but got " + value.getClass()); + return value; + } + else if (prestoType.equals(REAL)) { + checkArgument(value instanceof Double, "expected Double value but got " + value.getClass()); + return floatToRawIntBits((float) (double) value); + } + else if (prestoType.equals(DATE)) { + checkArgument(value instanceof LocalDate, "expected LocalDate value but got " + value.getClass()); + return ((LocalDate) value).toEpochDay(); + } + else if (prestoType.equals(TIMESTAMP)) { + checkArgument(value instanceof Long, "expected Long value but got " + value.getClass()); + return timeZone.convertLocalToUTC((long) value * 1000, false); + } + else if (prestoType instanceof DecimalType) { + checkArgument(value instanceof BigDecimal, "expected BigDecimal value but got " + value.getClass()); + BigInteger unscaled = Decimals.rescale((BigDecimal) value, (DecimalType) prestoType).unscaledValue(); + if (Decimals.isShortDecimal(prestoType)) { + return unscaled.longValueExact(); + } + else { + return Decimals.encodeUnscaledValue(unscaled); + } + } + else { + throw new IllegalArgumentException("Unsupported presto type " + prestoType); + } + } + private Estimate calculateRowsCount(Map partitionStatistics) { List knownPartitionRowCounts = partitionStatistics.values().stream() @@ -130,9 +271,9 @@ private Estimate calculateDistinctValuesCount(Map s DoubleStream::max); } - private Estimate calculateNullsCount(Map statisticsByPartitionName, String column) + private Estimate calculateNullsFraction(Map statisticsByPartitionName, String column, Estimate totalRowsCount) { - return summarizePartitionStatistics( + Estimate totalNullsCount = summarizePartitionStatistics( statisticsByPartitionName.values(), column, columnStatistics -> { @@ -144,11 +285,10 @@ private Estimate calculateNullsCount(Map statistics } }, nullsCountStream -> { - double totalNullsCount = 0; + double nullsCount = 0; long partitionsWithStatisticsCount = 0; for (PrimitiveIterator.OfDouble nullsCountIterator = nullsCountStream.iterator(); nullsCountIterator.hasNext(); ) { - double nullsCount = nullsCountIterator.nextDouble(); - totalNullsCount += nullsCount; + nullsCount += nullsCountIterator.nextDouble(); partitionsWithStatisticsCount++; } @@ -157,9 +297,17 @@ private Estimate calculateNullsCount(Map statistics } else { int allPartitionsCount = statisticsByPartitionName.size(); - return OptionalDouble.of(allPartitionsCount / partitionsWithStatisticsCount * totalNullsCount); + return OptionalDouble.of(allPartitionsCount / partitionsWithStatisticsCount * nullsCount); } }); + + if (totalNullsCount.isValueUnknown() || totalRowsCount.isValueUnknown()) { + return Estimate.unknownValue(); + } + if (totalRowsCount.getValue() == 0.0) { + return new Estimate(0.0); + } + return new Estimate(totalNullsCount.getValue() / totalRowsCount.getValue()); } private Estimate countDistinctPartitionKeys(HiveColumnHandle partitionColumn, List partitions) @@ -171,7 +319,7 @@ private Estimate countDistinctPartitionKeys(HiveColumnHandle partitionColumn, Li .count()); } - private Estimate calculateNullsCountForPartitioningKey(HiveColumnHandle partitionColumn, List partitions, Map partitionStatistics) + private Estimate calculateNullsFractionForPartitioningKey(HiveColumnHandle partitionColumn, List partitions, Map partitionStatistics) { OptionalDouble rowsPerPartition = partitionStatistics.values().stream() .map(PartitionStatistics::getRowCount) @@ -183,11 +331,16 @@ private Estimate calculateNullsCountForPartitioningKey(HiveColumnHandle partitio return Estimate.unknownValue(); } - return new Estimate(partitions.stream() + double estimatedTotalRowsCount = rowsPerPartition.getAsDouble() * partitions.size(); + if (estimatedTotalRowsCount == 0.0) { + return Estimate.zeroValue(); + } + double estimatedNullsCount = partitions.stream() .filter(partition -> partition.getKeys().get(partitionColumn).isNull()) .map(HivePartition::getPartitionId) .mapToLong(partitionId -> partitionStatistics.get(partitionId).getRowCount().orElse((long) rowsPerPartition.getAsDouble())) - .sum()); + .sum(); + return new Estimate(estimatedNullsCount / estimatedTotalRowsCount); } private Estimate summarizePartitionStatistics( diff --git a/presto-main/pom.xml b/presto-main/pom.xml index dd535dce447c9..0aedfbf788d72 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -262,6 +262,11 @@ jgrapht-core + + org.pcollections + pcollections + + org.apache.bval diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index c693ec2b28e63..3ed749a98a314 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -35,6 +35,7 @@ import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; public final class SystemSessionProperties @@ -73,6 +74,7 @@ public final class SystemSessionProperties public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations"; public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join"; public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; + public static final String USE_NEW_STATS_CALCULATOR = "use_new_stats_calculator"; private final List> sessionProperties; @@ -188,12 +190,16 @@ public SystemSessionProperties( Duration::toString), new PropertyMetadata<>( QUERY_MAX_MEMORY, - "Maximum amount of distributed memory a query can use", + "Maximum amount of distributed memory a query can use (will be capped by global config property)", VARCHAR, DataSize.class, memoryManagerConfig.getMaxQueryMemory(), true, - value -> DataSize.valueOf((String) value), + value -> { + long sessionValue = DataSize.valueOf((String) value).toBytes(); + long configValue = memoryManagerConfig.getMaxQueryMemory().toBytes(); + return succinctBytes(Math.min(configValue, sessionValue)); + }, DataSize::toString), booleanSessionProperty( RESOURCE_OVERCOMMIT, @@ -318,7 +324,12 @@ public SystemSessionProperties( PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "Push partial aggregations below joins", false, - false)); + false), + booleanSessionProperty( + USE_NEW_STATS_CALCULATOR, + "Use new experimental statistics calculator", + featuresConfig.isUseNewStatsCalculator(), + true)); } public List> getSessionProperties() @@ -499,4 +510,9 @@ public static boolean isPushAggregationThroughJoin(Session session) { return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class); } + + public static boolean isUseNewStatsCalculator(Session session) + { + return session.getSystemProperty(USE_NEW_STATS_CALCULATOR, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/AbstractSetOperationStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/AbstractSetOperationStatsRule.java new file mode 100644 index 0000000000000..51dad43569605 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/AbstractSetOperationStatsRule.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SetOperationNode; +import com.google.common.collect.ListMultimap; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; + +abstract class AbstractSetOperationStatsRule + implements ComposableStatsCalculator.Rule +{ + @Override + public final Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + SetOperationNode unionNode = (SetOperationNode) node; + + Optional estimate = Optional.empty(); + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + PlanNodeStatsEstimate sourceStats = lookup.getStats(source, session, types); + + PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, unionNode.getSymbolMapping(), i); + + if (estimate.isPresent()) { + estimate = Optional.of(operate(estimate.get(), sourceStatsWithMappedSymbols)); + } + else { + estimate = Optional.of(sourceStatsWithMappedSymbols); + } + } + + checkState(estimate.isPresent()); + return estimate; + } + + private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, ListMultimap mapping, int index) + { + PlanNodeStatsEstimate.Builder mapped = PlanNodeStatsEstimate.builder() + .setOutputRowCount(estimate.getOutputRowCount()); + + mapping.keySet().stream() + .forEach(symbol -> mapped.addSymbolStatistics(symbol, estimate.getSymbolStatistics(mapping.get(symbol).get(index)))); + + return mapped.build(); + } + + protected abstract PlanNodeStatsEstimate operate(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second); +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java new file mode 100644 index 0000000000000..0f2386630beff --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/AggregationStatsRule.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Collection; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.util.MoreMath.isPositiveOrNan; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class AggregationStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + AggregationNode aggregationNode = (AggregationNode) node; + + if (aggregationNode.getGroupingSets().size() != 1) { + return Optional.empty(); + } + + return Optional.of(groupBy( + lookup.getStats(aggregationNode.getSource(), session, types), + getOnlyElement(aggregationNode.getGroupingSets()), + aggregationNode.getAggregations())); + } + + public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate input, Collection groupBySymbols, Map aggregations) + { + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); + for (Symbol groupBySymbol : groupBySymbols) { + SymbolStatsEstimate symbolStatistics = input.getSymbolStatistics(groupBySymbol); + result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> { + if (isPositiveOrNan(nullsFraction)) { + double distinctValuesCount = symbolStatistics.getDistinctValuesCount(); + return 1.0 / (distinctValuesCount + 1); + } + return 0.0; + })); + } + + double rowsCount = 1; + for (Symbol groupBySymbol : groupBySymbols) { + SymbolStatsEstimate symbolStatistics = input.getSymbolStatistics(groupBySymbol); + int nullRow = isPositiveOrNan(symbolStatistics.getNullsFraction()) ? 1 : 0; + rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow; + } + result.setOutputRowCount(rowsCount); + + for (Map.Entry aggregationEntry : aggregations.entrySet()) { + result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), input)); + } + + return result.build(); + } + + private static SymbolStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats) + { + requireNonNull(aggregation, "aggregation is null"); + requireNonNull(sourceStats, "sourceStats is null"); + + // TODO implement simple aggregations like: min, max, count, sum + return SymbolStatsEstimate.UNKNOWN_STATS; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java new file mode 100644 index 0000000000000..4d57c0cc7eeac --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingCostCalculator + implements CostCalculator +{ + private final CostCalculator costCalculator; + private final Map costs = new HashMap<>(); + + public CachingCostCalculator(CostCalculator costCalculator) + { + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + } + + @Override + public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!costs.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback + PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(planNode, lookup, session, types); + requireNonNull(costs, "computed cost can not be null"); + checkState(costs.put(planNode, cost) == null, "cost for " + planNode + " already computed"); + } + return costs.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java new file mode 100644 index 0000000000000..80d5947c06fa9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats = new HashMap<>(); + + public CachingStatsCalculator(StatsCalculator statsCalculator) + { + this.statsCalculator = statsCalculator; + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!stats.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types); + requireNonNull(stats, "computed stats can not be null"); + checkState(stats.put(planNode, statsEstimate) == null, "statistics for " + planNode + " already computed"); + } + return stats.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToOutputRowsCount.java b/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToOutputRowsCount.java new file mode 100644 index 0000000000000..b20367ac8187b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToOutputRowsCount.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; + +import static java.lang.Double.isNaN; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +public class CapDistinctValuesCountToOutputRowsCount + implements ComposableStatsCalculator.Normalizer +{ + @Override + public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types) + { + requireNonNull(node, "node is null"); + requireNonNull(estimate, "estimate is null"); + requireNonNull(types, "types is null"); + + double outputRowCount = estimate.getOutputRowCount(); + if (isNaN(outputRowCount)) { + return estimate; + } + for (Symbol symbol : estimate.getSymbolsWithKnownStatistics()) { + estimate = estimate.mapSymbolColumnStatistics( + symbol, + symbolStatsEstimate -> symbolStatsEstimate.mapDistinctValuesCount( + distinctValuesCount -> { + if (!isNaN(distinctValuesCount)) { + return min(distinctValuesCount, outputRowCount); + } + return distinctValuesCount; + })); + } + return estimate; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToTypeDomainRangeLength.java b/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToTypeDomainRangeLength.java new file mode 100644 index 0000000000000..3ae9cc45cfded --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CapDistinctValuesCountToTypeDomainRangeLength.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.IntegerType; +import com.facebook.presto.spi.type.SmallintType; +import com.facebook.presto.spi.type.TinyintType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; + +import static java.lang.Double.NaN; +import static java.lang.Double.isNaN; +import static java.lang.Double.min; +import static java.lang.Math.floor; +import static java.lang.Math.pow; + +public class CapDistinctValuesCountToTypeDomainRangeLength + implements ComposableStatsCalculator.Normalizer +{ + @Override + public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types) + { + for (Symbol symbol : estimate.getSymbolsWithKnownStatistics()) { + double domainLength = calculateDomainLength(symbol, estimate, types); + if (isNaN(domainLength)) { + continue; + } + estimate = estimate.mapSymbolColumnStatistics( + symbol, + symbolStatsEstimate -> symbolStatsEstimate.mapDistinctValuesCount( + distinctValuesCount -> { + if (!isNaN(distinctValuesCount)) { + return min(distinctValuesCount, domainLength); + } + return distinctValuesCount; + })); + } + + return estimate; + } + + private double calculateDomainLength(Symbol symbol, PlanNodeStatsEstimate estimate, Map types) + { + SymbolStatsEstimate symbolStatistics = estimate.getSymbolStatistics(symbol); + + if (symbolStatistics.statisticRange().length() == 0) { + return 1; + } + + Type type = types.get(symbol); + if (!isDiscrete(type)) { + return NaN; + } + + double length = symbolStatistics.getHighValue() - symbolStatistics.getLowValue(); + if (type instanceof DecimalType) { + length *= pow(10, ((DecimalType) type).getScale()); + } + return floor(length + 1); + } + + private boolean isDiscrete(Type type) + { + return type.equals(IntegerType.INTEGER) || + type.equals(BigintType.BIGINT) || + type.equals(SmallintType.SMALLINT) || + type.equals(TinyintType.TINYINT) || + type.equals(BooleanType.BOOLEAN) || + type instanceof DecimalType; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java deleted file mode 100644 index f2326e379bcda..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.cost; - -import com.facebook.presto.Session; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.ColumnHandle; -import com.facebook.presto.spi.Constraint; -import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.statistics.Estimate; -import com.facebook.presto.spi.statistics.TableStatistics; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DomainTranslator; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.FilterNode; -import com.facebook.presto.sql.planner.plan.JoinNode; -import com.facebook.presto.sql.planner.plan.LimitNode; -import com.facebook.presto.sql.planner.plan.OutputNode; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; -import com.facebook.presto.sql.planner.plan.PlanVisitor; -import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.planner.plan.SemiJoinNode; -import com.facebook.presto.sql.planner.plan.TableScanNode; -import com.facebook.presto.sql.planner.plan.ValuesNode; -import com.facebook.presto.sql.tree.BooleanLiteral; -import com.facebook.presto.sql.tree.Expression; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import static com.facebook.presto.cost.PlanNodeCost.UNKNOWN_COST; - -/** - * Simple implementation of CostCalculator. It make many arbitrary decisions (e.g filtering selectivity, join matching). - * It serves POC purpose. To be replaced with more advanced implementation. - */ -@ThreadSafe -public class CoefficientBasedCostCalculator - implements CostCalculator -{ - private static final Double FILTER_COEFFICIENT = 0.5; - private static final Double JOIN_MATCHING_COEFFICIENT = 2.0; - - // todo some computation for outputSizeInBytes - - private final Metadata metadata; - - @Inject - public CoefficientBasedCostCalculator(Metadata metadata) - { - this.metadata = metadata; - } - - @Override - public Map calculateCostForPlan(Session session, Map types, PlanNode planNode) - { - Visitor visitor = new Visitor(session, types); - planNode.accept(visitor, null); - return ImmutableMap.copyOf(visitor.getCosts()); - } - - private class Visitor - extends PlanVisitor - { - private final Session session; - private final Map costs; - private final Map types; - - public Visitor(Session session, Map types) - { - this.costs = new HashMap<>(); - this.session = session; - this.types = ImmutableMap.copyOf(types); - } - - public Map getCosts() - { - return ImmutableMap.copyOf(costs); - } - - @Override - protected PlanNodeCost visitPlan(PlanNode node, Void context) - { - visitSources(node); - costs.put(node.getId(), UNKNOWN_COST); - return UNKNOWN_COST; - } - - @Override - public PlanNodeCost visitOutput(OutputNode node, Void context) - { - return copySourceCost(node); - } - - @Override - public PlanNodeCost visitFilter(FilterNode node, Void context) - { - PlanNodeCost sourceCost; - if (node.getSource() instanceof TableScanNode) { - sourceCost = visitTableScanWithPredicate((TableScanNode) node.getSource(), node.getPredicate()); - } - else { - sourceCost = visitSource(node); - } - - final double filterCoefficient = FILTER_COEFFICIENT; - PlanNodeCost filterCost = sourceCost - .mapOutputRowCount(value -> value * filterCoefficient); - costs.put(node.getId(), filterCost); - return filterCost; - } - - @Override - public PlanNodeCost visitProject(ProjectNode node, Void context) - { - return copySourceCost(node); - } - - @Override - public PlanNodeCost visitJoin(JoinNode node, Void context) - { - List sourceCosts = visitSources(node); - PlanNodeCost leftCost = sourceCosts.get(0); - PlanNodeCost rightCost = sourceCosts.get(1); - - PlanNodeCost.Builder joinCost = PlanNodeCost.builder(); - if (!leftCost.getOutputRowCount().isValueUnknown() && !rightCost.getOutputRowCount().isValueUnknown()) { - double rowCount = Math.max(leftCost.getOutputRowCount().getValue(), rightCost.getOutputRowCount().getValue()) * JOIN_MATCHING_COEFFICIENT; - joinCost.setOutputRowCount(new Estimate(rowCount)); - } - - costs.put(node.getId(), joinCost.build()); - return joinCost.build(); - } - - @Override - public PlanNodeCost visitExchange(ExchangeNode node, Void context) - { - List sourceCosts = visitSources(node); - Estimate rowCount = new Estimate(0); - for (PlanNodeCost sourceCost : sourceCosts) { - if (sourceCost.getOutputRowCount().isValueUnknown()) { - rowCount = Estimate.unknownValue(); - } - else { - rowCount = rowCount.map(value -> value + sourceCost.getOutputRowCount().getValue()); - } - } - - PlanNodeCost exchangeCost = PlanNodeCost.builder() - .setOutputRowCount(rowCount) - .build(); - costs.put(node.getId(), exchangeCost); - return exchangeCost; - } - - @Override - public PlanNodeCost visitTableScan(TableScanNode node, Void context) - { - return visitTableScanWithPredicate(node, BooleanLiteral.TRUE_LITERAL); - } - - private PlanNodeCost visitTableScanWithPredicate(TableScanNode node, Expression predicate) - { - Constraint constraint = getConstraint(node, predicate); - - TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), constraint); - PlanNodeCost tableScanCost = PlanNodeCost.builder() - .setOutputRowCount(tableStatistics.getRowCount()) - .build(); - - costs.put(node.getId(), tableScanCost); - return tableScanCost; - } - - private Constraint getConstraint(TableScanNode node, Expression predicate) - { - DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( - metadata, - session, - predicate, - types); - - TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() - .transform(node.getAssignments()::get) - .intersect(node.getCurrentConstraint()); - - return new Constraint<>(simplifiedConstraint, bindings -> true); - } - - @Override - public PlanNodeCost visitValues(ValuesNode node, Void context) - { - Estimate valuesCount = new Estimate(node.getRows().size()); - PlanNodeCost valuesCost = PlanNodeCost.builder() - .setOutputRowCount(valuesCount) - .build(); - costs.put(node.getId(), valuesCost); - return valuesCost; - } - - @Override - public PlanNodeCost visitEnforceSingleRow(EnforceSingleRowNode node, Void context) - { - visitSources(node); - PlanNodeCost nodeCost = PlanNodeCost.builder() - .setOutputRowCount(new Estimate(1.0)) - .build(); - costs.put(node.getId(), nodeCost); - return nodeCost; - } - - @Override - public PlanNodeCost visitSemiJoin(SemiJoinNode node, Void context) - { - visitSources(node); - PlanNodeCost sourceStatitics = costs.get(node.getSource().getId()); - PlanNodeCost semiJoinCost = sourceStatitics.mapOutputRowCount(rowCount -> rowCount * JOIN_MATCHING_COEFFICIENT); - costs.put(node.getId(), semiJoinCost); - return semiJoinCost; - } - - @Override - public PlanNodeCost visitLimit(LimitNode node, Void context) - { - PlanNodeCost sourceCost = visitSource(node); - PlanNodeCost.Builder limitCost = PlanNodeCost.builder(); - if (sourceCost.getOutputRowCount().getValue() < node.getCount()) { - limitCost.setOutputRowCount(sourceCost.getOutputRowCount()); - } - else { - limitCost.setOutputRowCount(new Estimate(node.getCount())); - } - costs.put(node.getId(), limitCost.build()); - return limitCost.build(); - } - - private PlanNodeCost copySourceCost(PlanNode node) - { - PlanNodeCost sourceCost = visitSource(node); - costs.put(node.getId(), sourceCost); - return sourceCost; - } - - private List visitSources(PlanNode node) - { - return node.getSources().stream() - .map(source -> source.accept(this, null)) - .collect(Collectors.toList()); - } - - private PlanNodeCost visitSource(PlanNode node) - { - return Iterables.getOnlyElement(visitSources(node)); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedStatsCalculator.java new file mode 100644 index 0000000000000..65f0fb546d1ad --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedStatsCalculator.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.DomainTranslator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.Map; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; + +/** + * Simple implementation of StatsCalculator. It make many arbitrary decisions (e.g filtering selectivity, join matching). + * It serves POC purpose. To be replaced with more advanced implementation. + */ +@ThreadSafe +public class CoefficientBasedStatsCalculator + implements StatsCalculator +{ + private static final Double FILTER_COEFFICIENT = 0.5; + private static final Double JOIN_MATCHING_COEFFICIENT = 2.0; + + // todo some computation for outputSizeInBytes + + private final Metadata metadata; + + @Inject + public CoefficientBasedStatsCalculator(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public PlanNodeStatsEstimate calculateStats( + PlanNode node, + Lookup lookup, + Session session, + Map types) + { + Visitor visitor = new Visitor(lookup, session, types); + return node.accept(visitor, null); + } + + private class Visitor + extends PlanVisitor + { + private final Lookup lookup; + private final Session session; + private final Map types; + + public Visitor(Lookup lookup, Session session, Map types) + { + this.lookup = lookup; + this.session = session; + this.types = ImmutableMap.copyOf(types); + } + + private PlanNodeStatsEstimate lookupStats(PlanNode sourceNode) + { + return lookup.getStats(sourceNode, session, types); + } + + @Override + protected PlanNodeStatsEstimate visitPlan(PlanNode node, Void context) + { + // TODO: Explicitly visit GroupReference and throw an IllegalArgumentException + // this can only be done once we get rid of the StatelessLookup + return UNKNOWN_STATS; + } + + @Override + public PlanNodeStatsEstimate visitOutput(OutputNode node, Void context) + { + return lookupStats(node.getSource()); + } + + @Override + public PlanNodeStatsEstimate visitFilter(FilterNode node, Void context) + { + PlanNodeStatsEstimate sourceStats = lookupStats(node.getSource()); + return sourceStats.mapOutputRowCount(value -> value * FILTER_COEFFICIENT); + } + + @Override + public PlanNodeStatsEstimate visitProject(ProjectNode node, Void context) + { + return lookupStats(node.getSource()); + } + + @Override + public PlanNodeStatsEstimate visitJoin(JoinNode node, Void context) + { + PlanNodeStatsEstimate leftStats = lookupStats(node.getLeft()); + PlanNodeStatsEstimate rightStats = lookupStats(node.getRight()); + + PlanNodeStatsEstimate.Builder joinStats = PlanNodeStatsEstimate.builder(); + double rowCount = Math.max(leftStats.getOutputRowCount(), rightStats.getOutputRowCount()) * JOIN_MATCHING_COEFFICIENT; + joinStats.setOutputRowCount(rowCount); + return joinStats.build(); + } + + @Override + public PlanNodeStatsEstimate visitExchange(ExchangeNode node, Void context) + { + double rowCount = 0; + for (int i = 0; i < node.getSources().size(); i++) { + PlanNodeStatsEstimate childStats = lookupStats(node.getSources().get(i)); + rowCount = rowCount + childStats.getOutputRowCount(); + } + + return PlanNodeStatsEstimate.builder() + .setOutputRowCount(rowCount) + .build(); + } + + @Override + public PlanNodeStatsEstimate visitTableScan(TableScanNode node, Void context) + { + Constraint constraint = getConstraint(node); + + TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), constraint); + return PlanNodeStatsEstimate.builder() + .setOutputRowCount(tableStatistics.getRowCount().getValue()) + .build(); + } + + private Constraint getConstraint(TableScanNode node) + { + DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( + metadata, + session, + BooleanLiteral.TRUE_LITERAL, + types); + + TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() + .transform(node.getAssignments()::get) + .intersect(node.getCurrentConstraint()); + + return new Constraint<>(simplifiedConstraint, bindings -> true); + } + + @Override + public PlanNodeStatsEstimate visitValues(ValuesNode node, Void context) + { + int valuesCount = node.getRows().size(); + return PlanNodeStatsEstimate.builder() + .setOutputRowCount(valuesCount) + .build(); + } + + @Override + public PlanNodeStatsEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + return PlanNodeStatsEstimate.builder() + .setOutputRowCount(1) + .build(); + } + + @Override + public PlanNodeStatsEstimate visitSemiJoin(SemiJoinNode node, Void context) + { + PlanNodeStatsEstimate sourceStats = lookupStats(node.getSource()); + return sourceStats.mapOutputRowCount(rowCount -> rowCount * JOIN_MATCHING_COEFFICIENT); + } + + @Override + public PlanNodeStatsEstimate visitLimit(LimitNode node, Void context) + { + PlanNodeStatsEstimate sourceStats = lookupStats(node.getSource()); + PlanNodeStatsEstimate.Builder limitStats = PlanNodeStatsEstimate.builder(); + if (sourceStats.getOutputRowCount() < node.getCount()) { + limitStats.setOutputRowCount(sourceStats.getOutputRowCount()); + } + else { + limitStats.setOutputRowCount(node.getCount()); + } + return limitStats.build(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java new file mode 100644 index 0000000000000..8dfa58bb486f7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java @@ -0,0 +1,173 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.ComparisonExpressionType; + +import static com.facebook.presto.cost.FilterStatsCalculator.filterStatsForUnknownExpression; +import static com.facebook.presto.cost.SymbolStatsEstimate.buildFrom; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.lang.Double.isNaN; +import static java.lang.Math.max; + +public class ComparisonStatsCalculator +{ + private ComparisonStatsCalculator() + {} + + public static PlanNodeStatsEstimate comparisonSymbolToLiteralStats(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + double doubleLiteral, + ComparisonExpressionType type) + { + switch (type) { + case EQUAL: + return symbolToLiteralEquality(inputStatistics, symbol, doubleLiteral); + case NOT_EQUAL: + return symbolToLiteralNonEquality(inputStatistics, symbol, doubleLiteral); + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return symbolToLiteralLessThan(inputStatistics, symbol, doubleLiteral); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return symbolToLiteralGreaterThan(inputStatistics, symbol, doubleLiteral); + case IS_DISTINCT_FROM: + default: + return filterStatsForUnknownExpression(inputStatistics); + } + } + + private static PlanNodeStatsEstimate symbolToLiteralRangeComparison(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + StatisticRange literalRange) + { + SymbolStatsEstimate symbolStats = inputStatistics.getSymbolStatistics(symbol); + + StatisticRange range = StatisticRange.from(symbolStats); + StatisticRange intersectRange = range.intersect(literalRange); + + double filterFactor = range.overlapPercentWith(intersectRange); + SymbolStatsEstimate symbolNewEstimate = + SymbolStatsEstimate.builder() + .setAverageRowSize(symbolStats.getAverageRowSize()) + .setStatisticsRange(intersectRange) + .setNullsFraction(0.0).build(); + + return inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - symbolStats.getNullsFraction()) * rowCount) + .mapSymbolColumnStatistics(symbol, oldStats -> symbolNewEstimate); + } + + private static PlanNodeStatsEstimate symbolToLiteralEquality(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + double literal) + { + return symbolToLiteralRangeComparison(inputStatistics, symbol, new StatisticRange(literal, literal, 1)); + } + + private static PlanNodeStatsEstimate symbolToLiteralNonEquality(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + double literal) + { + SymbolStatsEstimate symbolStats = inputStatistics.getSymbolStatistics(symbol); + + StatisticRange range = StatisticRange.from(symbolStats); + StatisticRange intersectRange = range.intersect(new StatisticRange(literal, literal, 1)); + + double filterFactor = 1 - range.overlapPercentWith(intersectRange); + + return inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - symbolStats.getNullsFraction()) * rowCount) + .mapSymbolColumnStatistics(symbol, oldStats -> buildFrom(oldStats) + .setNullsFraction(0.0) + .setDistinctValuesCount(max(oldStats.getDistinctValuesCount() - 1, 0)) + .setAverageRowSize(oldStats.getAverageRowSize()) + .build()); + } + + private static PlanNodeStatsEstimate symbolToLiteralLessThan(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + double literal) + { + return symbolToLiteralRangeComparison(inputStatistics, symbol, new StatisticRange(NEGATIVE_INFINITY, literal, NaN)); + } + + private static PlanNodeStatsEstimate symbolToLiteralGreaterThan(PlanNodeStatsEstimate inputStatistics, + Symbol symbol, + double literal) + { + return symbolToLiteralRangeComparison(inputStatistics, symbol, new StatisticRange(literal, POSITIVE_INFINITY, NaN)); + } + + public static PlanNodeStatsEstimate comparisonSymbolToSymbolStats(PlanNodeStatsEstimate inputStatistics, + Symbol left, + Symbol right, + ComparisonExpressionType type) + { + switch (type) { + case EQUAL: + return symbolToSymbolEquality(inputStatistics, left, right); + case NOT_EQUAL: + return symbolToSymbolNonEquality(inputStatistics, left, right); + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case IS_DISTINCT_FROM: + default: + return filterStatsForUnknownExpression(inputStatistics); + } + } + + private static PlanNodeStatsEstimate symbolToSymbolEquality(PlanNodeStatsEstimate inputStatistics, + Symbol left, + Symbol right) + { + SymbolStatsEstimate leftStats = inputStatistics.getSymbolStatistics(left); + SymbolStatsEstimate rightStats = inputStatistics.getSymbolStatistics(right); + + if (isNaN(leftStats.getDistinctValuesCount()) || isNaN(rightStats.getDistinctValuesCount())) { + filterStatsForUnknownExpression(inputStatistics); + } + + StatisticRange leftRange = StatisticRange.from(leftStats); + StatisticRange rightRange = StatisticRange.from(rightStats); + + StatisticRange intersect = leftRange.intersect(rightRange); + + SymbolStatsEstimate newRightStats = buildFrom(rightStats) + .setNullsFraction(0) + .setStatisticsRange(intersect) + .build(); + SymbolStatsEstimate newLeftStats = buildFrom(leftStats) + .setNullsFraction(0) + .setStatisticsRange(intersect) + .build(); + + double nullsFilterFactor = (1 - leftStats.getNullsFraction()) * (1 - rightStats.getNullsFraction()); + double filterFactor = 1 / max(leftRange.getDistinctValuesCount(), rightRange.getDistinctValuesCount()); + + return inputStatistics.mapOutputRowCount(size -> size * filterFactor * nullsFilterFactor) + .mapSymbolColumnStatistics(left, oldLeftStats -> newLeftStats) + .mapSymbolColumnStatistics(right, oldRightStats -> newRightStats); + } + + private static PlanNodeStatsEstimate symbolToSymbolNonEquality(PlanNodeStatsEstimate inputStatistics, + Symbol left, + Symbol right) + { + return PlanNodeStatsEstimateMath.differenceInStats(inputStatistics, symbolToSymbolEquality(inputStatistics, left, right)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ComposableStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ComposableStatsCalculator.java new file mode 100644 index 0000000000000..b79fae3259f1e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ComposableStatsCalculator.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Matchable; +import com.facebook.presto.matching.MatchingEngine; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class ComposableStatsCalculator + implements StatsCalculator +{ + private final MatchingEngine rules; + private final List normalizers; + + public ComposableStatsCalculator(Set rules, List normalizers) + { + this.rules = MatchingEngine.builder() + .register(rules) + .build(); + this.normalizers = ImmutableList.copyOf(normalizers); + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + Visitor visitor = new Visitor(lookup, session, types); + return planNode.accept(visitor, null); + } + + public interface Rule extends Matchable + { + Optional calculate(PlanNode node, Lookup lookup, Session session, Map types); + } + + public interface Normalizer + { + PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types); + } + + private class Visitor + extends PlanVisitor + { + private final Lookup lookup; + private final Session session; + private final Map types; + + public Visitor(Lookup lookup, Session session, Map types) + { + this.lookup = lookup; + this.session = session; + this.types = ImmutableMap.copyOf(types); + } + + @Override + protected PlanNodeStatsEstimate visitPlan(PlanNode node, Void context) + { + Iterator ruleIterator = rules.getCandidates(node).iterator(); + while (ruleIterator.hasNext()) { + Rule rule = ruleIterator.next(); + Optional calculatedStats = rule.calculate(node, lookup, session, types); + if (calculatedStats.isPresent()) { + return normalize(node, calculatedStats.get()); + } + } + return PlanNodeStatsEstimate.UNKNOWN_STATS; + } + + private PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate) + { + for (Normalizer normalizer : normalizers) { + estimate = normalizer.normalize(node, estimate, types); + } + return estimate; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java index f5d1f87743a21..f13394a78e1df 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -11,28 +11,58 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.facebook.presto.cost; import com.facebook.presto.Session; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.google.inject.BindingAnnotation; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; import java.util.Map; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + /** * Interface of cost calculator. - * - * It's responsibility is to provide approximation of cost of execution of plan node. - * Example implementations may be based on table statistics or data samples. + *

+ * Computes estimated cost of executing given PlanNode. + * Implementation may use lookup to compute needed traits for self/source nodes. */ public interface CostCalculator { - Map calculateCostForPlan(Session session, Map types, PlanNode planNode); + PlanNodeCostEstimate calculateCost( + PlanNode planNode, + Lookup lookup, + Session session, + Map types); - default PlanNodeCost calculateCostForNode(Session session, Map types, PlanNode planNode) + default PlanNodeCostEstimate calculateCumulativeCost( + PlanNode planNode, + Lookup lookup, + Session session, + Map types) { - return calculateCostForPlan(session, types, planNode).get(planNode.getId()); + PlanNodeCostEstimate cost = calculateCost(planNode, lookup, session, types); + + if (!planNode.getSources().isEmpty()) { + PlanNodeCostEstimate childrenCost = planNode.getSources().stream() + .map(child -> lookup.getCumulativeCost(child, session, types)) + .reduce(PlanNodeCostEstimate.ZERO_COST, PlanNodeCostEstimate::add); + + return cost.add(childrenCost); + } + + return cost; } + + @BindingAnnotation + @Target({PARAMETER}) + @Retention(RUNTIME) + @interface EstimatedExchanges {} } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java new file mode 100644 index 0000000000000..be0e312ae22a0 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java @@ -0,0 +1,238 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.Map; + +import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.ZERO_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.cpuCost; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * Simple implementation of CostCalculator. It assumes that ExchangeNodes are already in the plan. + */ +@ThreadSafe +public class CostCalculatorUsingExchanges + implements CostCalculator +{ + private final int numberOfNodes; + + @Inject + public CostCalculatorUsingExchanges(InternalNodeManager nodeManager) + { + this(nodeManager.getAllNodes().getActiveNodes().size()); + } + + public CostCalculatorUsingExchanges(int numberOfNodes) + { + this.numberOfNodes = numberOfNodes; + } + + @Override + public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types) + { + CostEstimator costEstimator = new CostEstimator( + session, + types, + lookup, + numberOfNodes); + + return planNode.accept(costEstimator, null); + } + + private class CostEstimator + extends PlanVisitor + { + private final Session session; + private final Map types; + private final Lookup lookup; + private final int numberOfNodes; + + public CostEstimator(Session session, Map types, Lookup lookup, int numberOfNodes) + { + this.session = requireNonNull(session, "session is null"); + this.types = requireNonNull(types, "types is null"); + this.lookup = lookup; + this.numberOfNodes = numberOfNodes; + } + + @Override + protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) + { + return UNKNOWN_COST; + } + + @Override + public PlanNodeCostEstimate visitOutput(OutputNode node, Void context) + { + return ZERO_COST; + } + + @Override + public PlanNodeCostEstimate visitFilter(FilterNode node, Void context) + { + return cpuCost(getStats(node.getSource()).getOutputSizeInBytes()); + } + + @Override + public PlanNodeCostEstimate visitProject(ProjectNode node, Void context) + { + return cpuCost(getStats(node).getOutputSizeInBytes()); + } + + @Override + public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) + { + PlanNodeStatsEstimate aggregationStats = getStats(node); + PlanNodeStatsEstimate sourceStats = getStats(node.getSource()); + return PlanNodeCostEstimate.builder() + .setCpuCost(sourceStats.getOutputSizeInBytes()) + .setMemoryCost(aggregationStats.getOutputSizeInBytes()) + .setNetworkCost(0) + .build(); + } + + @Override + public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) + { + return calculateJoinCost( + node, + node.getLeft(), + node.getRight(), + node.getDistributionType().orElse(JoinNode.DistributionType.PARTITIONED).equals(JoinNode.DistributionType.REPLICATED)); + } + + private PlanNodeCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated) + { + int numberOfNodesMultiplier = replicated ? numberOfNodes : 1; + + PlanNodeStatsEstimate probeStats = getStats(probe); + PlanNodeStatsEstimate buildStats = getStats(build); + PlanNodeStatsEstimate outputStats = getStats(join); + + double cpuCost = probeStats.getOutputSizeInBytes() + + buildStats.getOutputSizeInBytes() * numberOfNodesMultiplier + + outputStats.getOutputSizeInBytes(); + + double memoryCost = buildStats.getOutputSizeInBytes() * numberOfNodesMultiplier; + + return PlanNodeCostEstimate.builder() + .setCpuCost(cpuCost) + .setMemoryCost(memoryCost) + .setNetworkCost(0) + .build(); + } + + @Override + public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context) + { + return calculateExchangeCost(numberOfNodes, getStats(node), node.getType(), node.getScope()); + } + + @Override + public PlanNodeCostEstimate visitTableScan(TableScanNode node, Void context) + { + return cpuCost(getStats(node).getOutputSizeInBytes()); // TODO: add network cost, based on input size in bytes? + } + + @Override + public PlanNodeCostEstimate visitValues(ValuesNode node, Void context) + { + return ZERO_COST; + } + + @Override + public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + return ZERO_COST; + } + + @Override + public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) + { + return calculateJoinCost( + node, + node.getSource(), + node.getFilteringSource(), + node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED)); + } + + @Override + public PlanNodeCostEstimate visitLimit(LimitNode node, Void context) + { + return cpuCost(getStats(node).getOutputSizeInBytes()); + } + + private PlanNodeStatsEstimate getStats(PlanNode node) + { + return lookup.getStats(node, session, types); + } + } + + public static PlanNodeCostEstimate calculateExchangeCost(int numberOfNodes, PlanNodeStatsEstimate exchangeStats, ExchangeNode.Type type, ExchangeNode.Scope scope) + { + double network = 0; + double cpu = 0; + + switch (type) { + case GATHER: + network = exchangeStats.getOutputSizeInBytes(); + break; + case REPARTITION: + network = exchangeStats.getOutputSizeInBytes(); + cpu = exchangeStats.getOutputSizeInBytes(); + break; + case REPLICATE: + network = exchangeStats.getOutputSizeInBytes() * numberOfNodes; + break; + default: + throw new UnsupportedOperationException(format("Unsupported type [%s] of the exchange", type)); + } + + if (scope.equals(ExchangeNode.Scope.LOCAL)) { + network = 0; + } + + return PlanNodeCostEstimate.builder() + .setNetworkCost(network) + .setCpuCost(cpu) + .setMemoryCost(0) + .build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java new file mode 100644 index 0000000000000..4241dcaa7be55 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.Map; + +import static com.facebook.presto.cost.PlanNodeCostEstimate.ZERO_COST; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static java.util.Objects.requireNonNull; + +/** + * This is a wrapper class around CostCalculator that estimates ExchangeNodes cost. + */ +@ThreadSafe +public class CostCalculatorWithEstimatedExchanges + implements CostCalculator +{ + private final CostCalculator costCalculator; + private final int numberOfNodes; + + @Inject + public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, InternalNodeManager nodeManager) + { + this(costCalculator, nodeManager.getAllNodes().getActiveNodes().size()); + } + + public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, int numberOfNodes) + { + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.numberOfNodes = numberOfNodes; + } + + @Override + public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types) + { + ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator( + session, + types, + lookup, + numberOfNodes); + PlanNodeCostEstimate estimatedExchangeCost = planNode.accept(exchangeCostEstimator, null); + + return costCalculator.calculateCost(planNode, lookup, session, types).add(estimatedExchangeCost); + } + + private class ExchangeCostEstimator + extends PlanVisitor + { + private final Session session; + private final Map types; + private final Lookup lookup; + private final int numberOfNodes; + + public ExchangeCostEstimator(Session session, Map types, Lookup lookup, int numberOfNodes) + { + this.session = requireNonNull(session, "session is null"); + this.types = requireNonNull(types, "types is null"); + this.lookup = lookup; + this.numberOfNodes = numberOfNodes; + } + + @Override + protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) + { + return ZERO_COST; + } + + @Override + public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) + { + return CostCalculatorUsingExchanges.calculateExchangeCost( + numberOfNodes, + getStats(node.getSource()), + REPARTITION, + REMOTE); + } + + @Override + public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) + { + return calculateJoinCost( + node.getLeft(), + node.getRight(), + node.getDistributionType().orElse(JoinNode.DistributionType.PARTITIONED).equals(JoinNode.DistributionType.REPLICATED)); + } + + @Override + public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) + { + return calculateJoinCost( + node.getSource(), + node.getFilteringSource(), + node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED)); + } + + private PlanNodeCostEstimate calculateJoinCost(PlanNode probe, PlanNode build, boolean replicated) + { + if (replicated) { + return CostCalculatorUsingExchanges.calculateExchangeCost( + numberOfNodes, + getStats(build), + REPLICATE, + REMOTE); + } + else { + PlanNodeCostEstimate probeCost = CostCalculatorUsingExchanges.calculateExchangeCost( + numberOfNodes, + getStats(probe), + REPARTITION, + REMOTE); + PlanNodeCostEstimate buildCost = CostCalculatorUsingExchanges.calculateExchangeCost( + numberOfNodes, + getStats(build), + REPARTITION, + REMOTE); + return probeCost.add(buildCost); + } + } + + private PlanNodeStatsEstimate getStats(PlanNode node) + { + return lookup.getStats(node, session, types); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java new file mode 100644 index 0000000000000..b37e8f32f3051 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostComparator.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.google.common.annotations.VisibleForTesting; + +import javax.inject.Inject; + +import java.util.Comparator; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class CostComparator +{ + private final double cpuWeight; + private final double memoryWeight; + private final double networkWeight; + + @Inject + public CostComparator(FeaturesConfig featuresConfig) + { + this(featuresConfig.getCpuCostWeight(), featuresConfig.getMemoryCostWeight(), featuresConfig.getNetworkCostWeight()); + } + + @VisibleForTesting + public CostComparator(double cpuWeight, double memoryWeight, double networkWeight) + { + checkArgument(cpuWeight >= 0, "cpuWeight can not be negative"); + checkArgument(memoryWeight >= 0, "memoryWeight can not be negative"); + checkArgument(networkWeight >= 0, "networkWeight can not be negative"); + this.cpuWeight = cpuWeight; + this.memoryWeight = memoryWeight; + this.networkWeight = networkWeight; + } + + public Comparator forSession(Session session) + { + return (left, right) -> this.compare(session, left, right); + } + + public int compare(Session session, PlanNodeCostEstimate left, PlanNodeCostEstimate right) + { + requireNonNull(session, "session can not be null"); + requireNonNull(left, "left can not be null"); + requireNonNull(right, "right can not be null"); + checkArgument(!left.hasUnknownComponents() && !right.hasUnknownComponents(), "cannot compare unknown costs"); + double leftCost = left.getCpuCost() * cpuWeight + + left.getMemoryCost() * memoryWeight + + left.getNetworkCost() * networkWeight; + + double rightCost = right.getCpuCost() * cpuWeight + + right.getMemoryCost() * memoryWeight + + right.getNetworkCost() * networkWeight; + + return Double.compare(leftCost, rightCost); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java b/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java new file mode 100644 index 0000000000000..23eefc5c0c235 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/DomainConverter.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.IntegerType; +import com.facebook.presto.spi.type.RealType; +import com.facebook.presto.spi.type.SmallintType; +import com.facebook.presto.spi.type.TinyintType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.planner.ExpressionInterpreter; +import io.airlift.slice.Slice; + +import java.util.OptionalDouble; + +import static java.util.Collections.singletonList; + +/** + * This will contain set of function used in process of calculation stats. + * It is mostly for mapping Type domain to double domain which is used for range comparisons + * during stats computations. + */ +public class DomainConverter +{ + private final Type type; + private final FunctionRegistry functionRegistry; + private final ConnectorSession session; + + public DomainConverter(Type type, FunctionRegistry functionRegistry, ConnectorSession session) + { + this.type = type; + this.functionRegistry = functionRegistry; + this.session = session; + } + + public Slice castToVarchar(Object object) + { + Signature castSignature = functionRegistry.getCoercion(type, VarcharType.createUnboundedVarcharType()); + ScalarFunctionImplementation castImplementation = functionRegistry.getScalarFunctionImplementation(castSignature); + return (Slice) ExpressionInterpreter.invoke(session, castImplementation, singletonList(object)); + } + + public OptionalDouble translateToDouble(Object object) + { + if (!isDoubleTranslationSupported(type)) { + return OptionalDouble.empty(); + } + Signature castSignature = functionRegistry.getCoercion(type, DoubleType.DOUBLE); + ScalarFunctionImplementation castImplementation = functionRegistry.getScalarFunctionImplementation(castSignature); + return OptionalDouble.of((double) ExpressionInterpreter.invoke(session, castImplementation, singletonList(object))); + } + + private boolean isDoubleTranslationSupported(Type type) + { + return type instanceof DecimalType + || DoubleType.DOUBLE.equals(type) + || RealType.REAL.equals(type) + || BigintType.BIGINT.equals(type) + || IntegerType.INTEGER.equals(type) + || SmallintType.SMALLINT.equals(type) + || TinyintType.TINYINT.equals(type) + || BooleanType.BOOLEAN.equals(type); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java new file mode 100644 index 0000000000000..0bd95ffbbc596 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/EnforceSingleRowStatsRule.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; +import java.util.Optional; + +public class EnforceSingleRowStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(EnforceSingleRowNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + return Optional.of( + PlanNodeStatsEstimate.builder() + .setOutputRowCount(1) + .build()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java new file mode 100644 index 0000000000000..c8d4b20d88ff1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom; +import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS; +import static com.google.common.base.Predicates.not; + +public class EnsureStatsMatchOutput + implements ComposableStatsCalculator.Normalizer +{ + @Override + public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types) + { + PlanNodeStatsEstimate.Builder builder = buildFrom(estimate); + + node.getOutputSymbols().stream() + .filter(not(estimate.getSymbolsWithKnownStatistics()::contains)) + .forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS)); + + estimate.getSymbolsWithKnownStatistics().stream() + .filter(not(node.getOutputSymbols()::contains)) + .forEach(builder::removeSymbolStatistics); + + return builder.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java new file mode 100644 index 0000000000000..f71444cdbcd55 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStats; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +// WIP +public class ExchangeStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(ExchangeNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + ExchangeNode exchangeNode = (ExchangeNode) node; + // QUESTION should I check partitioning schema? + + Optional estimate = Optional.empty(); + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + PlanNodeStatsEstimate sourceStats = lookup.getStats(source, session, types); + + PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, exchangeNode.getInputs().get(i), exchangeNode.getOutputSymbols()); + + if (estimate.isPresent()) { + estimate = Optional.of(addStats(estimate.get(), sourceStatsWithMappedSymbols)); + } + else { + estimate = Optional.of(sourceStatsWithMappedSymbols); + } + } + + checkState(estimate.isPresent()); + return estimate; + } + + private PlanNodeStatsEstimate mapToOutputSymbols(PlanNodeStatsEstimate estimate, List inputs, List outputs) + { + checkArgument(inputs.size() == outputs.size(), "Inputs does not match outputs"); + PlanNodeStatsEstimate.Builder mapped = PlanNodeStatsEstimate.builder() + .setOutputRowCount(estimate.getOutputRowCount()); + + for (int i = 0; i < inputs.size(); i++) { + mapped.addSymbolStatistics(outputs.get(i), estimate.getSymbolStatistics(inputs.get(i))); + } + + return mapped.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java new file mode 100644 index 0000000000000..67b97fca260bd --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.LiteralInterpreter; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.AstVisitor; +import com.facebook.presto.sql.tree.BetweenPredicate; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InListExpression; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; +import com.facebook.presto.sql.tree.Literal; +import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.SymbolReference; + +import javax.inject.Inject; + +import java.util.Map; + +import static com.facebook.presto.cost.ComparisonStatsCalculator.comparisonSymbolToLiteralStats; +import static com.facebook.presto.cost.ComparisonStatsCalculator.comparisonSymbolToSymbolStats; +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStats; +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.differenceInNonRangeStats; +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.differenceInStats; +import static com.facebook.presto.cost.SymbolStatsEstimate.buildFrom; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN_OR_EQUAL; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.Double.NaN; +import static java.lang.Double.isInfinite; +import static java.lang.Double.min; +import static java.lang.String.format; + +public class FilterStatsCalculator +{ + private final Metadata metadata; + + @Inject + public FilterStatsCalculator(Metadata metadata) + { + this.metadata = metadata; + } + + public PlanNodeStatsEstimate filterStats( + PlanNodeStatsEstimate statsEstimate, + Expression predicate, + Session session, + Map types) + { + return new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process(predicate); + } + + public static PlanNodeStatsEstimate filterStatsForUnknownExpression(PlanNodeStatsEstimate inputStatistics) + { + return inputStatistics.mapOutputRowCount(size -> size * 0.5); + } + + private class FilterExpressionStatsCalculatingVisitor + extends AstVisitor + { + private final PlanNodeStatsEstimate input; + private final Session session; + private final Map types; + + FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, Map types) + { + this.input = input; + this.session = session; + this.types = types; + } + + @Override + protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) + { + return filterForUnknownExpression(); + } + + private PlanNodeStatsEstimate filterForUnknownExpression() + { + return filterStatsForUnknownExpression(input); + } + + private PlanNodeStatsEstimate filterForFalseExpression() + { + PlanNodeStatsEstimate.Builder falseStatsBuilder = PlanNodeStatsEstimate.builder(); + + input.getSymbolsWithKnownStatistics().forEach( + symbol -> + falseStatsBuilder.addSymbolStatistics(symbol, + buildFrom(input.getSymbolStatistics(symbol)) + .setLowValue(NaN) + .setHighValue(NaN) + .setDistinctValuesCount(0.0) + .setNullsFraction(NaN).build())); + + return falseStatsBuilder.setOutputRowCount(0.0).build(); + } + + @Override + protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context) + { + return differenceInStats(input, process(node.getValue())); + } + + @Override + protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) + { + PlanNodeStatsEstimate leftStats = process(node.getLeft()); + PlanNodeStatsEstimate rightStats = process(node.getRight()); + PlanNodeStatsEstimate andStats = new FilterExpressionStatsCalculatingVisitor(leftStats, session, types).process(node.getRight()); + + switch (node.getType()) { + case AND: + return andStats; + case OR: + return differenceInNonRangeStats(addStats(leftStats, rightStats), andStats); + default: + checkState(false, format("Unimplemented logical binary operator expression %s", node.getType())); + return PlanNodeStatsEstimate.UNKNOWN_STATS; + } + } + + @Override + protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void context) + { + if (node.equals(BooleanLiteral.TRUE_LITERAL)) { + return input; + } + else { + return filterForFalseExpression(); + } + } + + @Override + protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context) + { + if (node.getValue() instanceof SymbolReference) { + Symbol symbol = Symbol.from(node.getValue()); + SymbolStatsEstimate symbolStatsEstimate = input.getSymbolStatistics(symbol); + return input.mapOutputRowCount(rowCount -> rowCount * (1 - symbolStatsEstimate.getNullsFraction())) + .mapSymbolColumnStatistics(symbol, statsEstimate -> statsEstimate.mapNullsFraction(x -> 0.0)); + } + return visitExpression(node, context); + } + + @Override + protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) + { + if (node.getValue() instanceof SymbolReference) { + Symbol symbol = Symbol.from(node.getValue()); + SymbolStatsEstimate symbolStatsEstimate = input.getSymbolStatistics(symbol); + return input.mapOutputRowCount(rowCount -> rowCount * symbolStatsEstimate.getNullsFraction()) + .mapSymbolColumnStatistics(symbol, statsEstimate -> + SymbolStatsEstimate.builder().setNullsFraction(1.0) + .setLowValue(NaN) + .setHighValue(NaN) + .setDistinctValuesCount(0.0).build()); + } + return visitExpression(node, context); + } + + @Override + protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) + { + if (!(node.getValue() instanceof SymbolReference) || !(node.getMin() instanceof Literal) || !(node.getMax() instanceof Literal)) { + return visitExpression(node, context); + } + + SymbolStatsEstimate valueStats = input.getSymbolStatistics(Symbol.from((SymbolReference) node.getValue())); + Expression leftComparison; + Expression rightComparison; + + // We want to do heuristic cut (infinite range to finite range) ASAP and than do filtering on finite range. + if (isInfinite(valueStats.getLowValue())) { + leftComparison = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()); + rightComparison = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()); + } + else { + rightComparison = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()); + leftComparison = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()); + } + + // we relay on and processing left to right + return process(and(leftComparison, rightComparison)); + } + + @Override + protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) + { + if (!(node.getValueList() instanceof InListExpression) || !(node.getValue() instanceof SymbolReference)) { + return visitExpression(node, context); + } + + InListExpression inList = (InListExpression) node.getValueList(); + PlanNodeStatsEstimate statsSum = inList.getValues().stream() + .map(inValue -> process(new ComparisonExpression(EQUAL, node.getValue(), inValue))) + .reduce(filterForFalseExpression(), + PlanNodeStatsEstimateMath::addStats); + + Symbol inValueSymbol = Symbol.from(node.getValue()); + SymbolStatsEstimate symbolStat = input.getSymbolStatistics(inValueSymbol); + double notNullValuesBeforeIn = input.getOutputRowCount() * (1 - symbolStat.getNullsFraction()); + + return statsSum.mapOutputRowCount(rowCount -> min(rowCount, notNullValuesBeforeIn)) + .mapSymbolColumnStatistics(inValueSymbol, + symbolStats -> + symbolStats.mapNullsFraction(x -> 0.0) + .mapDistinctValuesCount(distinctValues -> + min(distinctValues, input.getSymbolStatistics(inValueSymbol).getDistinctValuesCount()))); + } + + @Override + protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context) + { + if (node.getLeft() instanceof SymbolReference && node.getRight() instanceof SymbolReference) { + return comparisonSymbolToSymbolStats(input, + Symbol.from(node.getLeft()), + Symbol.from(node.getRight()), + node.getType() + ); + } + else if (node.getLeft() instanceof SymbolReference && node.getRight() instanceof Literal) { + Symbol symbol = Symbol.from(node.getLeft()); + return comparisonSymbolToLiteralStats(input, + symbol, + doubleValueFromLiteral(types.get(symbol), (Literal) node.getRight()), + node.getType() + ); + } + else if (node.getLeft() instanceof Literal && node.getRight() instanceof SymbolReference) { + Symbol symbol = Symbol.from(node.getRight()); + return comparisonSymbolToLiteralStats(input, + symbol, + doubleValueFromLiteral(types.get(symbol), (Literal) node.getLeft()), + node.getType().flip() + ); + } + else { + return filterStatsForUnknownExpression(input); + } + } + + private double doubleValueFromLiteral(Type type, Literal literal) + { + Object literalValue = LiteralInterpreter.evaluate(metadata, session.toConnectorSession(), literal); + DomainConverter domainConverter = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession()); + return domainConverter.translateToDouble(literalValue).orElse(NaN); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java new file mode 100644 index 0000000000000..0044b9f6d67ca --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsRule.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; +import java.util.Optional; + +public class FilterStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(FilterNode.class); + + private final FilterStatsCalculator filterStatsCalculator; + + public FilterStatsRule(FilterStatsCalculator filterStatsCalculator) + { + this.filterStatsCalculator = filterStatsCalculator; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + FilterNode filterNode = (FilterNode) node; + PlanNodeStatsEstimate sourceStats = lookup.getStats(filterNode.getSource(), session, types); + return Optional.of(filterStatsCalculator.filterStats(sourceStats, filterNode.getPredicate(), session, types)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java new file mode 100644 index 0000000000000..f07ff86e7c473 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.plan.IntersectNode; + +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.intersect; + +public class IntersectStatsRule + extends AbstractSetOperationStatsRule +{ + private static final Pattern PATTERN = Pattern.matchByClass(IntersectNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + protected PlanNodeStatsEstimate operate(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second) + { + return intersect(first, second); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java new file mode 100644 index 0000000000000..204ab30f3f2bb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java @@ -0,0 +1,242 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.annotations.VisibleForTesting; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.util.MoreMath.rangeMax; +import static com.facebook.presto.util.MoreMath.rangeMin; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.difference; +import static java.lang.Double.NaN; + +public class JoinStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(JoinNode.class); + + private final FilterStatsCalculator filterStatsCalculator; + + public JoinStatsRule(FilterStatsCalculator filterStatsCalculator) + { + this.filterStatsCalculator = filterStatsCalculator; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + JoinNode joinNode = (JoinNode) node; + + PlanNodeStatsEstimate leftStats = lookup.getStats(joinNode.getLeft(), session, types); + PlanNodeStatsEstimate rightStats = lookup.getStats(joinNode.getRight(), session, types); + + switch (joinNode.getType()) { + case INNER: + return Optional.of(computeInnerJoinStats(joinNode, leftStats, rightStats, session, types)); + case LEFT: + return Optional.of(computeLeftJoinStats(joinNode, leftStats, rightStats, session, types)); + case RIGHT: + return Optional.of(computeRightJoinStats(joinNode, leftStats, rightStats, session, types)); + case FULL: + return Optional.of(computeFullJoinStats(joinNode, leftStats, rightStats, session, types)); + default: + return Optional.empty(); + } + } + + private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types) + { + PlanNodeStatsEstimate rightAntiJoinStats = calculateAntiJoinStats(node.getFilter(), flippedCriteria(node), rightStats, leftStats); + return addAntiJoinStats(computeLeftJoinStats(node, leftStats, rightStats, session, types), rightAntiJoinStats, getRightJoinSymbols(node)); + } + + private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types) + { + PlanNodeStatsEstimate innerJoinStats = computeInnerJoinStats(node, leftStats, rightStats, session, types); + PlanNodeStatsEstimate leftAntiJoinStats = calculateAntiJoinStats(node.getFilter(), node.getCriteria(), leftStats, rightStats); + return addAntiJoinStats(innerJoinStats, leftAntiJoinStats, getLeftJoinSymbols(node)); + } + + private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types) + { + PlanNodeStatsEstimate innerJoinStats = computeInnerJoinStats(node, leftStats, rightStats, session, types); + PlanNodeStatsEstimate rightAntiJoinStats = calculateAntiJoinStats(node.getFilter(), flippedCriteria(node), rightStats, leftStats); + return addAntiJoinStats(innerJoinStats, rightAntiJoinStats, getRightJoinSymbols(node)); + } + + private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, Session session, Map types) + { + List comparisons = node.getCriteria().stream() + .map(criteria -> new ComparisonExpression(EQUAL, criteria.getLeft().toSymbolReference(), criteria.getRight().toSymbolReference())) + .collect(toImmutableList()); + Expression predicate = combineConjuncts(combineConjuncts(comparisons), node.getFilter().orElse(TRUE_LITERAL)); + PlanNodeStatsEstimate crossJoinStats = crossJoinStats(node, leftStats, rightStats); + return filterStatsCalculator.filterStats(crossJoinStats, predicate, session, types); + } + + @VisibleForTesting + PlanNodeStatsEstimate calculateAntiJoinStats(Optional filter, List criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) + { + // TODO: add support for non-equality conditions (e.g: <=, !=, >) + if (filter.isPresent()) { + // non-equi filters are not supported + return UNKNOWN_STATS; + } + + PlanNodeStatsEstimate outputStats = leftStats; + + for (EquiJoinClause clause : criteria) { + SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(clause.getLeft()); + SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(clause.getRight()); + + StatisticRange rightRange = StatisticRange.from(rightColumnStats); + StatisticRange antiRange = StatisticRange.from(leftColumnStats) + .subtract(rightRange); + + // TODO: use NDVs from left and right StatisticRange when they are fixed + double leftNDV = leftColumnStats.getDistinctValuesCount(); + double rightNDV = rightColumnStats.getDistinctValuesCount(); + + if (leftNDV > rightNDV) { + double selectedRangeFraction = leftColumnStats.getValuesFraction() * (leftNDV - rightNDV) / leftNDV; + double scaleFactor = selectedRangeFraction + leftColumnStats.getNullsFraction(); + double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor; + outputStats = outputStats.mapSymbolColumnStatistics(clause.getLeft(), columnStats -> + SymbolStatsEstimate.buildFrom(columnStats) + .setLowValue(antiRange.getLow()) + .setHighValue(antiRange.getHigh()) + .setNullsFraction(newLeftNullsFraction) + .setDistinctValuesCount(leftNDV - rightNDV) + .build()); + outputStats = outputStats.mapOutputRowCount(rowCount -> rowCount * scaleFactor); + } + else if (leftNDV <= rightNDV) { + // only null values are left + outputStats = outputStats.mapSymbolColumnStatistics(clause.getLeft(), columnStats -> + SymbolStatsEstimate.buildFrom(columnStats) + .setLowValue(NaN) + .setHighValue(NaN) + .setNullsFraction(1.0) + .setDistinctValuesCount(0.0) + .build()); + outputStats = outputStats.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction()); + } + else { + // either leftNDV or rightNDV is NaN + return UNKNOWN_STATS; + } + } + + return outputStats; + } + + @VisibleForTesting + PlanNodeStatsEstimate addAntiJoinStats(PlanNodeStatsEstimate joinStats, PlanNodeStatsEstimate antiJoinStats, Set joinSymbols) + { + checkState(joinStats.getSymbolsWithKnownStatistics().containsAll(antiJoinStats.getSymbolsWithKnownStatistics())); + + double joinOutputRowCount = joinStats.getOutputRowCount(); + double antiJoinOutputRowCount = antiJoinStats.getOutputRowCount(); + double totalRowCount = joinOutputRowCount + antiJoinOutputRowCount; + PlanNodeStatsEstimate outputStats = joinStats.mapOutputRowCount(rowCount -> rowCount + antiJoinOutputRowCount); + + for (Symbol symbol : antiJoinStats.getSymbolsWithKnownStatistics()) { + outputStats = outputStats.mapSymbolColumnStatistics(symbol, joinColumnStats -> { + SymbolStatsEstimate antiJoinColumnStats = antiJoinStats.getSymbolStatistics(symbol); + // weighted average + double newNullsFraction = (joinColumnStats.getNullsFraction() * joinOutputRowCount + antiJoinColumnStats.getNullsFraction() * antiJoinOutputRowCount) / totalRowCount; + double distinctValues; + if (joinSymbols.contains(symbol)) { + distinctValues = joinColumnStats.getDistinctValuesCount() + antiJoinColumnStats.getDistinctValuesCount(); + } + else { + distinctValues = joinColumnStats.getDistinctValuesCount(); + } + return SymbolStatsEstimate.buildFrom(joinColumnStats) + .setLowValue(rangeMin(joinColumnStats.getLowValue(), antiJoinColumnStats.getLowValue())) + .setHighValue(rangeMax(joinColumnStats.getHighValue(), antiJoinColumnStats.getHighValue())) + .setDistinctValuesCount(distinctValues) + .setNullsFraction(newNullsFraction) + .build(); + }); + } + + // add nulls to columns that don't exist in right stats + for (Symbol symbol : difference(joinStats.getSymbolsWithKnownStatistics(), antiJoinStats.getSymbolsWithKnownStatistics())) { + outputStats = outputStats.mapSymbolColumnStatistics(symbol, joinColumnStats -> + joinColumnStats.mapNullsFraction(nullsFraction -> (nullsFraction * joinOutputRowCount + antiJoinOutputRowCount) / totalRowCount)); + } + + return outputStats; + } + + private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) + { + PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() + .setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount()); + + node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, leftStats.getSymbolStatistics(symbol))); + node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics(symbol, rightStats.getSymbolStatistics(symbol))); + + return builder.build(); + } + + private Set getLeftJoinSymbols(JoinNode node) + { + return node.getCriteria().stream() + .map(EquiJoinClause::getLeft) + .collect(toImmutableSet()); + } + + private Set getRightJoinSymbols(JoinNode node) + { + return node.getCriteria().stream() + .map(EquiJoinClause::getRight) + .collect(toImmutableSet()); + } + + private List flippedCriteria(JoinNode node) + { + return node.getCriteria().stream() + .map(criteria -> new JoinNode.EquiJoinClause(criteria.getRight(), criteria.getLeft())) + .collect(toImmutableList()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java new file mode 100644 index 0000000000000..26175be3459b8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/LimitStatsRule.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; +import java.util.Optional; + +public class LimitStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + LimitNode limitNode = (LimitNode) node; + + PlanNodeStatsEstimate sourceStats = lookup.getStats(limitNode.getSource(), session, types); + PlanNodeStatsEstimate.Builder limitCost = PlanNodeStatsEstimate.builder(); + // TODO special handling for NaN? + if (sourceStats.getOutputRowCount() < limitNode.getCount()) { + limitCost.setOutputRowCount(sourceStats.getOutputRowCount()); + } + else { + limitCost.setOutputRowCount(limitNode.getCount()); + } + return Optional.of(limitCost.build()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java new file mode 100644 index 0000000000000..4a079e376698c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/OutputStatsRule.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; +import java.util.Optional; + +public class OutputStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(OutputNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + OutputNode outputNode = (OutputNode) node; + return Optional.of(lookup.getStats(outputNode.getSource(), session, types)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java deleted file mode 100644 index c30eaa90538f3..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cost; - -import com.facebook.presto.spi.statistics.Estimate; - -import java.util.Objects; -import java.util.function.Function; - -import static com.facebook.presto.spi.statistics.Estimate.unknownValue; -import static java.util.Objects.requireNonNull; - -public class PlanNodeCost -{ - public static final PlanNodeCost UNKNOWN_COST = PlanNodeCost.builder().build(); - - private final Estimate outputRowCount; - private final Estimate outputSizeInBytes; - - private PlanNodeCost(Estimate outputRowCount, Estimate outputSizeInBytes) - { - this.outputRowCount = requireNonNull(outputRowCount, "outputRowCount can not be null"); - this.outputSizeInBytes = requireNonNull(outputSizeInBytes, "outputSizeInBytes can not be null"); - } - - public Estimate getOutputRowCount() - { - return outputRowCount; - } - - public Estimate getOutputSizeInBytes() - { - return outputSizeInBytes; - } - - public PlanNodeCost mapOutputRowCount(Function mappingFunction) - { - return buildFrom(this).setOutputRowCount(outputRowCount.map(mappingFunction)).build(); - } - - public PlanNodeCost mapOutputSizeInBytes(Function mappingFunction) - { - return buildFrom(this).setOutputSizeInBytes(outputRowCount.map(mappingFunction)).build(); - } - - @Override - public String toString() - { - return "PlanNodeCost{outputRowCount=" + outputRowCount + ", outputSizeInBytes=" + outputSizeInBytes + '}'; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - PlanNodeCost that = (PlanNodeCost) o; - return Objects.equals(outputRowCount, that.outputRowCount) && - Objects.equals(outputSizeInBytes, that.outputSizeInBytes); - } - - @Override - public int hashCode() - { - return Objects.hash(outputRowCount, outputSizeInBytes); - } - - public static Builder builder() - { - return new Builder(); - } - - public static Builder buildFrom(PlanNodeCost other) - { - return builder().setOutputRowCount(other.getOutputRowCount()) - .setOutputSizeInBytes(other.getOutputSizeInBytes()); - } - - public static final class Builder - { - private Estimate outputRowCount = unknownValue(); - private Estimate outputSizeInBytes = unknownValue(); - - public Builder setOutputRowCount(Estimate outputRowCount) - { - this.outputRowCount = outputRowCount; - return this; - } - - public Builder setOutputSizeInBytes(Estimate outputSizeInBytes) - { - this.outputSizeInBytes = outputSizeInBytes; - return this; - } - - public PlanNodeCost build() - { - return new PlanNodeCost(outputRowCount, outputSizeInBytes); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java new file mode 100644 index 0000000000000..2a1cb29139553 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCostEstimate.java @@ -0,0 +1,179 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.lang.Double.isNaN; + +public class PlanNodeCostEstimate +{ + public static final PlanNodeCostEstimate INFINITE_COST = new PlanNodeCostEstimate(POSITIVE_INFINITY, POSITIVE_INFINITY, POSITIVE_INFINITY); + public static final PlanNodeCostEstimate UNKNOWN_COST = new PlanNodeCostEstimate(NaN, NaN, NaN); + public static final PlanNodeCostEstimate ZERO_COST = new PlanNodeCostEstimate(0, 0, 0); + + private final double cpuCost; + private final double memoryCost; + private final double networkCost; + + private PlanNodeCostEstimate(double cpuCost, double memoryCost, double networkCost) + { + checkArgument(isNaN(cpuCost) || cpuCost >= 0, "cpuCost cannot be negative"); + checkArgument(isNaN(memoryCost) || memoryCost >= 0, "memoryCost cannot be negative"); + checkArgument(isNaN(networkCost) || networkCost >= 0, "networkCost cannot be negative"); + this.cpuCost = cpuCost; + this.memoryCost = memoryCost; + this.networkCost = networkCost; + } + + /** + * Returns CPU component of the cost. Unknown value is represented by {@link Double#NaN} + */ + public double getCpuCost() + { + return cpuCost; + } + + /** + * Returns memory component of the cost. Unknown value is represented by {@link Double#NaN} + */ + public double getMemoryCost() + { + return memoryCost; + } + + /** + * Returns network component of the cost. Unknown value is represented by {@link Double#NaN} + */ + public double getNetworkCost() + { + return networkCost; + } + + /** + * Returns true if this cost has unknown components. + */ + public boolean hasUnknownComponents() + { + return isNaN(cpuCost) || isNaN(memoryCost) || isNaN(networkCost); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("cpuCost", cpuCost) + .add("memoryCost", memoryCost) + .add("networkCost", networkCost) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PlanNodeCostEstimate that = (PlanNodeCostEstimate) o; + return Double.compare(that.cpuCost, cpuCost) == 0 && + Double.compare(that.memoryCost, memoryCost) == 0 && + Double.compare(that.networkCost, networkCost) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(cpuCost, memoryCost, networkCost); + } + + public PlanNodeCostEstimate add(PlanNodeCostEstimate other) + { + return new PlanNodeCostEstimate( + cpuCost + other.cpuCost, + memoryCost + other.memoryCost, + networkCost + other.networkCost); + } + + public static PlanNodeCostEstimate networkCost(double networkCost) + { + return builder().setCpuCost(0).setMemoryCost(0).setNetworkCost(networkCost).build(); + } + + public static PlanNodeCostEstimate cpuCost(double cpuCost) + { + return builder().setCpuCost(cpuCost).setMemoryCost(0).setNetworkCost(0).build(); + } + + public static PlanNodeCostEstimate memoryCost(double memoryCost) + { + return builder().setCpuCost(0).setMemoryCost(memoryCost).setNetworkCost(0).build(); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Optional cpuCost = Optional.empty(); + private Optional memoryCost = Optional.empty(); + private Optional networkCost = Optional.empty(); + + public Builder setFrom(PlanNodeCostEstimate otherStatistics) + { + return setCpuCost(otherStatistics.getCpuCost()) + .setMemoryCost(otherStatistics.getMemoryCost()) + .setNetworkCost(otherStatistics.getNetworkCost()); + } + + public Builder setCpuCost(double cpuCost) + { + checkState(!this.cpuCost.isPresent(), "cpuCost already set"); + this.cpuCost = Optional.of(cpuCost); + return this; + } + + public Builder setMemoryCost(double memoryCost) + { + checkState(!this.memoryCost.isPresent(), "memoryCost already set"); + this.memoryCost = Optional.of(memoryCost); + return this; + } + + public Builder setNetworkCost(double networkCost) + { + checkState(!this.networkCost.isPresent(), "networkCost already set"); + this.networkCost = Optional.of(networkCost); + return this; + } + + public PlanNodeCostEstimate build() + { + checkState(cpuCost.isPresent(), "cpuCost not set"); + checkState(memoryCost.isPresent(), "memoryCost not set"); + checkState(networkCost.isPresent(), "networkCost not set"); + return new PlanNodeCostEstimate(cpuCost.get(), memoryCost.get(), networkCost.get()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java new file mode 100644 index 0000000000000..c01e64e8f355d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -0,0 +1,186 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import org.pcollections.HashTreePMap; +import org.pcollections.PMap; + +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Double.NaN; +import static java.lang.Double.isNaN; + +public class PlanNodeStatsEstimate +{ + public static final PlanNodeStatsEstimate UNKNOWN_STATS = PlanNodeStatsEstimate.builder().build(); + public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10; + + private final double outputRowCount; + private final PMap symbolStatistics; + + private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics) + { + checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); + this.outputRowCount = outputRowCount; + this.symbolStatistics = symbolStatistics; + } + + /** + * Returns estimated number of rows. + * Unknown value is represented by {@link Double#NaN} + */ + public double getOutputRowCount() + { + return outputRowCount; + } + + /** + * Returns estimated data size. + * Unknown value is represented by {@link Double#NaN} + */ + public double getOutputSizeInBytes() + { + if (isNaN(outputRowCount)) { + return Double.NaN; + } + double outputSizeInBytes = 0; + for (Map.Entry entry : symbolStatistics.entrySet()) { + outputSizeInBytes += getOutputSizeForSymbol(entry.getValue()); + } + return outputSizeInBytes; + } + + private double getOutputSizeForSymbol(SymbolStatsEstimate symbolStatistics) + { + double averageRowSize = symbolStatistics.getAverageRowSize(); + if (isNaN(averageRowSize)) { + // TODO take into consderation data type of column + return outputRowCount * DEFAULT_DATA_SIZE_PER_COLUMN; + } + return outputRowCount * averageRowSize; + } + + public PlanNodeStatsEstimate mapOutputRowCount(Function mappingFunction) + { + return buildFrom(this).setOutputRowCount(mappingFunction.apply(outputRowCount)).build(); + } + + public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction) + { + return buildFrom(this) + .addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol))) + .build(); + } + + public SymbolStatsEstimate getSymbolStatistics(Symbol symbol) + { + return symbolStatistics.getOrDefault(symbol, SymbolStatsEstimate.UNKNOWN_STATS); + } + + public Set getSymbolsWithKnownStatistics() + { + return symbolStatistics.keySet(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("outputRowCount", outputRowCount) + .add("symbolStatistics", symbolStatistics) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PlanNodeStatsEstimate that = (PlanNodeStatsEstimate) o; + return Double.compare(that.outputRowCount, outputRowCount) == 0 && + Objects.equals(symbolStatistics, that.symbolStatistics); + } + + @Override + public int hashCode() + { + return Objects.hash(outputRowCount, symbolStatistics); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Builder buildFrom(PlanNodeStatsEstimate other) + { + return new Builder(other.getOutputRowCount(), other.symbolStatistics); + } + + public static final class Builder + { + private double outputRowCount; + private PMap symbolStatistics; + + public Builder() + { + this(NaN, HashTreePMap.empty()); + } + + private Builder(double outputRowCount, PMap symbolStatistics) + { + this.outputRowCount = outputRowCount; + this.symbolStatistics = symbolStatistics; + } + + public Builder setOutputRowCount(double outputRowCount) + { + this.outputRowCount = outputRowCount; + return this; + } + + public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) + { + symbolStatistics = symbolStatistics.plus(symbol, statistics); + return this; + } + + public Builder addSymbolStatistics(Map symbolStatistics) + { + this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics); + return this; + } + + public Builder removeSymbolStatistics(Symbol symbol) + { + symbolStatistics = symbolStatistics.minus(symbol); + return this; + } + + public PlanNodeStatsEstimate build() + { + return new PlanNodeStatsEstimate(outputRowCount, symbolStatistics); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java new file mode 100644 index 0000000000000..99d328a6eb464 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; + +import java.util.HashSet; +import java.util.stream.Stream; + +import static com.facebook.presto.cost.AggregationStatsRule.groupBy; +import static com.facebook.presto.util.MoreMath.min; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Collections.emptyMap; + +public class PlanNodeStatsEstimateMath +{ + private PlanNodeStatsEstimateMath() + { + } + + private interface SubtractRangeStrategy + { + StatisticRange range(StatisticRange leftRange, StatisticRange rightRange); + } + + public static PlanNodeStatsEstimate differenceInStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + { + return differenceInStatsWithRangeStrategy(left, right, StatisticRange::subtract); + } + + public static PlanNodeStatsEstimate differenceInNonRangeStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + { + return differenceInStatsWithRangeStrategy(left, right, ((leftRange, rightRange) -> leftRange)); + } + + private static PlanNodeStatsEstimate differenceInStatsWithRangeStrategy(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right, SubtractRangeStrategy strategy) + { + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); + double newRowCount = left.getOutputRowCount() - right.getOutputRowCount(); + + Stream.concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream()) + .forEach(symbol -> { + statsBuilder.addSymbolStatistics(symbol, + subtractColumnStats(left.getSymbolStatistics(symbol), + left.getOutputRowCount(), + right.getSymbolStatistics(symbol), + right.getOutputRowCount(), + newRowCount, + strategy)); + }); + + return statsBuilder.setOutputRowCount(newRowCount).build(); + } + + private static SymbolStatsEstimate subtractColumnStats(SymbolStatsEstimate leftStats, + double leftRowCount, + SymbolStatsEstimate rightStats, + double rightRowCount, + double newRowCount, + SubtractRangeStrategy strategy) + { + StatisticRange leftRange = StatisticRange.from(leftStats); + StatisticRange rightRange = StatisticRange.from(rightStats); + + double nullsCountLeft = leftStats.getNullsFraction() * leftRowCount; + double nullsCountRight = rightStats.getNullsFraction() * rightRowCount; + double totalSizeLeft = leftRowCount * leftStats.getAverageRowSize(); + double totalSizeRight = rightRowCount * rightStats.getAverageRowSize(); + StatisticRange range = strategy.range(leftRange, rightRange); + + return SymbolStatsEstimate.builder() + .setDistinctValuesCount(leftStats.getDistinctValuesCount() - rightStats.getDistinctValuesCount()) + .setHighValue(range.getHigh()) + .setLowValue(range.getLow()) + .setAverageRowSize((totalSizeLeft - totalSizeRight) / newRowCount) + .setNullsFraction((nullsCountLeft - nullsCountRight) / newRowCount) + .build(); + } + + public static PlanNodeStatsEstimate addStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + { + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); + double newRowCount = left.getOutputRowCount() + right.getOutputRowCount(); + + Stream.concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream()) + .forEach(symbol -> { + statsBuilder.addSymbolStatistics(symbol, + addColumnStats(left.getSymbolStatistics(symbol), + left.getOutputRowCount(), + right.getSymbolStatistics(symbol), + right.getOutputRowCount(), newRowCount)); + }); + + return statsBuilder.setOutputRowCount(newRowCount).build(); + } + + private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, double leftRows, SymbolStatsEstimate rightStats, double rightRows, double newRowCount) + { + StatisticRange leftRange = StatisticRange.from(leftStats); + StatisticRange rightRange = StatisticRange.from(rightStats); + + StatisticRange sum = leftRange.add(rightRange); + double nullsCountRight = rightStats.getNullsFraction() * rightRows; + double nullsCountLeft = leftStats.getNullsFraction() * leftRows; + double totalSizeLeft = leftRows * leftStats.getAverageRowSize(); + double totalSizeRight = rightRows * rightStats.getAverageRowSize(); + + return SymbolStatsEstimate.builder() + .setStatisticsRange(sum) + .setAverageRowSize((totalSizeLeft + totalSizeRight) / newRowCount) // FIXME, weights to average. left and right should be equal in most cases anyway + .setNullsFraction((nullsCountLeft + nullsCountRight) / newRowCount) + .build(); + } + + public static PlanNodeStatsEstimate intersect(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + { + checkArgument(new HashSet<>(left.getSymbolsWithKnownStatistics()).equals(new HashSet<>(right.getSymbolsWithKnownStatistics()))); + + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); + + for (Symbol symbol : left.getSymbolsWithKnownStatistics()) { + SymbolStatsEstimate leftSymbolStats = left.getSymbolStatistics(symbol); + SymbolStatsEstimate rightSymbolStats = right.getSymbolStatistics(symbol); + StatisticRange leftRange = StatisticRange.from(leftSymbolStats); + StatisticRange rightRange = StatisticRange.from(rightSymbolStats); + StatisticRange intersection = leftRange.intersect(rightRange); + + statsBuilder.addSymbolStatistics( + symbol, + SymbolStatsEstimate.builder() + .setStatisticsRange(intersection) + // it does matter how many nulls are preserved, the intersting point is the fact if there are nulls both sides or not + // this will be normalized later by groupBy + .setNullsFraction(min(leftSymbolStats.getNullsFraction(), rightSymbolStats.getNullsFraction())) + .build()); + } + + PlanNodeStatsEstimate intermediateResult = statsBuilder.build(); + return groupBy(intermediateResult, intermediateResult.getSymbolsWithKnownStatistics(), emptyMap()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java new file mode 100644 index 0000000000000..8e7e59f898a83 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Map; +import java.util.Optional; + +public class ProjectStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); + + private final ScalarStatsCalculator scalarStatsCalculator; + + public ProjectStatsRule(ScalarStatsCalculator scalarStatsCalculator) + { + this.scalarStatsCalculator = scalarStatsCalculator; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + ProjectNode projectNode = (ProjectNode) node; + + PlanNodeStatsEstimate sourceStats = lookup.getStats(projectNode.getSource(), session, types); + // TODO handle output size in bytes + PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(sourceStats.getOutputRowCount()); + + for (Map.Entry entry : projectNode.getAssignments().entrySet()) { + calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types)); + } + return Optional.of(calculatedStats.build()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java new file mode 100644 index 0000000000000..930c169e42619 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.analyzer.ExpressionAnalyzer; +import com.facebook.presto.sql.analyzer.Scope; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.AstVisitor; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.CoalesceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Literal; +import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.OptionalDouble; + +import static com.facebook.presto.sql.planner.LiteralInterpreter.evaluate; +import static com.facebook.presto.util.MoreMath.max; +import static com.facebook.presto.util.MoreMath.min; +import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; +import static java.lang.Math.abs; +import static java.util.Objects.requireNonNull; + +public class ScalarStatsCalculator +{ + private final Metadata metadata; + + @Inject + public ScalarStatsCalculator(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata can not be null"); + } + + public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, Map types) + { + return new Visitor(inputStatistics, session).process(scalarExpression); + } + + private class Visitor + extends AstVisitor + { + private final PlanNodeStatsEstimate input; + private final Session session; + + Visitor(PlanNodeStatsEstimate input, Session session) + { + this.input = input; + this.session = session; + } + + @Override + protected SymbolStatsEstimate visitNode(Node node, Void context) + { + return SymbolStatsEstimate.UNKNOWN_STATS; + } + + @Override + protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context) + { + return input.getSymbolStatistics(Symbol.from(node)); + } + + @Override + protected SymbolStatsEstimate visitNullLiteral(NullLiteral node, Void context) + { + return SymbolStatsEstimate.builder() + .setDistinctValuesCount(0) + .setNullsFraction(1) + .build(); + } + + @Override + protected SymbolStatsEstimate visitLiteral(Literal node, Void context) + { + Object value = evaluate(metadata, session.toConnectorSession(), node); + Type type = ExpressionAnalyzer.createConstantAnalyzer(metadata, session, ImmutableList.of()).analyze(node, Scope.create()); + OptionalDouble doubleValue = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession()).translateToDouble(value); + SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder() + .setNullsFraction(0) + .setDistinctValuesCount(1); + + if (doubleValue.isPresent()) { + estimate.setLowValue(doubleValue.getAsDouble()); + estimate.setHighValue(doubleValue.getAsDouble()); + } + return estimate.build(); + } + + protected SymbolStatsEstimate visitCast(Cast node, Void context) + { + SymbolStatsEstimate sourceStats = process(node.getExpression()); + TypeSignature targetType = TypeSignature.parseTypeSignature(node.getType()); + + // todo - make this general postprocessing rule. + double distinctValuesCount = sourceStats.getDistinctValuesCount(); + double lowValue = sourceStats.getLowValue(); + double highValue = sourceStats.getHighValue(); + + if (isIntegralType(targetType)) { + // todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT) + if (isFinite(lowValue)) { + lowValue = Math.round(lowValue); + } + if (isFinite(highValue)) { + highValue = Math.round(highValue); + } + if (isFinite(lowValue) && isFinite(highValue)) { + double integersInRange = highValue - lowValue + 1; + if (!isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) { + distinctValuesCount = integersInRange; + } + } + } + + return SymbolStatsEstimate.builder() + .setNullsFraction(sourceStats.getNullsFraction()) + .setLowValue(lowValue) + .setHighValue(highValue) + .setDistinctValuesCount(distinctValuesCount) + .build(); + } + + private boolean isIntegralType(TypeSignature targetType) + { + switch (targetType.getBase()) { + case StandardTypes.BIGINT: + case StandardTypes.INTEGER: + case StandardTypes.SMALLINT: + case StandardTypes.TINYINT: + return true; + case StandardTypes.DECIMAL: + DecimalType decimalType = (DecimalType) metadata.getType(targetType); + return decimalType.getScale() == 0; + default: + return false; + } + } + + @Override + protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + { + requireNonNull(node, "node is null"); + SymbolStatsEstimate left = process(node.getLeft()); + SymbolStatsEstimate right = process(node.getRight()); + + SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder() + .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) + .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) + // TODO make a generic rule which cap NDV for all kind of expressions to rows count and range length (if finite) + .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount())); + + double leftLow = left.getLowValue(); + double leftHigh = left.getHighValue(); + double rightLow = right.getLowValue(); + double rightHigh = right.getHighValue(); + if (node.getType() == ArithmeticBinaryExpression.Type.DIVIDE && rightLow < 0 && rightHigh > 0) { + result.setLowValue(Double.NEGATIVE_INFINITY) + .setHighValue(Double.POSITIVE_INFINITY); + } + else if (node.getType() == ArithmeticBinaryExpression.Type.MODULUS) { + double maxDivisor = max(abs(rightLow), abs(rightHigh)); + if (leftHigh <= 0) { + result.setLowValue(max(-maxDivisor, leftLow)) + .setHighValue(0); + } + else if (leftLow >= 0) { + result.setLowValue(0) + .setHighValue(min(maxDivisor, leftHigh)); + } + else { + result.setLowValue(max(-maxDivisor, leftLow)) + .setHighValue(min(maxDivisor, leftHigh)); + } + } + else { + double v1 = operate(node.getType(), leftLow, rightLow); + double v2 = operate(node.getType(), leftLow, rightHigh); + double v3 = operate(node.getType(), leftHigh, rightLow); + double v4 = operate(node.getType(), leftHigh, rightHigh); + double lowValue = min(v1, v2, v3, v4); + double highValue = max(v1, v2, v3, v4); + + result.setLowValue(lowValue) + .setHighValue(highValue); + } + + return result.build(); + } + + private double operate(ArithmeticBinaryExpression.Type type, double left, double right) + { + switch (type) { + case ADD: + return left + right; + case SUBTRACT: + return left - right; + case MULTIPLY: + return left * right; + case DIVIDE: + return left / right; + case MODULUS: + return left % right; + default: + throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Type: " + type); + } + } + + @Override + protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) + { + requireNonNull(node, "node is null"); + SymbolStatsEstimate result = null; + for (Expression operand : node.getOperands()) { + SymbolStatsEstimate operandEstimates = process(operand); + if (result != null) { + result = estimateCoalesce(result, operandEstimates); + } + else { + result = operandEstimates; + } + } + return requireNonNull(result, "result is null"); + } + + private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate left, SymbolStatsEstimate right) + { + // Question to reviewer: do you have a method to check if fraction is empty or saturated? + if (left.getNullsFraction() == 0) { + return left; + } + else if (left.getNullsFraction() == 1.0) { + return right; + } + else { + return SymbolStatsEstimate.builder() + .setLowValue(min(left.getLowValue(), right.getLowValue())) + .setHighValue(max(left.getHighValue(), right.getLowValue())) + .setDistinctValuesCount(left.getDistinctValuesCount() + + min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction())) + .setNullsFraction(left.getNullsFraction() * right.getNullsFraction()) + // TODO check if dataSize estimatation method is correct + .setAverageRowSize(max(left.getAverageRowSize(), right.getAverageRowSize())) + .build(); + } + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java new file mode 100644 index 0000000000000..b43a918af86ee --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/SelectingStatsCalculator.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.inject.BindingAnnotation; + +import javax.inject.Inject; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Map; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static java.util.Objects.requireNonNull; + +public class SelectingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator oldStatsCalculator; + private final StatsCalculator newStatsCalculator; + + @Inject + public SelectingStatsCalculator(@Old StatsCalculator oldStatsCalculator, @New StatsCalculator newStatsCalculator) + { + this.oldStatsCalculator = requireNonNull(oldStatsCalculator, "oldStatsCalculator can not be null"); + this.newStatsCalculator = requireNonNull(newStatsCalculator, "newStatsCalculator can not be null"); + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (SystemSessionProperties.isUseNewStatsCalculator(session)) { + return newStatsCalculator.calculateStats(planNode, lookup, session, types); + } + else { + return oldStatsCalculator.calculateStats(planNode, lookup, session, types); + } + } + + @BindingAnnotation + @Target({PARAMETER, METHOD}) + @Retention(RUNTIME) + public @interface Old {} + + @BindingAnnotation + @Target({PARAMETER, METHOD}) + @Retention(RUNTIME) + public @interface New {} +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java new file mode 100644 index 0000000000000..9c53eb8d284a1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkState; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; +import static java.lang.Double.isInfinite; +import static java.lang.Double.isNaN; +import static java.lang.Math.max; +import static java.lang.Math.min; + +public class StatisticRange +{ + private static final double INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.25; + private static final double INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.5; + + private final double low; + private final double high; + private final double distinctValues; + + public StatisticRange(double low, double high, double distinctValues) + { + checkState(low <= high || (isNaN(low) && isNaN(high)), "Low must be smaller or equal to high or range must be empty (NaN, NaN)"); + checkState(distinctValues >= 0 || isNaN(distinctValues), "Distinct values count cannot be negative"); + this.low = low; + this.high = high; + this.distinctValues = distinctValues; + } + + public static StatisticRange empty() + { + return new StatisticRange(NaN, NaN, 0); + } + + public static StatisticRange from(SymbolStatsEstimate estimate) + { + return new StatisticRange(estimate.getLowValue(), estimate.getHighValue(), estimate.getDistinctValuesCount()); + } + + public double getLow() + { + return low; + } + + public double getHigh() + { + return high; + } + + public double getDistinctValuesCount() + { + return distinctValues; + } + + public double length() + { + return high - low; + } + + public boolean isEmpty() + { + return isNaN(low) && isNaN(high); + } + + public double overlapPercentWith(StatisticRange other) + { + if (this.equals(other)) { + return 1.0; + } + + if (isEmpty() || other.isEmpty()) { + return 0.0; // zero is better than NaN as it will behave properly for calculating row count + } + + double lengthOfIntersect = min(high, other.high) - max(low, other.low); + if (isInfinite(lengthOfIntersect)) { + return INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR; + } + if (lengthOfIntersect == 0) { + return 1 / distinctValues; + } + if (lengthOfIntersect < 0) { + return 0; + } + if (isInfinite(length()) && isFinite(lengthOfIntersect)) { + return INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR; + } + if (lengthOfIntersect > 0) { + return lengthOfIntersect / length(); + } + + return NaN; + } + + private double overlappingDistinctValues(StatisticRange other) + { + double overlapPercentOfLeft = overlapPercentWith(other); + double overlapPercentOfRight = other.overlapPercentWith(this); + double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues; + double overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues; + + return maxExcludeNaN(overlapDistinctValuesLeft, overlapDistinctValuesRight); + } + + public StatisticRange intersect(StatisticRange other) + { + double newLow = max(low, other.low); + double newHigh = min(high, other.high); + if (newLow <= newHigh) { + return new StatisticRange(newLow, newHigh, overlappingDistinctValues(other)); + } + return empty(); + } + + public StatisticRange add(StatisticRange other) + { + double newDistinctValues = distinctValues + other.distinctValues; + return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues); + } + + public StatisticRange subtract(StatisticRange rightRange) + { + StatisticRange intersect = intersect(rightRange); + double newLow = getLow(); + double newHigh = getHigh(); + if (intersect.getLow() == getLow()) { + newLow = intersect.getHigh(); + } + if (intersect.getHigh() == getHigh()) { + newHigh = intersect.getLow(); + } + if (newLow > newHigh) { + newLow = NaN; + newHigh = NaN; + } + + return new StatisticRange(newLow, newHigh, max(getDistinctValuesCount(), rightRange.getDistinctValuesCount()) - intersect.getDistinctValuesCount()); + } + + private static double minExcludeNaN(double v1, double v2) + { + if (isNaN(v1)) { + return v2; + } + if (isNaN(v2)) { + return v1; + } + return min(v1, v2); + } + + private static double maxExcludeNaN(double v1, double v2) + { + if (isNaN(v1)) { + return v2; + } + if (isNaN(v2)) { + return v1; + } + return max(v1, v2); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (!(obj instanceof StatisticRange)) { + return false; + } + StatisticRange other = (StatisticRange) obj; + return low == other.low && + high == other.high && + distinctValues == other.distinctValues; + } + + @Override + public int hashCode() + { + return Objects.hash(low, high, distinctValues); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java new file mode 100644 index 0000000000000..894b94887121d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsCalculator.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Map; + +/** + * Interface of cost calculator. + * + * Obtains estimated stats for output produced by given PlanNode + * Implementation may use lookup to compute needed traits for self/source nodes. + */ +public interface StatsCalculator +{ + PlanNodeStatsEstimate calculateStats( + PlanNode node, + Lookup lookup, + Session session, + Map types); +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java new file mode 100644 index 0000000000000..d6b62a95567bb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/SymbolStatsEstimate.java @@ -0,0 +1,195 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import java.util.Objects; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Double.NaN; +import static java.lang.Double.isNaN; + +public class SymbolStatsEstimate +{ + public static final SymbolStatsEstimate UNKNOWN_STATS = SymbolStatsEstimate.builder().build(); + + // for now we support only types which map to real domain naturally and keep low/high value as double in stats. + private final double lowValue; + private final double highValue; + private final double nullsFraction; + private final double averageRowSize; + private final double distinctValuesCount; + + public SymbolStatsEstimate(double lowValue, double highValue, double nullsFraction, double averageRowSize, double distinctValuesCount) + { + checkArgument(lowValue <= highValue || (isNaN(lowValue) && isNaN(highValue)), "lowValue must be less than or equal to highValue or both values have to be NaN"); + this.lowValue = lowValue; + this.highValue = highValue; + this.nullsFraction = nullsFraction; + this.averageRowSize = averageRowSize; + this.distinctValuesCount = distinctValuesCount; + } + + public double getLowValue() + { + return lowValue; + } + + public double getHighValue() + { + return highValue; + } + + public boolean hasEmptyRange() + { + return isNaN(lowValue) && isNaN(highValue); + } + + public double getNullsFraction() + { + if (hasEmptyRange()) { + return 1.0; + } + return nullsFraction; + } + + public StatisticRange statisticRange() + { + return new StatisticRange(lowValue, highValue, distinctValuesCount); + } + + public double getValuesFraction() + { + return 1.0 - nullsFraction; + } + + public double getAverageRowSize() + { + return averageRowSize; + } + + public double getDistinctValuesCount() + { + return distinctValuesCount; + } + + public SymbolStatsEstimate mapLowValue(Function mappingFunction) + { + return buildFrom(this).setLowValue(mappingFunction.apply(lowValue)).build(); + } + + public SymbolStatsEstimate mapHighValue(Function mappingFunction) + { + return buildFrom(this).setHighValue(mappingFunction.apply(highValue)).build(); + } + + public SymbolStatsEstimate mapNullsFraction(Function mappingFunction) + { + return buildFrom(this).setNullsFraction(mappingFunction.apply(nullsFraction)).build(); + } + + public SymbolStatsEstimate mapDistinctValuesCount(Function mappingFunction) + { + return buildFrom(this).setDistinctValuesCount(mappingFunction.apply(distinctValuesCount)).build(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SymbolStatsEstimate that = (SymbolStatsEstimate) o; + return Double.compare(that.nullsFraction, nullsFraction) == 0 && + Double.compare(that.averageRowSize, averageRowSize) == 0 && + Double.compare(that.distinctValuesCount, distinctValuesCount) == 0 && + Objects.equals(lowValue, that.lowValue) && + Objects.equals(highValue, that.highValue); + } + + @Override + public int hashCode() + { + return Objects.hash(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Builder buildFrom(SymbolStatsEstimate other) + { + return builder() + .setLowValue(other.getLowValue()) + .setHighValue(other.getHighValue()) + .setNullsFraction(other.getNullsFraction()) + .setAverageRowSize(other.getAverageRowSize()) + .setDistinctValuesCount(other.getDistinctValuesCount()); + } + + public static final class Builder + { + private double lowValue = Double.NEGATIVE_INFINITY; + private double highValue = Double.POSITIVE_INFINITY; + private double nullsFraction = NaN; + private double averageRowSize = NaN; + private double distinctValuesCount = NaN; + + public Builder setStatisticsRange(StatisticRange range) + { + return setLowValue(range.getLow()) + .setHighValue(range.getHigh()) + .setDistinctValuesCount(range.getDistinctValuesCount()); + } + + public Builder setLowValue(double lowValue) + { + this.lowValue = lowValue; + return this; + } + + public Builder setHighValue(double highValue) + { + this.highValue = highValue; + return this; + } + + public Builder setNullsFraction(double nullsFraction) + { + this.nullsFraction = nullsFraction; + return this; + } + + public Builder setAverageRowSize(double averageRowSize) + { + this.averageRowSize = averageRowSize; + return this; + } + + public Builder setDistinctValuesCount(double distinctValuesCount) + { + this.distinctValuesCount = distinctValuesCount; + return this; + } + + public SymbolStatsEstimate build() + { + return new SymbolStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java new file mode 100644 index 0000000000000..713e9c757a308 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.DomainTranslator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.Expression; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; + +import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.util.Objects.requireNonNull; + +public class TableScanStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(TableScanNode.class); + + private final Metadata metadata; + + public TableScanStatsRule(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata can not be null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + TableScanNode tableScanNode = (TableScanNode) node; + + Constraint constraint = getConstraint(tableScanNode, BooleanLiteral.TRUE_LITERAL, session, types); + + TableStatistics tableStatistics = metadata.getTableStatistics(session, tableScanNode.getTable(), constraint); + Map outputSymbolStats = new HashMap<>(); + + for (Map.Entry entry : tableScanNode.getAssignments().entrySet()) { + Symbol symbol = entry.getKey(); + Type symbolType = types.get(symbol); + Optional columnStatistics = Optional.ofNullable(tableStatistics.getColumnStatistics().get(entry.getValue())); + outputSymbolStats.put(symbol, columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics, session, symbolType)).orElse(UNKNOWN_STATS)); + } + + return Optional.of(PlanNodeStatsEstimate.builder() + .setOutputRowCount(tableStatistics.getRowCount().getValue()) + .addSymbolStatistics(outputSymbolStats) + .build()); + } + + private SymbolStatsEstimate toSymbolStatistics(TableStatistics tableStatistics, ColumnStatistics columnStatistics, Session session, Type type) + { + DomainConverter domainConverter = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession()); + + return SymbolStatsEstimate.builder() + .setLowValue(asDouble(columnStatistics.getOnlyRangeColumnStatistics().getLowValue(), domainConverter).orElse(NEGATIVE_INFINITY)) + .setHighValue(asDouble(columnStatistics.getOnlyRangeColumnStatistics().getHighValue(), domainConverter).orElse(POSITIVE_INFINITY)) + .setNullsFraction( + columnStatistics.getNullsFraction().getValue() + / (columnStatistics.getNullsFraction().getValue() + columnStatistics.getOnlyRangeColumnStatistics().getFraction().getValue())) + .setDistinctValuesCount(columnStatistics.getOnlyRangeColumnStatistics().getDistinctValuesCount().getValue()) + .setAverageRowSize(columnStatistics.getOnlyRangeColumnStatistics().getDataSize().getValue() / tableStatistics.getRowCount().getValue()) + .build(); + } + + private OptionalDouble asDouble(Optional optionalValue, DomainConverter domainConverter) + { + return optionalValue.map(domainConverter::translateToDouble).orElseGet(OptionalDouble::empty); + } + + private Constraint getConstraint(TableScanNode node, Expression predicate, Session session, Map types) + { + DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( + metadata, + session, + predicate, + types); + + TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() + .transform(node.getAssignments()::get) + .intersect(node.getCurrentConstraint()); + + return new Constraint<>(simplifiedConstraint, bindings -> true); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java new file mode 100644 index 0000000000000..a34e1dc5cb9d1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.plan.UnionNode; + +import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStats; + +public class UnionStatsRule + extends AbstractSetOperationStatsRule +{ + private static final Pattern PATTERN = Pattern.matchByClass(UnionNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + protected PlanNodeStatsEstimate operate(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second) + { + return addStats(first, second); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java new file mode 100644 index 0000000000000..a56f92a8fad83 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.stream.Collectors.toList; + +public class ValuesStatsRule + implements ComposableStatsCalculator.Rule +{ + private static final Pattern PATTERN = Pattern.matchByClass(ValuesNode.class); + + private final Metadata metadata; + + public ValuesStatsRule(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional calculate(PlanNode node, Lookup lookup, Session session, Map types) + { + ValuesNode valuesNode = (ValuesNode) node; + + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); + statsBuilder.setOutputRowCount(valuesNode.getRows().size()); + + for (int symbolId = 0; symbolId < valuesNode.getOutputSymbols().size(); ++symbolId) { + Symbol symbol = valuesNode.getOutputSymbols().get(symbolId); + List symbolValues = getSymbolValues(valuesNode, symbolId, session, types.get(symbol)); + statsBuilder.addSymbolStatistics(symbol, buildSymbolStatistics(symbolValues, session, types.get(symbol))); + } + + return Optional.of(statsBuilder.build()); + } + + private List getSymbolValues(ValuesNode valuesNode, int symbolId, Session session, Type symbolType) + { + if (UNKNOWN.equals(symbolType)) { + // special casing for UNKNOWN as evaluateConstantExpression does not handle that + return IntStream.range(0, valuesNode.getRows().size()) + .mapToObj(rowId -> null) + .collect(toList()); + } + return valuesNode.getRows().stream() + .map(row -> row.get(symbolId)) + .map(expression -> evaluateConstantExpression(expression, symbolType, metadata, session, ImmutableList.of())) + .collect(toList()); + } + + private SymbolStatsEstimate buildSymbolStatistics(List values, Session session, Type type) + { + DomainConverter domainConverter = new DomainConverter(type, metadata.getFunctionRegistry(), session.toConnectorSession()); + + List nonNullValues = values.stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + + if (nonNullValues.isEmpty()) { + return SymbolStatsEstimate.builder() + .setLowValue(Double.NaN) + .setHighValue(Double.NaN) + .setNullsFraction(values.isEmpty() ? 0.0 : 1.0) + .setDistinctValuesCount(0.0) + .build(); + } + else { + double[] valuesAsDoubles = nonNullValues.stream() + .map(domainConverter::translateToDouble) + .filter(OptionalDouble::isPresent) + .mapToDouble(OptionalDouble::getAsDouble) + .toArray(); + + double lowValue = DoubleStream.of(valuesAsDoubles).min().orElse(Double.NEGATIVE_INFINITY); + double highValue = DoubleStream.of(valuesAsDoubles).max().orElse(Double.POSITIVE_INFINITY); + double valuesCount = values.size(); + double nonNullValuesCount = nonNullValues.size(); + long distinctValuesCount = nonNullValues.stream().distinct().count(); + + return SymbolStatsEstimate.builder() + .setNullsFraction((valuesCount - nonNullValuesCount) / valuesCount) + .setLowValue(lowValue) + .setHighValue(highValue) + .setDistinctValuesCount(distinctValuesCount) + .build(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 87067b1fdb609..39f4c57413e18 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -18,7 +18,6 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.scheduler.ExecutionPolicy; import com.facebook.presto.execution.scheduler.NodeScheduler; @@ -51,6 +50,7 @@ import com.facebook.presto.sql.planner.PlanOptimizers; import com.facebook.presto.sql.planner.StageExecutionPlan; import com.facebook.presto.sql.planner.SubPlan; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.Expression; @@ -105,7 +105,7 @@ public final class SqlQueryExecution private final FailureDetector failureDetector; private final QueryExplainer queryExplainer; - private final CostCalculator costCalculator; + private final Lookup lookup; private final AtomicReference queryScheduler = new AtomicReference<>(); private final AtomicReference queryPlan = new AtomicReference<>(); private final NodeTaskMap nodeTaskMap; @@ -125,7 +125,7 @@ public SqlQueryExecution(QueryId queryId, SplitManager splitManager, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, - CostCalculator costCalculator, + Lookup lookup, List planOptimizers, RemoteTaskFactory remoteTaskFactory, LocationFactory locationFactory, @@ -146,7 +146,7 @@ public SqlQueryExecution(QueryId queryId, this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.locationFactory = requireNonNull(locationFactory, "locationFactory is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -308,7 +308,7 @@ private PlanRoot doAnalyzeQuery() // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, costCalculator); + LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, lookup); Plan plan = logicalPlanner.plan(analysis); queryPlan.set(plan); @@ -571,7 +571,7 @@ public static class SqlQueryExecutionFactory private final SplitManager splitManager; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; - private final CostCalculator costCalculator; + private final Lookup lookup; private final List planOptimizers; private final RemoteTaskFactory remoteTaskFactory; private final TransactionManager transactionManager; @@ -592,7 +592,7 @@ public static class SqlQueryExecutionFactory SplitManager splitManager, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, - CostCalculator costCalculator, + Lookup lookup, PlanOptimizers planOptimizers, RemoteTaskFactory remoteTaskFactory, TransactionManager transactionManager, @@ -623,7 +623,7 @@ public static class SqlQueryExecutionFactory this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.executionPolicies = requireNonNull(executionPolicies, "schedulerPolicies is null"); - this.costCalculator = requireNonNull(costCalculator, "cost calculator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); this.planOptimizers = planOptimizers.get(); } @@ -647,7 +647,7 @@ public SqlQueryExecution createQueryExecution(QueryId queryId, String query, Ses splitManager, nodePartitioningManager, nodeScheduler, - costCalculator, + lookup, planOptimizers, remoteTaskFactory, locationFactory, diff --git a/presto-main/src/main/java/com/facebook/presto/matching/Matchable.java b/presto-main/src/main/java/com/facebook/presto/matching/Matchable.java new file mode 100644 index 0000000000000..a25c0b7374b38 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/matching/Matchable.java @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.matching; + +public interface Matchable +{ + Pattern getPattern(); +} diff --git a/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java b/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java new file mode 100644 index 0000000000000..52347773844fd --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.matching; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Streams; +import com.google.common.collect.TreeTraverser; + +import java.util.Iterator; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.facebook.presto.util.MoreLists.asList; +import static java.util.Arrays.asList; + +public class MatchingEngine +{ + private final ListMultimap matchablesByClass; + + private MatchingEngine(ListMultimap matchablesByClass) + { + this.matchablesByClass = ImmutableListMultimap.copyOf(matchablesByClass); + } + + public Stream getCandidates(Object object) + { + return Streams.stream(ancestors(object.getClass())) + .flatMap(clazz -> matchablesByClass.get(clazz).stream()); + } + + private static Iterator ancestors(Class clazz) + { + return TreeTraverser.using( + (Class n) -> ImmutableList.builder() + .addAll(asList(Optional.ofNullable(n.getSuperclass()))) + .addAll(asList(n.getInterfaces())) + .build()) + .preOrderTraversal(clazz) + .iterator(); + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private final ImmutableListMultimap.Builder matchablesByClass = ImmutableListMultimap.builder(); + + public Builder register(Set matchables) + { + matchables.forEach(this::register); + return this; + } + + public Builder register(T matchable) + { + Pattern pattern = matchable.getPattern(); + if (pattern instanceof Pattern.MatchByClass) { + matchablesByClass.put(((Pattern.MatchByClass) pattern).getObjectClass(), matchable); + } + else { + throw new IllegalArgumentException("Unexpected Pattern: " + pattern); + } + return this; + } + + public MatchingEngine build() + { + return new MatchingEngine(matchablesByClass.build()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java b/presto-main/src/main/java/com/facebook/presto/matching/Pattern.java similarity index 57% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java rename to presto-main/src/main/java/com/facebook/presto/matching/Pattern.java index a9ceb13168720..1968dc44a3f51 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java +++ b/presto-main/src/main/java/com/facebook/presto/matching/Pattern.java @@ -12,57 +12,55 @@ * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.sql.planner.plan.PlanNode; +package com.facebook.presto.matching; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; public abstract class Pattern { - private static final Pattern ANY_NODE = new MatchNodeClass(PlanNode.class); + private static final Pattern ANY = new MatchByClass(Object.class); private Pattern() {} - public abstract boolean matches(PlanNode node); + public abstract boolean matches(Object object); public static Pattern any() { - return ANY_NODE; + return ANY; } - public static Pattern node(Class nodeClass) + public static Pattern matchByClass(Class objectClass) { - return new MatchNodeClass(nodeClass); + return new MatchByClass(objectClass); } - static class MatchNodeClass + static class MatchByClass extends Pattern { - private final Class nodeClass; + private final Class objectClass; - MatchNodeClass(Class nodeClass) + MatchByClass(Class objectClass) { - this.nodeClass = requireNonNull(nodeClass, "nodeClass is null"); + this.objectClass = requireNonNull(objectClass, "objectClass is null"); } - Class getNodeClass() + Class getObjectClass() { - return nodeClass; + return objectClass; } @Override - public boolean matches(PlanNode node) + public boolean matches(Object object) { - return nodeClass.isInstance(node); + return objectClass.isInstance(object); } @Override public String toString() { return toStringHelper(this) - .add("nodeClass", nodeClass) + .add("objectClass", objectClass) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java b/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java index fa052d492c6d1..932f16ba27aad 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java @@ -157,8 +157,7 @@ public synchronized void process(Iterable queries) long totalBytes = 0; for (QueryExecution query : queries) { long bytes = query.getTotalMemoryReservation(); - DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession()); - long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes()); + long queryMemoryLimit = getQueryMaxMemory(query.getSession()).toBytes(); totalBytes += bytes; if (resourceOvercommit(query.getSession()) && outOfMemory) { // If a query has requested resource overcommit, only kill it if the cluster has run out of memory diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java index f53679de0a47d..4b2fd96b735aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryPerformanceFetcher; import com.facebook.presto.execution.StageId; @@ -23,6 +22,7 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -45,16 +45,16 @@ public static class ExplainAnalyzeOperatorFactory private final PlanNodeId planNodeId; private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; - private final CostCalculator costCalculator; + private final Lookup lookup; private boolean closed; - public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) + public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, Lookup lookup) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); } @Override @@ -68,7 +68,7 @@ public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, ExplainAnalyzeOperator.class.getSimpleName()); - return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata, costCalculator); + return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata, lookup); } @Override @@ -80,23 +80,23 @@ public void close() @Override public OperatorFactory duplicate() { - return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata, costCalculator); + return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata, lookup); } } private final OperatorContext operatorContext; private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; - private final CostCalculator costCalculator; + private final Lookup lookup; private boolean finishing; private boolean outputConsumed; - public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) + public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, Lookup lookup) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); } @Override @@ -151,7 +151,7 @@ public Page getOutput() return null; } - String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, costCalculator, operatorContext.getSession()); + String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, lookup, operatorContext.getSession()); BlockBuilder builder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); VARCHAR.writeString(builder, plan); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index bd88e4c5020bc..8c4a3d2ed8f5f 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -21,8 +21,33 @@ import com.facebook.presto.client.NodeVersion; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.connector.system.SystemConnectorModule; -import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.AggregationStatsRule; +import com.facebook.presto.cost.CapDistinctValuesCountToOutputRowsCount; +import com.facebook.presto.cost.CapDistinctValuesCountToTypeDomainRangeLength; +import com.facebook.presto.cost.CoefficientBasedStatsCalculator; +import com.facebook.presto.cost.ComposableStatsCalculator; import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostCalculator.EstimatedExchanges; +import com.facebook.presto.cost.CostCalculatorUsingExchanges; +import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.EnforceSingleRowStatsRule; +import com.facebook.presto.cost.EnsureStatsMatchOutput; +import com.facebook.presto.cost.ExchangeStatsRule; +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.FilterStatsRule; +import com.facebook.presto.cost.IntersectStatsRule; +import com.facebook.presto.cost.JoinStatsRule; +import com.facebook.presto.cost.LimitStatsRule; +import com.facebook.presto.cost.OutputStatsRule; +import com.facebook.presto.cost.ProjectStatsRule; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.SelectingStatsCalculator; +import com.facebook.presto.cost.SelectingStatsCalculator.New; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.TableScanStatsRule; +import com.facebook.presto.cost.UnionStatsRule; +import com.facebook.presto.cost.ValuesStatsRule; import com.facebook.presto.event.query.QueryMonitor; import com.facebook.presto.event.query.QueryMonitorConfig; import com.facebook.presto.execution.LocationFactory; @@ -114,6 +139,8 @@ import com.facebook.presto.sql.planner.LocalExecutionPlanner; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.PlanOptimizers; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.StatelessLookup; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.transaction.ForTransactionManager; @@ -186,10 +213,10 @@ protected void setup(Binder binder) if (serverConfig.isCoordinator()) { install(new CoordinatorModule()); - binder.bind(new TypeLiteral>(){}).toProvider(QueryPerformanceFetcherProvider.class).in(Scopes.SINGLETON); + binder.bind(new TypeLiteral>() {}).toProvider(QueryPerformanceFetcherProvider.class).in(Scopes.SINGLETON); } else { - binder.bind(new TypeLiteral>(){}).toInstance(Optional.empty()); + binder.bind(new TypeLiteral>() {}).toInstance(Optional.empty()); // Install no-op resource group manager on workers, since only coordinators manage resource groups. binder.bind(ResourceGroupManager.class).to(NoOpResourceGroupManager.class).in(Scopes.SINGLETON); @@ -349,7 +376,16 @@ protected void setup(Binder binder) binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); // statistics calculator - binder.bind(CostCalculator.class).to(CoefficientBasedCostCalculator.class).in(Scopes.SINGLETON); + binder.bind(CostComparator.class).in(Scopes.SINGLETON); + binder.bind(CostCalculator.class).to(CostCalculatorUsingExchanges.class).in(Scopes.SINGLETON); + binder.bind(CostCalculator.class) + .annotatedWith(EstimatedExchanges.class) + .to(CostCalculatorWithEstimatedExchanges.class).in(Scopes.SINGLETON); + binder.bind(Lookup.class).to(StatelessLookup.class).in(Scopes.SINGLETON); + binder.bind(StatsCalculator.class).annotatedWith(SelectingStatsCalculator.Old.class).to(CoefficientBasedStatsCalculator.class).in(Scopes.SINGLETON); + binder.bind(StatsCalculator.class).to(SelectingStatsCalculator.class).in(Scopes.SINGLETON); + binder.bind(FilterStatsCalculator.class).in(Scopes.SINGLETON); + binder.bind(ScalarStatsCalculator.class).in(Scopes.SINGLETON); // type binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); @@ -447,6 +483,33 @@ protected void setup(Binder binder) binder.bind(ExecutorCleanup.class).in(Scopes.SINGLETON); } + @Provides + @Singleton + @New + public static StatsCalculator createNewStatsCalculator(Metadata metadata, FilterStatsCalculator filterStatsCalculator, ScalarStatsCalculator scalarStatsCalculator) + { + ImmutableSet.Builder rules = ImmutableSet.builder(); + rules.add(new OutputStatsRule()); + rules.add(new TableScanStatsRule(metadata)); + rules.add(new ValuesStatsRule(metadata)); + rules.add(new LimitStatsRule()); + rules.add(new EnforceSingleRowStatsRule()); + rules.add(new ExchangeStatsRule()); + rules.add(new ProjectStatsRule(scalarStatsCalculator)); + rules.add(new FilterStatsRule(filterStatsCalculator)); + rules.add(new JoinStatsRule(filterStatsCalculator)); + rules.add(new AggregationStatsRule()); + rules.add(new UnionStatsRule()); + rules.add(new IntersectStatsRule()); + + ImmutableList.Builder normalizers = ImmutableList.builder(); + normalizers.add(new EnsureStatsMatchOutput()); + normalizers.add(new CapDistinctValuesCountToOutputRowsCount()); + normalizers.add(new CapDistinctValuesCountToTypeDomainRangeLength()); + + return new ComposableStatsCalculator(rules.build(), normalizers.build()); + } + @Provides @Singleton @ForExchange diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 40909d6069a1a..ae2a5fc8d5aa6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -15,7 +15,6 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.connector.ConnectorManager; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.eventlistener.EventListenerManager; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.TaskManager; @@ -37,6 +36,7 @@ import com.facebook.presto.spi.Plugin; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.testing.ProcedureTester; import com.facebook.presto.testing.TestingAccessControlManager; import com.facebook.presto.testing.TestingEventListenerManager; @@ -101,7 +101,7 @@ public class TestingPrestoServer private final CatalogManager catalogManager; private final TransactionManager transactionManager; private final Metadata metadata; - private final CostCalculator costCalculator; + private final Lookup lookup; private final TestingAccessControlManager accessControl; private final ProcedureTester procedureTester; private final Optional resourceGroupManager; @@ -251,7 +251,7 @@ public TestingPrestoServer(boolean coordinator, catalogManager = injector.getInstance(CatalogManager.class); transactionManager = injector.getInstance(TransactionManager.class); metadata = injector.getInstance(Metadata.class); - costCalculator = injector.getInstance(CostCalculator.class); + lookup = injector.getInstance(Lookup.class); accessControl = injector.getInstance(TestingAccessControlManager.class); procedureTester = injector.getInstance(ProcedureTester.class); splitManager = injector.getInstance(SplitManager.class); @@ -349,9 +349,9 @@ public Metadata getMetadata() return metadata; } - public CostCalculator getCostCalculator() + public Lookup getLookup() { - return costCalculator; + return lookup; } public TestingAccessControlManager getAccessControl() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 8d8fb01bf6eef..4c7b3a0d9a99b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -39,6 +39,9 @@ }) public class FeaturesConfig { + private double cpuCostWeight = 0.75; + private double memoryCostWeight = 0; + private double networkCostWeight = 0.25; private boolean distributedIndexJoinsEnabled; private boolean distributedJoinsEnabled = true; private boolean colocatedJoinsEnabled; @@ -70,9 +73,46 @@ public class FeaturesConfig private double spillMaxUsedSpaceThreshold = 0.9; private boolean iterativeOptimizerEnabled = true; private boolean pushAggregationThroughJoin = true; + private boolean useNewStatsCalculator = true; private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata + public double getCpuCostWeight() + { + return cpuCostWeight; + } + + @Config("cpu-cost-weight") + public FeaturesConfig setCpuCostWeight(double cpuCostWeight) + { + this.cpuCostWeight = cpuCostWeight; + return this; + } + + public double getMemoryCostWeight() + { + return memoryCostWeight; + } + + @Config("memory-cost-weight") + public FeaturesConfig setMemoryCostWeight(double memoryCostWeight) + { + this.memoryCostWeight = memoryCostWeight; + return this; + } + + public double getNetworkCostWeight() + { + return networkCostWeight; + } + + @Config("network-cost-weight") + public FeaturesConfig setNetworkCostWeight(double networkCostWeight) + { + this.networkCostWeight = networkCostWeight; + return this; + } + public boolean isResourceGroupsEnabled() { return resourceGroups; @@ -438,4 +478,16 @@ public FeaturesConfig setPushAggregationThroughJoin(boolean value) this.pushAggregationThroughJoin = value; return this; } + + @Config("experimental.use-new-stats-calculator") + public FeaturesConfig setUseNewStatsCalculator(boolean useNewStatsCalculator) + { + this.useNewStatsCalculator = useNewStatsCalculator; + return this; + } + + public boolean isUseNewStatsCalculator() + { + return useNewStatsCalculator; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java index 28236a161f95e..88159c5631e65 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.DataDefinitionTask; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -25,6 +24,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.PlanOptimizers; import com.facebook.presto.sql.planner.SubPlan; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; import com.facebook.presto.sql.tree.ExplainType.Type; @@ -46,7 +46,7 @@ public class QueryExplainer private final Metadata metadata; private final AccessControl accessControl; private final SqlParser sqlParser; - private final CostCalculator costCalculator; + private final Lookup lookup; private final Map, DataDefinitionTask> dataDefinitionTask; @Inject @@ -55,14 +55,14 @@ public QueryExplainer( Metadata metadata, AccessControl accessControl, SqlParser sqlParser, - CostCalculator costCalculator, + Lookup lookup, Map, DataDefinitionTask> dataDefinitionTask) { this(planOptimizers.get(), metadata, accessControl, sqlParser, - costCalculator, + lookup, dataDefinitionTask); } @@ -71,14 +71,14 @@ public QueryExplainer( Metadata metadata, AccessControl accessControl, SqlParser sqlParser, - CostCalculator costCalculator, + Lookup lookup, Map, DataDefinitionTask> dataDefinitionTask) { this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); this.dataDefinitionTask = ImmutableMap.copyOf(requireNonNull(dataDefinitionTask, "dataDefinitionTask is null")); } @@ -98,10 +98,10 @@ public String getPlan(Session session, Statement statement, Type planType, List< switch (planType) { case LOGICAL: Plan plan = getLogicalPlan(session, statement, parameters); - return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, costCalculator, session); + return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, lookup, session); case DISTRIBUTED: SubPlan subPlan = getDistributedPlan(session, statement, parameters); - return PlanPrinter.textDistributedPlan(subPlan, metadata, costCalculator, session); + return PlanPrinter.textDistributedPlan(subPlan, metadata, lookup, session); } throw new IllegalArgumentException("Unhandled plan type: " + planType); } @@ -138,7 +138,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List queryPerformanceFetcher; @@ -252,6 +255,7 @@ public class LocalExecutionPlanner public LocalExecutionPlanner( Metadata metadata, SqlParser sqlParser, + StatsCalculator statsCalculator, CostCalculator costCalculator, Optional queryPerformanceFetcher, PageSourceProvider pageSourceProvider, @@ -278,6 +282,7 @@ public LocalExecutionPlanner( this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); @@ -608,7 +613,12 @@ public PhysicalOperation visitExplainAnalyze(ExplainAnalyzeNode node, LocalExecu checkState(queryPerformanceFetcher.isPresent(), "ExplainAnalyze can only run on coordinator"); PhysicalOperation source = node.getSource().accept(this, context); - OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata, costCalculator); + OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory( + context.getNextOperatorId(), + node.getId(), + queryPerformanceFetcher.get(), + metadata, + new StatelessLookup(statsCalculator, costCalculator)); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 8634cfc6ce685..d7520bd9d933c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -15,8 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.NewTableLayout; import com.facebook.presto.metadata.QualifiedObjectName; @@ -32,6 +30,7 @@ import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; @@ -39,7 +38,6 @@ import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -92,28 +90,28 @@ public enum Stage private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; private final SqlParser sqlParser; - private final CostCalculator costCalculator; + private final Lookup lookup; public LogicalPlanner(Session session, List planOptimizers, PlanNodeIdAllocator idAllocator, Metadata metadata, SqlParser sqlParser, - CostCalculator costCalculator) + Lookup lookup) { requireNonNull(session, "session is null"); requireNonNull(planOptimizers, "planOptimizers is null"); requireNonNull(idAllocator, "idAllocator is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(sqlParser, "sqlParser is null"); - requireNonNull(costCalculator, "costCalculator is null"); + requireNonNull(lookup, "lookup is null"); this.session = session; this.planOptimizers = planOptimizers; this.idAllocator = idAllocator; this.metadata = metadata; this.sqlParser = sqlParser; - this.costCalculator = costCalculator; + this.lookup = lookup; } public Plan plan(Analysis analysis) @@ -139,9 +137,7 @@ public Plan plan(Analysis analysis, Stage stage) PlanSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes()); } - Map planNodeCosts = costCalculator.calculateCostForPlan(session, symbolAllocator.getTypes(), root); - - return new Plan(root, symbolAllocator.getTypes(), planNodeCosts); + return new Plan(root, symbolAllocator.getTypes(), lookup, session); } public PlanNode planStatement(Analysis analysis, Statement statement) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java index 8a92461e3289a..d85f624dbc725 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java @@ -13,31 +13,38 @@ */ package com.facebook.presto.sql.planner; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableMap; +import java.util.List; import java.util.Map; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; public class Plan { private final PlanNode root; private final Map types; - private final Map planNodeCosts; + private final Map planNodeStats; - public Plan(PlanNode root, Map types, Map planNodeCosts) + public Plan(PlanNode root, Map types, Lookup lookup, Session session) { requireNonNull(root, "root is null"); requireNonNull(types, "types is null"); - requireNonNull(planNodeCosts, "planNodeCosts is null"); + requireNonNull(lookup, "lookup is null"); this.root = root; this.types = ImmutableMap.copyOf(types); - this.planNodeCosts = planNodeCosts; + this.planNodeStats = getPlanNodes(root) + .stream() + .collect(toImmutableMap(PlanNode::getId, node -> lookup.getStats(node, session, types))); } public PlanNode getRoot() @@ -50,8 +57,13 @@ public Map getTypes() return types; } - public Map getPlanNodeCosts() + public Map getPlanNodeStats() { - return planNodeCosts; + return planNodeStats; + } + + private List getPlanNodes(PlanNode root) + { + return searchFrom(root, Lookup.from(groupReference -> groupReference)).findAll(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 7617a2a419f74..a1f4c44c28340 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostCalculator.EstimatedExchanges; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; @@ -41,6 +45,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; +import com.facebook.presto.sql.planner.iterative.rule.PushDownTableConstraints; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughProject; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughSemiJoin; @@ -97,6 +102,8 @@ import java.util.List; import java.util.Set; +import static java.util.Objects.requireNonNull; + public class PlanOptimizers { private final List optimizers; @@ -104,9 +111,25 @@ public class PlanOptimizers private final MBeanExporter exporter; @Inject - public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig featuresConfig, MBeanExporter exporter) + public PlanOptimizers( + Metadata metadata, + SqlParser sqlParser, + FeaturesConfig featuresConfig, + MBeanExporter exporter, + CostComparator costComparator, + StatsCalculator statsCalculator, + CostCalculator costCalculator, + @EstimatedExchanges CostCalculator estimatedExchangesCostCalculator) { - this(metadata, sqlParser, featuresConfig, false, exporter); + this(metadata, + sqlParser, + featuresConfig, + false, + exporter, + costComparator, + statsCalculator, + costCalculator, + estimatedExchangesCostCalculator); } @PostConstruct @@ -121,8 +144,18 @@ public void destroy() stats.unexport(exporter); } - public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig featuresConfig, boolean forceSingleNode, MBeanExporter exporter) + public PlanOptimizers( + Metadata metadata, + SqlParser sqlParser, + FeaturesConfig featuresConfig, + boolean forceSingleNode, + MBeanExporter exporter, + CostComparator costComparator, + StatsCalculator statsCalculator, + CostCalculator costCalculator, + CostCalculator estimatedExchangesCostCalculator) { + requireNonNull(statsCalculator, "statsCalculator can not be null"); this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -142,12 +175,16 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea IterativeOptimizer inlineProjections = new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( new InlineProjections(), new RemoveRedundantIdentityProjections())); IterativeOptimizer projectionPushDown = new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableList.of(new ProjectionPushDown()), ImmutableSet.of( new PushProjectionThroughUnion(), @@ -158,6 +195,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new CanonicalizeExpressions(), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.builder() .addAll(predicatePushDownRules) .addAll(columnPruningRules) @@ -177,6 +216,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea ), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( new ImplementFilteredAggregations(), new ImplementBernoulliSampleAsFilter())), @@ -184,6 +225,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new UnaliasSymbolReferences(), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of(new RemoveRedundantIdentityProjections()) ), new SetFlatteningOptimizer(), @@ -193,9 +236,14 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea inlineProjections, new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of(new TransformExistsApplyToLateralNode(metadata.getFunctionRegistry()))), new TransformQuantifiedComparisonApplyToLateralJoin(metadata), - new IterativeOptimizer(stats, + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableList.of( new RemoveUnreferencedScalarLateralNodes(), new TransformUncorrelatedLateralToJoin(), @@ -208,10 +256,14 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea ), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableList.of(new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry())), ImmutableSet.of(new com.facebook.presto.sql.planner.iterative.rule.TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry()))), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( new TransformCorrelatedInPredicateToJoin(), // must be run after PruneUnreferencedOutputs new ImplementFilteredAggregations()) @@ -219,9 +271,16 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new TransformCorrelatedNoAggregationSubqueryToJoin(), new TransformCorrelatedSingleRowSubqueryToProject(), new PredicatePushDown(metadata, sqlParser), + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PushDownTableConstraints(metadata, sqlParser))), new PruneUnreferencedOutputs(), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( new RemoveRedundantIdentityProjections(), new PushAggregationThroughOuterJoin()) @@ -234,10 +293,14 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new IndexJoinOptimizer(metadata), // Run this after projections and filters have been fully simplified and pushed down new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of(new SimplifyCountOverConstant())), new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( // add UnaliasSymbolReferences when it's ported new RemoveRedundantIdentityProjections(), @@ -247,21 +310,32 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of(new RemoveRedundantIdentityProjections()) ), new MetadataQueryOptimizer(metadata), new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins()), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again ImmutableSet.of(new EliminateCrossJoins()) ), new PredicatePushDown(metadata, sqlParser), + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PushDownTableConstraints(metadata, sqlParser))), projectionPushDown); if (featuresConfig.isOptimizeSingleDistinct()) { builder.add( new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of(new SingleMarkDistinctToGroupBy())), new PruneUnreferencedOutputs()); } @@ -269,6 +343,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea builder.add(new OptimizeMixedDistinctAggregations(metadata)); builder.add(new IterativeOptimizer( stats, + statsCalculator, + estimatedExchangesCostCalculator, ImmutableSet.of( new CreatePartialTopN(), new PushTopNThroughUnion()))); @@ -284,6 +360,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea builder.add( new IterativeOptimizer( stats, + statsCalculator, + costCalculator, ImmutableSet.of(new RemoveEmptyDelete()) // Run RemoveEmptyDelete after table scan is removed by PickLayout/AddExchanges )); @@ -294,6 +372,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea builder.add(new PruneUnreferencedOutputs()); builder.add(new IterativeOptimizer( stats, + statsCalculator, + costCalculator, ImmutableSet.of(new RemoveRedundantIdentityProjections()))); // Optimizers above this don't understand local exchanges, so be careful moving this. @@ -304,6 +384,8 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea builder.add(new PartialAggregationPushDown(metadata.getFunctionRegistry())); builder.add(new IterativeOptimizer( stats, + statsCalculator, + costCalculator, ImmutableSet.of( new AddIntermediateAggregations(), new RemoveRedundantIdentityProjections()))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 612836aee46de..d76bea4dd65bf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -15,6 +15,11 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.cost.CachingCostCalculator; +import com.facebook.presto.cost.CachingStatsCalculator; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.matching.MatchingEngine; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -41,22 +46,26 @@ public class IterativeOptimizer implements PlanOptimizer { private final List legacyRules; - private final RuleStore ruleStore; + private final MatchingEngine ruleStore; private final StatsRecorder stats; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; - public IterativeOptimizer(StatsRecorder stats, Set rules) + public IterativeOptimizer(StatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, Set rules) { - this(stats, ImmutableList.of(), rules); + this(stats, statsCalculator, costCalculator, ImmutableList.of(), rules); } - public IterativeOptimizer(StatsRecorder stats, List legacyRules, Set newRules) + public IterativeOptimizer(StatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, List legacyRules, Set newRules) { this.legacyRules = ImmutableList.copyOf(legacyRules); - this.ruleStore = RuleStore.builder() + this.ruleStore = MatchingEngine.builder() .register(newRules) .build(); this.stats = stats; + this.statsCalculator = statsCalculator; + this.costCalculator = costCalculator; stats.registerAll(newRules); } @@ -74,7 +83,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } Memo memo = new Memo(idAllocator, plan); - Lookup lookup = Lookup.from(memo::resolve); + Lookup lookup = Lookup.from(memo::resolve, new CachingStatsCalculator(statsCalculator), new CachingCostCalculator(costCalculator)); Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 7305e501e05b1..65b2304952fcf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -13,10 +13,20 @@ */ package com.facebook.presto.sql.planner.iterative; +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; +import java.util.Map; import java.util.function.Function; +import static com.facebook.presto.cost.PlanNodeCostEstimate.INFINITE_COST; +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; import static com.google.common.base.Verify.verify; public interface Lookup @@ -24,32 +34,76 @@ public interface Lookup /** * Resolves a node by materializing GroupReference nodes * representing symbolic references to other nodes. - * * If the node is not a GroupReference, it returns the * argument as is. */ PlanNode resolve(PlanNode node); + PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types); + + PlanNodeCostEstimate getCumulativeCost(PlanNode node, Session session, Map types); + /** * A Lookup implementation that does not perform lookup. It satisfies contract * by rejecting {@link GroupReference}-s. */ static Lookup noLookup() { - return node -> { - verify(!(node instanceof GroupReference), "Unexpected GroupReference"); - return node; + return new Lookup() + { + @Override + public PlanNode resolve(PlanNode node) + { + verify(!(node instanceof GroupReference), "Unexpected GroupReference"); + return node; + } + + @Override + public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) + { + return UNKNOWN_STATS; + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode node, Session session, Map types) + { + return INFINITE_COST; + } }; } static Lookup from(Function resolver) { - return node -> { - if (node instanceof GroupReference) { - return resolver.apply((GroupReference) node); + return from(resolver, + (planNode, lookup, session, types) -> UNKNOWN_STATS, + (planNode, lookup, session, types) -> INFINITE_COST); + } + + static Lookup from(Function resolver, StatsCalculator statsCalculator, CostCalculator costCalculator) + { + return new Lookup() + { + @Override + public PlanNode resolve(PlanNode node) + { + if (node instanceof GroupReference) { + return resolver.apply((GroupReference) node); + } + + return node; } - return node; + @Override + public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) + { + return statsCalculator.calculateStats(resolve(node), this, session, types); + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode node, Session session, Map types) + { + return costCalculator.calculateCumulativeCost(node, this, session, types); + } }; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java index 4d2776bf068bf..17e17c59751d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java @@ -14,13 +14,15 @@ package com.facebook.presto.sql.planner.iterative; import com.facebook.presto.Session; +import com.facebook.presto.matching.Matchable; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.Optional; -public interface Rule +public interface Rule extends Matchable { /** * Returns a pattern to which plan nodes this rule applies. diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStats.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStats.java index 60d5f6bb6a113..70ee469eb94cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStats.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStats.java @@ -11,6 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.facebook.presto.sql.planner.iterative; import io.airlift.stats.TimeDistribution; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java deleted file mode 100644 index f190a5656092d..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import com.google.common.collect.Streams; - -import java.util.Iterator; -import java.util.Set; -import java.util.stream.Stream; - -public class RuleStore -{ - private final ListMultimap, Rule> rulesByClass; - - private RuleStore(ListMultimap, Rule> rulesByClass) - { - this.rulesByClass = ImmutableListMultimap.copyOf(rulesByClass); - } - - public Stream getCandidates(PlanNode planNode) - { - return Streams.stream(ancestors(planNode.getClass())) - .flatMap(clazz -> rulesByClass.get(clazz).stream()); - } - - private static Iterator> ancestors(Class planNodeClass) - { - return new AbstractIterator>() { - private Class current = planNodeClass; - - @Override - protected Class computeNext() - { - if (!PlanNode.class.isAssignableFrom(current)) { - return endOfData(); - } - - Class result = (Class) current; - current = current.getSuperclass(); - - return result; - } - }; - } - - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private final ImmutableListMultimap.Builder, Rule> rulesByClass = ImmutableListMultimap.builder(); - - public Builder register(Set newRules) - { - newRules.forEach(this::register); - return this; - } - - public Builder register(Rule newRule) - { - Pattern pattern = newRule.getPattern(); - if (pattern instanceof Pattern.MatchNodeClass) { - rulesByClass.put(((Pattern.MatchNodeClass) pattern).getNodeClass(), newRule); - } - else { - throw new IllegalArgumentException("Unexpected Pattern: " + pattern); - } - return this; - } - - public RuleStore build() - { - return new RuleStore(rulesByClass.build()); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/StatelessLookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/StatelessLookup.java new file mode 100644 index 0000000000000..ea27c76a356a5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/StatelessLookup.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.Map; + +import static com.google.common.base.Verify.verify; +import static java.util.Objects.requireNonNull; + +// TODO: remove. Eventually all uses of StatelessLookup should be replaced with the Lookup specific to the plan +@ThreadSafe +public class StatelessLookup + implements Lookup +{ + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; + + @Inject + public StatelessLookup(StatsCalculator statsCalculator, CostCalculator costCalculator) + { + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + } + + @Override + public PlanNode resolve(PlanNode node) + { + verify(!(node instanceof GroupReference), "Unexpected GroupReference"); + return node; + } + + @Override + public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types) + { + return statsCalculator.calculateStats( + planNode, + this, + session, + types); + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types) + { + return costCalculator.calculateCumulativeCost( + planNode, + this, + session, + types); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index 89fdb4e3c0d16..672bb2ff47c21 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -22,7 +23,6 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -67,7 +67,7 @@ public class AddIntermediateAggregations implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java index 82364de8c05c4..1729c360795a6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TopNNode; @@ -31,7 +31,7 @@ public class CreatePartialTopN implements Rule { - private static final Pattern PATTERN = Pattern.node(TopNNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(TopNNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 744925d4f4adb..2452f75ae7a40 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -15,11 +15,11 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; @@ -50,7 +50,7 @@ public class EliminateCrossJoins implements Rule { - private static final Pattern PATTERN = Pattern.node(JoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(JoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java index b379303dfe779..5faad997791ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +29,7 @@ public class EvaluateZeroLimit implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java index ecf5b95424fc8..c2ed5a1dc6bee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -32,7 +32,7 @@ public class EvaluateZeroSample implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(SampleNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java index 9705352c30c15..575d81749a798 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -50,7 +50,7 @@ public class ImplementBernoulliSampleAsFilter implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(SampleNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index e1e54119393ca..c3c75cd5c0d97 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -57,7 +57,7 @@ public class ImplementFilteredAggregations implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index 8df1175dbfe9e..26921dba5d9e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -14,13 +14,13 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -48,7 +48,7 @@ public class InlineProjections implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java index 38d3724c176d3..86056449398a3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -31,7 +31,7 @@ public class MergeAdjacentWindows implements Rule { - private static final Pattern PATTERN = Pattern.node(WindowNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(WindowNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java index c70cbed39298d..e5f3f8dc90329 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +29,7 @@ public class MergeFilters implements Rule { - private static final Pattern PATTERN = Pattern.node(FilterNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(FilterNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java index 616dd88cdad4b..b618f04cea0d7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; @@ -29,7 +29,7 @@ public class MergeLimitWithDistinct implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java index f6d466acc5644..6ddcd0d7c8cc5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +29,7 @@ public class MergeLimitWithSort implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java index cf91fbacfd17f..87c828695d496 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -28,7 +28,7 @@ public class MergeLimitWithTopN implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java index 06e9ef9ccf65c..51bae60a194a9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -27,7 +27,7 @@ public class MergeLimits implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java index a0aa799dbee89..22f2cf680e26a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -35,7 +35,7 @@ public class PruneCrossJoinColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java index 15fc5b8ee128f..3edd114e57b1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java @@ -14,12 +14,12 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -36,7 +36,7 @@ public class PruneJoinChildrenColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(JoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(JoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java index f2d4276ea28ca..00debf924e362 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,7 +37,7 @@ public class PruneJoinColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java index 9f4766e7a5985..db1bd1391d05d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,7 +37,7 @@ public class PruneMarkDistinctColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java index 16f1cb34dcfda..58b176f6e678e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -37,7 +37,7 @@ public class PruneSemiJoinColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java index 071a5965756bb..3e397ee34c452 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -35,7 +35,7 @@ public class PruneSemiJoinFilteringSourceColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(SemiJoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(SemiJoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java index 3e654efc9a942..a6e7f1614cf38 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -36,7 +36,7 @@ public class PruneTableScanColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java index ac85cfe0b4ff6..52ef3e8ba3237 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -38,7 +38,7 @@ public class PruneValuesColumns implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 073ebbb74f048..a7c89e58f3b9b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -14,12 +14,12 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -86,7 +86,7 @@ public class PushAggregationThroughOuterJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownTableConstraints.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownTableConstraints.java new file mode 100644 index 0000000000000..7ad14b97b935f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownTableConstraints.java @@ -0,0 +1,162 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableLayoutResult; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.DomainTranslator; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; +import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; +import static com.facebook.presto.sql.ExpressionUtils.stripDeterministicConjuncts; +import static com.facebook.presto.sql.ExpressionUtils.stripNonDeterministicConjuncts; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class PushDownTableConstraints + implements Rule +{ + private final Metadata metadata; + private final SqlParser sqlParser; + + public PushDownTableConstraints(Metadata metadata, SqlParser sqlParser) + { + this.metadata = metadata; + this.sqlParser = sqlParser; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof FilterNode)) { + return Optional.empty(); + } + + FilterNode filter = (FilterNode) node; + PlanNode source = lookup.resolve(filter.getSource()); + if (!(source instanceof TableScanNode)) { + return Optional.empty(); + } + + Expression predicate = filter.getPredicate(); + // don't include non-deterministic predicates + Expression deterministicPredicate = stripNonDeterministicConjuncts(predicate); + DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( + metadata, + session, + deterministicPredicate, + symbolAllocator.getTypes()); + + TableScanNode tableScan = (TableScanNode) source; + TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() + .transform(tableScan.getAssignments()::get) + .intersect(tableScan.getCurrentConstraint()); + + Map assignments = ImmutableBiMap.copyOf(tableScan.getAssignments()).inverse(); + + // Layouts will be returned in order of the connector's preference + List layouts = metadata.getLayouts( + session, tableScan.getTable(), + new Constraint<>(simplifiedConstraint, bindings -> true), + Optional.of(tableScan.getOutputSymbols().stream() + .map(tableScan.getAssignments()::get) + .collect(toImmutableSet()))); + + // Filter out layouts that cannot supply all the required columns + layouts = layouts.stream() + .filter(layoutHasAllNeededOutputs(tableScan)) + .collect(toImmutableList()); + + if (layouts.isEmpty()) { + return Optional.empty(); + } + + // At this point we have no way to choose between possible layouts, just take the first one + TableLayoutResult layout = layouts.get(0); + + PlanNode rewrittenPlan = new TableScanNode( + tableScan.getId(), + tableScan.getTable(), + tableScan.getOutputSymbols(), + tableScan.getAssignments(), + Optional.of(layout.getLayout().getHandle()), + simplifiedConstraint.intersect(layout.getLayout().getPredicate()), + Optional.ofNullable(tableScan.getOriginalConstraint()).orElse(predicate)); + + Expression resultingPredicate = combineConjuncts( + DomainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), + stripDeterministicConjuncts(predicate), + decomposedPredicate.getRemainingExpression()); + if (!BooleanLiteral.TRUE_LITERAL.equals(resultingPredicate)) { + rewrittenPlan = new FilterNode(idAllocator.getNextId(), rewrittenPlan, resultingPredicate); + } + + if (!planChanged(rewrittenPlan, filter, lookup, session, symbolAllocator)) { + return Optional.empty(); + } + return Optional.of(rewrittenPlan); + } + + private Predicate layoutHasAllNeededOutputs(TableScanNode node) + { + return layout -> { + List columnHandles = Lists.transform(node.getOutputSymbols(), node.getAssignments()::get); + return !layout.getLayout().getColumns().isPresent() + || layout.getLayout().getColumns().get().containsAll(columnHandles); + }; + } + + private boolean planChanged(PlanNode rewrittenPlan, FilterNode oldPlan, Lookup lookup, Session session, SymbolAllocator symbolAllocator) + { + if (!(rewrittenPlan instanceof FilterNode)) { + return true; + } + + FilterNode rewrittenFilter = (FilterNode) rewrittenPlan; + if (!new ExpressionEquivalence(metadata, sqlParser).areExpressionsEquivalent(session, rewrittenFilter.getPredicate(), oldPlan.getPredicate(), symbolAllocator.getTypes())) { + if (!ImmutableSet.copyOf(extractConjuncts(rewrittenFilter.getPredicate())).equals(ImmutableSet.copyOf(extractConjuncts(oldPlan.getPredicate())))) { + return true; + } + } + + TableScanNode oldTableScan = (TableScanNode) lookup.resolve(oldPlan.getSource()); + TableScanNode rewrittenTableScan = (TableScanNode) lookup.resolve(rewrittenFilter.getSource()); + return !rewrittenTableScan.getCurrentConstraint().equals(oldTableScan.getCurrentConstraint()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java index 9f2120ac06939..d7c346cc8311f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -30,7 +30,7 @@ public class PushLimitThroughMarkDistinct implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java index 562413ee3c2a8..dfcbcb871828c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,7 +30,7 @@ public class PushLimitThroughProject implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java index 8953b38b0da05..8ba08abcbfa61 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,7 +30,7 @@ public class PushLimitThroughSemiJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LimitNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 543846469a342..b8604b291f9f7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -21,7 +22,6 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -71,7 +71,7 @@ public class PushProjectionThroughExchange implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index e203eff472a1c..63e3a54df7163 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -14,13 +14,13 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -39,7 +39,7 @@ public class PushProjectionThroughUnion implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java index 280cde9c19e92..876a29a9ef8fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,7 +37,7 @@ public class PushTopNThroughUnion implements Rule { - private static final Pattern PATTERN = Pattern.node(TopNNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(TopNNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java index 6f6527da4fa6f..1a77ce76eb927 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -52,7 +52,7 @@ public class RemoveEmptyDelete implements Rule { - private static final Pattern PATTERN = Pattern.node(TableFinishNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(TableFinishNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java index bc7d079f1cc1a..fc729758c2f2d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -30,7 +30,7 @@ public class RemoveFullSample implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(SampleNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java index 5b0605650aa04..a94fca2630673 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -31,7 +31,7 @@ public class RemoveRedundantIdentityProjections implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ProjectNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java index d2e950f67017a..75d1866d51e42 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,7 +30,7 @@ public class RemoveUnreferencedScalarLateralNodes implements Rule { - private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LateralJoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 46d94c84adea2..062b04ed61273 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -14,13 +14,13 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -45,7 +45,7 @@ public class SimplifyCountOverConstant implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java index e841e849931a5..574abd4903bb3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -53,7 +53,7 @@ public class SingleMarkDistinctToGroupBy implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(AggregationNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java index 460f60612200c..dff48289b73f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -32,7 +32,7 @@ public class SwapAdjacentWindowsBySpecifications implements Rule { - private static final Pattern PATTERN = Pattern.node(WindowNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(WindowNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 1586de8483a69..c1750270d6529 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -21,7 +22,6 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin; import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin; @@ -95,7 +95,7 @@ public class TransformCorrelatedInPredicateToJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(ApplyNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ApplyNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java index 9ddf7ce43f6ee..67e4ac6fa495c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.ScalarAggregationToJoinRewriter; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -37,7 +37,7 @@ public class TransformCorrelatedScalarAggregationToJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LateralJoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 3b0e82e8a4230..f3f851bea365f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -14,13 +14,13 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -59,7 +59,7 @@ public class TransformExistsApplyToLateralNode implements Rule { - private static final Pattern PATTERN = Pattern.node(ApplyNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(ApplyNode.class); private static final QualifiedName COUNT = QualifiedName.of("count"); private static final FunctionCall COUNT_CALL = new FunctionCall(COUNT, ImmutableList.of()); private final Signature countSignature; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 1bb5fb31b23cb..59b2ae73d6e46 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -59,7 +59,7 @@ public class TransformUncorrelatedInPredicateSubqueryToSemiJoin @Override public Pattern getPattern() { - return Pattern.node(ApplyNode.class); + return Pattern.matchByClass(ApplyNode.class); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java index c3533093a5f5d..b258dab1dcc74 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -30,7 +30,7 @@ public class TransformUncorrelatedLateralToJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + private static final Pattern PATTERN = Pattern.matchByClass(LateralJoinNode.class); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index d3971ca45189a..d80e1f70f472a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.planPrinter; import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.execution.StageInfo; import com.facebook.presto.execution.StageStats; import com.facebook.presto.metadata.Metadata; @@ -30,7 +30,6 @@ import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.FunctionInvoker; import com.facebook.presto.sql.planner.Partitioning; @@ -39,6 +38,7 @@ import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -113,7 +113,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.facebook.presto.cost.PlanNodeCost.UNKNOWN_COST; import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.DomainUtils.simplifyDomain; @@ -126,6 +125,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -136,38 +136,36 @@ public class PlanPrinter private final Metadata metadata; private final Optional> stats; - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session sesion) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session sesion) { - this(plan, types, metadata, costCalculator, sesion, 0); + this(plan, types, metadata, lookup, sesion, 0); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session session, int indent) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); requireNonNull(metadata, "metadata is null"); - requireNonNull(costCalculator, "costCalculator is null"); + requireNonNull(lookup, "lookup is null"); this.metadata = metadata; this.stats = Optional.empty(); - Map costs = costCalculator.calculateCostForPlan(session, types, plan); - Visitor visitor = new Visitor(types, costs, session); + Visitor visitor = new Visitor(types, lookup, session); plan.accept(visitor, indent); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session session, Map stats, int indent) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); requireNonNull(metadata, "metadata is null"); - requireNonNull(costCalculator, "costCalculator is null"); + requireNonNull(lookup, "lookup is null"); this.metadata = metadata; this.stats = Optional.of(stats); - Map costs = costCalculator.calculateCostForPlan(session, types, plan); - Visitor visitor = new Visitor(types, costs, session); + Visitor visitor = new Visitor(types, lookup, session); plan.accept(visitor, indent); } @@ -177,22 +175,22 @@ public String toString() return output.toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session session) { - return new PlanPrinter(plan, types, metadata, costCalculator, session).toString(); + return new PlanPrinter(plan, types, metadata, lookup, session).toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session session, int indent) { - return new PlanPrinter(plan, types, metadata, costCalculator, session, indent).toString(); + return new PlanPrinter(plan, types, metadata, lookup, session, indent).toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Lookup lookup, Session session, Map stats, int indent) { - return new PlanPrinter(plan, types, metadata, costCalculator, session, stats, indent).toString(); + return new PlanPrinter(plan, types, metadata, lookup, session, stats, indent).toString(); } - public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, CostCalculator costCalculator, Session session) + public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, Lookup lookup, Session session) { StringBuilder builder = new StringBuilder(); List allStages = outputStageInfo.getSubStages().stream() @@ -200,23 +198,23 @@ public static String textDistributedPlan(StageInfo outputStageInfo, Metadata met .collect(toImmutableList()); for (StageInfo stageInfo : allStages) { Map aggregatedStats = aggregatePlanNodeStats(stageInfo); - builder.append(formatFragment(metadata, costCalculator, session, stageInfo.getPlan(), Optional.of(stageInfo.getStageStats()), Optional.of(aggregatedStats))); + builder.append(formatFragment(metadata, lookup, session, stageInfo.getPlan(), Optional.of(stageInfo.getStageStats()), Optional.of(aggregatedStats))); } return builder.toString(); } - public static String textDistributedPlan(SubPlan plan, Metadata metadata, CostCalculator costCalculator, Session session) + public static String textDistributedPlan(SubPlan plan, Metadata metadata, Lookup lookup, Session session) { StringBuilder builder = new StringBuilder(); for (PlanFragment fragment : plan.getAllFragments()) { - builder.append(formatFragment(metadata, costCalculator, session, fragment, Optional.empty(), Optional.empty())); + builder.append(formatFragment(metadata, lookup, session, fragment, Optional.empty(), Optional.empty())); } return builder.toString(); } - private static String formatFragment(Metadata metadata, CostCalculator costCalculator, Session session, PlanFragment fragment, Optional stageStats, Optional> planNodeStats) + private static String formatFragment(Metadata metadata, Lookup lookup, Session session, PlanFragment fragment, Optional stageStats, Optional> planNodeStats) { StringBuilder builder = new StringBuilder(); builder.append(format("Fragment %s [%s]\n", @@ -264,11 +262,11 @@ private static String formatFragment(Metadata metadata, CostCalculator costCalcu } if (stageStats.isPresent()) { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, planNodeStats.get(), 1)) + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, lookup, session, planNodeStats.get(), 1)) .append("\n"); } else { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, 1)) + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, lookup, session, 1)) .append("\n"); } @@ -453,14 +451,14 @@ private class Visitor extends PlanVisitor { private final Map types; - private final Map costs; + private final Lookup lookup; private final Session session; @SuppressWarnings("AssignmentToCollectionOrArrayFieldFromParameter") - public Visitor(Map types, Map costs, Session session) + public Visitor(Map types, Lookup lookup, Session session) { this.types = types; - this.costs = costs; + this.lookup = lookup; this.session = session; } @@ -468,7 +466,7 @@ public Visitor(Map types, Map costs, Ses public Void visitExplainAnalyze(ExplainAnalyzeNode node, Integer indent) { print(indent, "- ExplainAnalyze => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -495,7 +493,7 @@ public Void visitJoin(JoinNode node, Integer indent) } node.getSortExpression().ifPresent(expression -> print(indent + 2, "SortExpression[%s]", expression)); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getLeft().accept(this, indent + 1); node.getRight().accept(this, indent + 1); @@ -511,7 +509,7 @@ public Void visitSemiJoin(SemiJoinNode node, Integer indent) node.getFilteringSourceJoinSymbol(), formatHash(node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getSource().accept(this, indent + 1); node.getFilteringSource().accept(this, indent + 1); @@ -523,7 +521,7 @@ public Void visitSemiJoin(SemiJoinNode node, Integer indent) public Void visitIndexSource(IndexSourceNode node, Integer indent) { print(indent, "- IndexSource[%s, lookup = %s] => [%s]", node.getIndexHandle(), node.getLookupSymbols(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputSymbols().contains(entry.getKey())) { @@ -548,7 +546,7 @@ public Void visitIndexJoin(IndexJoinNode node, Integer indent) Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getProbeSource().accept(this, indent + 1); node.getIndexSource().accept(this, indent + 1); @@ -560,7 +558,7 @@ public Void visitIndexJoin(IndexJoinNode node, Integer indent) public Void visitLimit(LimitNode node, Integer indent) { print(indent, "- Limit%s[%s] => [%s]", node.isPartial() ? "Partial" : "", node.getCount(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -573,7 +571,7 @@ public Void visitDistinctLimit(DistinctLimitNode node, Integer indent) node.getLimit(), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -591,7 +589,7 @@ public Void visitAggregation(AggregationNode node, Integer indent) } print(indent, "- Aggregate%s%s%s => [%s]", type, key, formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry entry : node.getAggregations().entrySet()) { @@ -617,7 +615,7 @@ public Void visitGroupId(GroupIdNode node, Integer indent) .collect(Collectors.toList()); print(indent, "- GroupId%s => [%s]", inputGroupingSetSymbols, formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry mapping : node.getGroupingSetMappings().entrySet()) { @@ -639,7 +637,7 @@ public Void visitMarkDistinct(MarkDistinctNode node, Integer indent) formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -687,7 +685,7 @@ public Void visitWindow(WindowNode node, Integer indent) } print(indent, "- Window[%s]%s => [%s]", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { @@ -718,7 +716,7 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Integer indent) node.getMaxRowCountPerPartition(), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); print(indent + 2, "%s := %s", node.getRowNumberSymbol(), "row_number()"); @@ -742,7 +740,7 @@ public Void visitRowNumber(RowNumberNode node, Integer indent) Joiner.on(", ").join(args), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); print(indent + 2, "%s := %s", node.getRowNumberSymbol(), "row_number()"); @@ -754,7 +752,7 @@ public Void visitTableScan(TableScanNode node, Integer indent) { TableHandle table = node.getTable(); print(indent, "- TableScan[%s, originalConstraint = %s] => [%s]", table, node.getOriginalConstraint(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); printTableScanInfo(node, indent); @@ -765,7 +763,7 @@ public Void visitTableScan(TableScanNode node, Integer indent) public Void visitValues(ValuesNode node, Integer indent) { print(indent, "- Values => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (List row : node.getRows()) { print(indent + 2, "(" + Joiner.on(", ").join(row) + ")"); @@ -845,7 +843,7 @@ private Void visitScanFilterAndProjectInfo( format = operatorName + format; print(indent, format, arguments); - printCost(indent + 2, + printPlanNodesStatsAndCost(indent + 2, Stream.of(scanNode, filterNode, projectNode) .filter(Optional::isPresent) .map(Optional::get) @@ -913,7 +911,7 @@ private void printTableScanInfo(TableScanNode node, int indent) public Void visitUnnest(UnnestNode node, Integer indent) { print(indent, "- Unnest [replicate=%s, unnest=%s] => [%s]", formatOutputs(node.getReplicateSymbols()), formatOutputs(node.getUnnestSymbols().keySet()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -923,7 +921,7 @@ public Void visitUnnest(UnnestNode node, Integer indent) public Void visitOutput(OutputNode node, Integer indent) { print(indent, "- Output[%s] => [%s]", Joiner.on(", ").join(node.getColumnNames()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); @@ -942,7 +940,7 @@ public Void visitTopN(TopNNode node, Integer indent) Iterable keys = Iterables.transform(node.getOrderBy(), input -> input + " " + node.getOrderings().get(input)); print(indent, "- TopN[%s by (%s)] => [%s]", node.getCount(), Joiner.on(", ").join(keys), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -953,7 +951,7 @@ public Void visitSort(SortNode node, Integer indent) Iterable keys = Iterables.transform(node.getOrderBy(), input -> input + " " + node.getOrderings().get(input)); print(indent, "- Sort[%s] => [%s]", Joiner.on(", ").join(keys), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -962,7 +960,7 @@ public Void visitSort(SortNode node, Integer indent) public Void visitRemoteSource(RemoteSourceNode node, Integer indent) { print(indent, "- RemoteSource[%s] => [%s]", Joiner.on(',').join(node.getSourceFragmentIds()), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return null; @@ -972,7 +970,7 @@ public Void visitRemoteSource(RemoteSourceNode node, Integer indent) public Void visitUnion(UnionNode node, Integer indent) { print(indent, "- Union => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -982,7 +980,7 @@ public Void visitUnion(UnionNode node, Integer indent) public Void visitIntersect(IntersectNode node, Integer indent) { print(indent, "- Intersect => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -992,7 +990,7 @@ public Void visitIntersect(IntersectNode node, Integer indent) public Void visitExcept(ExceptNode node, Integer indent) { print(indent, "- Except => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1002,7 +1000,7 @@ public Void visitExcept(ExceptNode node, Integer indent) public Void visitTableWriter(TableWriterNode node, Integer indent) { print(indent, "- TableWriter => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); @@ -1017,7 +1015,7 @@ public Void visitTableWriter(TableWriterNode node, Integer indent) public Void visitTableFinish(TableFinishNode node, Integer indent) { print(indent, "- TableCommit[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1027,7 +1025,7 @@ public Void visitTableFinish(TableFinishNode node, Integer indent) public Void visitSample(SampleNode node, Integer indent) { print(indent, "- Sample[%s: %s] => [%s]", node.getSampleType(), node.getSampleRatio(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1052,7 +1050,7 @@ public Void visitExchange(ExchangeNode node, Integer indent) formatHash(node.getPartitioningScheme().getHashColumn()), formatOutputs(node.getOutputSymbols())); } - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1062,7 +1060,7 @@ public Void visitExchange(ExchangeNode node, Integer indent) public Void visitDelete(DeleteNode node, Integer indent) { print(indent, "- Delete[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1072,7 +1070,7 @@ public Void visitDelete(DeleteNode node, Integer indent) public Void visitMetadataDelete(MetadataDeleteNode node, Integer indent) { print(indent, "- MetadataDelete[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1082,7 +1080,7 @@ public Void visitMetadataDelete(MetadataDeleteNode node, Integer indent) public Void visitEnforceSingleRow(EnforceSingleRowNode node, Integer indent) { print(indent, "- Scalar => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1092,7 +1090,7 @@ public Void visitEnforceSingleRow(EnforceSingleRowNode node, Integer indent) public Void visitAssignUniqueId(AssignUniqueId node, Integer indent) { print(indent, "- AssignUniqueId => [%s]", formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1110,7 +1108,7 @@ public Void visitGroupReference(GroupReference node, Integer indent) public Void visitApply(ApplyNode node, Integer indent) { print(indent, "- Apply[%s] => [%s]", node.getCorrelation(), formatOutputs(node.getOutputSymbols())); - printCost(indent + 2, node); + printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); printAssignments(node.getSubqueryAssignments(), indent + 4); @@ -1221,32 +1219,45 @@ private String formatDomain(Domain domain) return "[" + Joiner.on(", ").join(parts.build()) + "]"; } - private void printCost(int indent, PlanNode... nodes) + private void printPlanNodesStatsAndCost(int indent, PlanNode... nodes) { - if (Arrays.stream(nodes).anyMatch(this::isKnownCost)) { - String costString = Joiner.on("/").join(Arrays.stream(nodes) - .map(this::formatCost) + if (Arrays.stream(nodes).anyMatch(this::isKnownPlanNodeStatsOrCost)) { + String formattedStatsAndCost = Joiner.on("/").join(Arrays.stream(nodes) + .map(this::formatPlanNodeStatsAndCost) .collect(toImmutableList())); - print(indent, "Cost: %s", costString); + print(indent, "Cost: %s", formattedStatsAndCost); } } - private boolean isKnownCost(PlanNode node) + private boolean isKnownPlanNodeStatsOrCost(PlanNode node) { - return !UNKNOWN_COST.equals(costs.getOrDefault(node.getId(), UNKNOWN_COST)); + return !PlanNodeCostEstimate.UNKNOWN_COST.equals(lookup.getCumulativeCost(node, session, types)) + || !PlanNodeStatsEstimate.UNKNOWN_STATS.equals(lookup.getStats(node, session, types)); } - private String formatCost(PlanNode node) + private String formatPlanNodeStatsAndCost(PlanNode node) { - PlanNodeCost cost = costs.getOrDefault(node.getId(), UNKNOWN_COST); - Estimate outputRowCount = cost.getOutputRowCount(); - Estimate outputSizeInBytes = cost.getOutputSizeInBytes(); - return String.format("{rows: %s, bytes: %s}", - outputRowCount.isValueUnknown() ? "?" : String.valueOf((long) outputRowCount.getValue()), - outputSizeInBytes.isValueUnknown() ? "?" : succinctBytes((long) outputSizeInBytes.getValue())); + PlanNodeStatsEstimate stats = lookup.getStats(node, session, types); + PlanNodeCostEstimate cost = lookup.getCumulativeCost(node, session, types); + return String.format("{rows: %s, bytes: %s, cpu: %s, memory: %s, network: %s}", + formatEstimate(stats.getOutputRowCount()), + formatEstimateAsDataSize(stats.getOutputSizeInBytes()), + formatEstimate(cost.getCpuCost()), + formatEstimateAsDataSize(cost.getMemoryCost()), + formatEstimate(cost.getNetworkCost())); } } + private static String formatEstimate(double value) + { + return isNaN(value) ? "?" : String.valueOf(value); + } + + private static String formatEstimateAsDataSize(double value) + { + return isNaN(value) ? "?" : succinctBytes((long) value).toString(); + } + private static String formatHash(Optional... hashes) { List symbols = Arrays.stream(hashes) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java index c5e482abbd1f8..975350f88603f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.rewrite; import com.facebook.presto.Session; +import com.facebook.presto.cost.DomainConverter; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.metadata.TableHandle; @@ -23,7 +24,9 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.RangeColumnStatistics; import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.QueryUtil; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.SemanticException; @@ -65,6 +68,9 @@ import java.util.TreeSet; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static com.facebook.presto.spi.statistics.RangeColumnStatistics.FRACTION_STATISTICS_KEY; +import static com.facebook.presto.spi.type.StandardTypes.DOUBLE; import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; import static com.facebook.presto.sql.QueryUtil.aliased; import static com.facebook.presto.sql.QueryUtil.selectAll; @@ -73,10 +79,12 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.sortedCopyOf; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; public class ShowStatsRewrite @@ -85,6 +93,12 @@ public class ShowStatsRewrite private static final List> ALLOWED_SHOW_STATS_WHERE_EXPRESSION_TYPES = ImmutableList.of( Literal.class, Identifier.class, ComparisonExpression.class, LogicalBinaryExpression.class, NotExpression.class, IsNullPredicate.class, IsNotNullPredicate.class); + private static final String COLUMN_NAME_COLUMN = "column_name"; + private static final String NULLS_FRACTION_COLUMN = "nulls_fraction"; + private static final String LOW_VALUE_COLUMN = "low_value"; + private static final String HIGH_VALUE_COLUMN = "high_value"; + private static final int MAX_LOW_HIGH_LENGTH = 32; + @Override public Statement rewrite(Session session, Metadata metadata, SqlParser parser, Optional queryExplainer, Statement node, List parameters, AccessControl accessControl) { @@ -195,16 +209,16 @@ private Node rewriteShowStats(ShowStats node, Table table, Constraint statisticsNames = findUniqueStatisticsNames(tableStatistics); - List resultColumnNames = buildColumnsNames(statisticsNames); - List selectItems = buildSelectItems(resultColumnNames); - Map columnNames = getStatisticsColumnNames(tableStatistics, node, table.getName()); - - List resultRows = buildStatisticsRows(tableStatistics, columnNames, statisticsNames); + List statsColumnNames = buildStatsColumnsNames(statisticsNames); + List selectItems = buildSelectItems(statsColumnNames); + Map tableColumnNames = getStatisticsColumnNames(tableStatistics, tableHandle); + Map tableColumnTypes = getStatisticsColumnTypes(tableStatistics, tableHandle); + List resultRows = buildStatisticsRows(tableStatistics, tableColumnNames, tableColumnTypes, statsColumnNames); return simpleQuery(selectAll(selectItems), aliased(new Values(resultRows), "table_stats_for_" + table.getName(), - resultColumnNames)); + statsColumnNames)); } private static void check(boolean condition, ShowStats node, String message) @@ -239,15 +253,20 @@ private Constraint getConstraint(QuerySpecification specification) return new Constraint<>(scanNode.get().getCurrentConstraint(), bindings -> true); } - private Map getStatisticsColumnNames(TableStatistics statistics, ShowStats node, QualifiedName tableName) + private Map getStatisticsColumnNames(TableStatistics statistics, TableHandle tableHandle) { - TableHandle tableHandle = getTableHandle(node, tableName); - return statistics.getColumnStatistics() .keySet().stream() .collect(toMap(identity(), column -> metadata.getColumnMetadata(session, tableHandle, column).getName())); } + private Map getStatisticsColumnTypes(TableStatistics statistics, TableHandle tableHandle) + { + return statistics.getColumnStatistics() + .keySet().stream() + .collect(toMap(identity(), column -> metadata.getColumnMetadata(session, tableHandle, column).getType())); + } + private TableHandle getTableHandle(ShowStats node, QualifiedName table) { QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, node, table); @@ -260,23 +279,23 @@ private static List findUniqueStatisticsNames(TableStatistics tableStati TreeSet statisticsKeys = new TreeSet<>(); statisticsKeys.addAll(tableStatistics.getTableStatistics().keySet()); for (ColumnStatistics columnStats : tableStatistics.getColumnStatistics().values()) { - statisticsKeys.addAll(columnStats.getStatistics().keySet()); + statisticsKeys.addAll(columnStats.getOnlyRangeColumnStatistics().getStatistics().keySet()); } return unmodifiableList(new ArrayList(statisticsKeys)); } - static List buildStatisticsRows(TableStatistics tableStatistics, Map columnNames, List statisticsNames) + List buildStatisticsRows(TableStatistics tableStatistics, Map sourceColumnNames, Map sourceColumnTypes, List statsColumnNames) { ImmutableList.Builder rowsBuilder = ImmutableList.builder(); // Stats for columns for (Map.Entry columnStats : tableStatistics.getColumnStatistics().entrySet()) { - Map columnStatisticsValues = columnStats.getValue().getStatistics(); - rowsBuilder.add(createStatsRow(Optional.of(columnNames.get(columnStats.getKey())), statisticsNames, columnStatisticsValues)); + ColumnHandle columnHandle = columnStats.getKey(); + rowsBuilder.add(createColumnStatsRow(sourceColumnNames.get(columnHandle), sourceColumnTypes.get(columnHandle), columnStats.getValue(), statsColumnNames)); } // Stats for whole table - rowsBuilder.add(createStatsRow(Optional.empty(), statisticsNames, tableStatistics.getTableStatistics())); + rowsBuilder.add(createTableStatsRow(statsColumnNames, tableStatistics)); return rowsBuilder.build(); } @@ -286,33 +305,99 @@ static List buildSelectItems(List columnNames) return columnNames.stream().map(QueryUtil::unaliasedName).collect(toImmutableList()); } - static List buildColumnsNames(List statisticsNames) + static List buildStatsColumnsNames(List statisticsNames) { ImmutableList.Builder columnNamesBuilder = ImmutableList.builder(); - columnNamesBuilder.add("column_name"); - columnNamesBuilder.addAll(statisticsNames); + columnNamesBuilder.add(COLUMN_NAME_COLUMN); + columnNamesBuilder.addAll(sortedCopyOf( + ImmutableList.builder() + .addAll(statisticsNames + .stream() + // we do not want to include "fraction" in show stats output if we do not have histograms + .filter(name -> !name.equals(FRACTION_STATISTICS_KEY)) + .collect(toList())) + .add(NULLS_FRACTION_COLUMN) + .build())); + columnNamesBuilder.add(LOW_VALUE_COLUMN); + columnNamesBuilder.add(HIGH_VALUE_COLUMN); return columnNamesBuilder.build(); } - private static Row createStatsRow(Optional columnName, List statisticsNames, Map columnStatisticsValues) + private Row createColumnStatsRow(String columnName, Type columnType, ColumnStatistics columnStatistics, List statsColumnNames) { ImmutableList.Builder rowValues = ImmutableList.builder(); - Expression columnNameExpression = columnName.map(name -> (Expression) new StringLiteral(name)).orElse(new Cast(new NullLiteral(), VARCHAR)); + RangeColumnStatistics rangeStatistics = columnStatistics.getOnlyRangeColumnStatistics(); + Map statisticsValues = rangeStatistics.getStatistics(); + for (String statColumnName : statsColumnNames) { + switch (statColumnName) { + case COLUMN_NAME_COLUMN: + rowValues.add(new StringLiteral(columnName)); + break; + case LOW_VALUE_COLUMN: + rowValues.add((lowHighAsLiteral(columnType, rangeStatistics.getLowValue()))); + break; + case HIGH_VALUE_COLUMN: + rowValues.add(lowHighAsLiteral(columnType, rangeStatistics.getHighValue())); + break; + case NULLS_FRACTION_COLUMN: + rowValues.add(createStatisticValueOrNull(columnStatistics.getNullsFraction())); + break; + default: + rowValues.add(createStatisticValueOrNull(statisticsValues, statColumnName)); + break; + } + } + return new Row(rowValues.build()); + } - rowValues.add(columnNameExpression); - for (String statName : statisticsNames) { - rowValues.add(createStatisticValueOrNull(columnStatisticsValues, statName)); + private Expression lowHighAsLiteral(Type valueType, Optional value) + { + if (!value.isPresent()) { + return new Cast(new NullLiteral(), VARCHAR); + } + + DomainConverter domainConverter = new DomainConverter(valueType, metadata.getFunctionRegistry(), session.toConnectorSession()); + String stringValue = domainConverter.castToVarchar(value.get()).toStringUtf8(); + if (stringValue.length() > MAX_LOW_HIGH_LENGTH) { + stringValue = stringValue.substring(0, MAX_LOW_HIGH_LENGTH) + "..."; + } + return new StringLiteral(stringValue); + } + + private static Expression createTableStatsRow(List columnNames, TableStatistics tableStatistics) + { + ImmutableList.Builder rowValues = ImmutableList.builder(); + Map statisticsValues = tableStatistics.getTableStatistics(); + for (String columnName : columnNames) { + switch (columnName) { + case COLUMN_NAME_COLUMN: + case LOW_VALUE_COLUMN: + case HIGH_VALUE_COLUMN: + rowValues.add(new Cast(new NullLiteral(), VARCHAR)); + break; + case NULLS_FRACTION_COLUMN: + rowValues.add(new Cast(new NullLiteral(), DOUBLE)); + break; + default: + rowValues.add(createStatisticValueOrNull(statisticsValues, columnName)); + break; + } } return new Row(rowValues.build()); } private static Expression createStatisticValueOrNull(Map columnStatisticsValues, String statName) { - if (columnStatisticsValues.containsKey(statName) && !columnStatisticsValues.get(statName).isValueUnknown()) { - return new DoubleLiteral(Double.toString(columnStatisticsValues.get(statName).getValue())); + return createStatisticValueOrNull(columnStatisticsValues.getOrDefault(statName, unknownValue())); + } + + private static Expression createStatisticValueOrNull(Estimate estimate) + { + if (!estimate.isValueUnknown()) { + return new DoubleLiteral(Double.toString(estimate.getValue())); } else { - return new NullLiteral(); + return new Cast(new NullLiteral(), DOUBLE); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 300e072d46c18..d938c37b8e7fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -29,8 +29,15 @@ import com.facebook.presto.connector.system.SchemaPropertiesSystemTable; import com.facebook.presto.connector.system.TablePropertiesSystemTable; import com.facebook.presto.connector.system.TransactionsSystemTable; -import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CoefficientBasedStatsCalculator; import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostCalculatorUsingExchanges; +import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.SelectingStatsCalculator; +import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.execution.CommitTask; import com.facebook.presto.execution.CreateTableTask; import com.facebook.presto.execution.CreateViewTask; @@ -85,6 +92,7 @@ import com.facebook.presto.operator.project.InterpretedPageProjection; import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.operator.project.PageProjection; +import com.facebook.presto.server.ServerMainModule; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorPageSource; @@ -123,6 +131,8 @@ import com.facebook.presto.sql.planner.PlanOptimizers; import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.StatelessLookup; import com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -211,7 +221,6 @@ public class LocalQueryRunner private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final MetadataManager metadata; - private final CostCalculator costCalculator; private final TestingAccessControlManager accessControl; private final SplitManager splitManager; private final BlockEncodingSerde blockEncodingSerde; @@ -221,6 +230,10 @@ public class LocalQueryRunner private final PageSinkManager pageSinkManager; private final TransactionManager transactionManager; private final SpillerFactory spillerFactory; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; + private final CostCalculator estimatedExchangesCostCalculator; + private final Lookup lookup; private final ExpressionCompiler expressionCompiler; private final JoinFilterFunctionCompiler joinFilterFunctionCompiler; @@ -286,7 +299,6 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new SchemaPropertyManager(), new TablePropertyManager(), transactionManager); - this.costCalculator = new CoefficientBasedCostCalculator(metadata); this.accessControl = new TestingAccessControlManager(transactionManager); this.pageSourceManager = new PageSourceManager(); @@ -363,6 +375,12 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, SpillerStats spillerStats = new SpillerStats(); this.spillerFactory = new GenericSpillerFactory(new FileSingleStreamSpillerFactory(blockEncodingSerde, spillerStats, featuresConfig)); + this.statsCalculator = new SelectingStatsCalculator( + new CoefficientBasedStatsCalculator(metadata), + ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata))); + this.costCalculator = new CostCalculatorUsingExchanges(getNodeCount()); + this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, getNodeCount()); + this.lookup = new StatelessLookup(statsCalculator, costCalculator); } public static LocalQueryRunner queryRunnerWithInitialTransaction(Session defaultSession) @@ -404,11 +422,26 @@ public Metadata getMetadata() } @Override + public Lookup getLookup() + { + return lookup; + } + + public StatsCalculator getStatsCalculator() + { + return statsCalculator; + } + public CostCalculator getCostCalculator() { return costCalculator; } + public CostCalculator getEstimatedExchangesCostCalculator() + { + return estimatedExchangesCostCalculator; + } + @Override public TestingAccessControlManager getAccessControl() { @@ -566,7 +599,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, Plan plan = createPlan(session, sql); if (printPlan) { - System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, costCalculator, session)); + System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, lookup, session)); } SubPlan subplan = PlanFragmenter.createSubPlans(session, metadata, plan); @@ -577,6 +610,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, sqlParser, + statsCalculator, costCalculator, Optional.empty(), pageSourceManager, @@ -680,7 +714,16 @@ public List getPlanOptimizers(boolean forceSingleNode) FeaturesConfig featuresConfig = new FeaturesConfig() .setDistributedIndexJoinsEnabled(false) .setOptimizeHashGeneration(true); - return new PlanOptimizers(metadata, sqlParser, featuresConfig, forceSingleNode, new MBeanExporter(new TestingMBeanServer())).get(); + return new PlanOptimizers( + metadata, + sqlParser, + featuresConfig, + forceSingleNode, + new MBeanExporter(new TestingMBeanServer()), + new CostComparator(featuresConfig), + statsCalculator, + costCalculator, + estimatedExchangesCostCalculator).get(); } public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers) @@ -708,11 +751,11 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List List> listOfListsCopy(List> lists) .collect(toImmutableList()); } + public static List asList(Optional optional) + { + return optional.map(ImmutableList::of).orElseGet(ImmutableList::of); + } + private MoreLists() {} } diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java new file mode 100644 index 0000000000000..917532d768e69 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import java.util.stream.DoubleStream; + +import static java.lang.Double.isNaN; + +public final class MoreMath +{ + private MoreMath() + { + } + + /** + * See http://floating-point-gui.de/errors/comparison/ + */ + public static boolean nearlyEqual(double a, double b, double epsilon) + { + double absA = Math.abs(a); + double absB = Math.abs(b); + double diff = Math.abs(a - b); + + if (a == b) { // shortcut, handles infinities + return true; + } + else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) { + // a or b is zero or both are extremely close to it + // relative error is less meaningful here + return diff < (epsilon * Double.MIN_NORMAL); + } + else { // use relative error + return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon; + } + } + + /** + * See http://floating-point-gui.de/errors/comparison/ + */ + public static boolean nearlyEqual(float a, float b, float epsilon) + { + float absA = Math.abs(a); + float absB = Math.abs(b); + float diff = Math.abs(a - b); + + if (a == b) { // shortcut, handles infinities + return true; + } + else if (a == 0 || b == 0 || diff < Float.MIN_NORMAL) { + // a or b is zero or both are extremely close to it + // relative error is less meaningful here + return diff < (epsilon * Float.MIN_NORMAL); + } + else { // use relative error + return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; + } + } + + public static double min(double... values) + { + return DoubleStream.of(values) + .min() + .getAsDouble(); + } + + public static double max(double... values) + { + return DoubleStream.of(values) + .max() + .getAsDouble(); + } + + public static double rangeMin(double left, double right) + { + if (isNaN(left)) { + return right; + } + else if (isNaN(right)) { + return left; + } + return min(left, right); + } + + public static double rangeMax(double left, double right) + { + if (isNaN(left)) { + return right; + } + else if (isNaN(right)) { + return left; + } + return max(left, right); + } + + public static boolean isPositiveOrNan(double value) + { + return value > 0 || isNaN(value); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/EstimateAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/EstimateAssertion.java new file mode 100644 index 0000000000000..62f3e2d238716 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/EstimateAssertion.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.util.MoreMath; + +import static java.lang.Double.isNaN; + +public final class EstimateAssertion +{ + private static final double TOLERANCE = 0.0000001; + + private EstimateAssertion() + { + } + + public static void assertEstimateEquals(double actual, double expected, String messageFormat, Object... messageObjects) + { + if (isNaN(actual) && isNaN(expected)) { + return; + } + + if (!MoreMath.nearlyEqual(actual, expected, TOLERANCE)) { + throw new AssertionError(String.format(messageFormat, messageObjects) + String.format(", expected [%f], but got [%f]", expected, actual)); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java new file mode 100644 index 0000000000000..c7e1d2b6daf11 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/PlanNodeStatsAssertion.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; + +import java.util.function.Consumer; + +import static com.facebook.presto.cost.EstimateAssertion.assertEstimateEquals; +import static com.google.common.collect.Sets.union; +import static org.testng.Assert.assertTrue; + +public class PlanNodeStatsAssertion +{ + private final PlanNodeStatsEstimate actual; + + private PlanNodeStatsAssertion(PlanNodeStatsEstimate actual) + { + this.actual = actual; + } + + public static PlanNodeStatsAssertion assertThat(PlanNodeStatsEstimate actual) + { + return new PlanNodeStatsAssertion(actual); + } + + public PlanNodeStatsAssertion outputRowsCount(double expected) + { + assertEstimateEquals(actual.getOutputRowCount(), expected, "outputRowsCount mismatch"); + return this; + } + + public PlanNodeStatsAssertion outputRowsCountUnknown() + { + assertTrue(Double.isNaN(actual.getOutputRowCount()), "expected unknown outputRowsCount but got " + actual.getOutputRowCount()); + return this; + } + + public PlanNodeStatsAssertion symbolStats(String symbolName, Consumer symbolStatsAssertionConsumer) + { + return symbolStats(new Symbol(symbolName), symbolStatsAssertionConsumer); + } + + public PlanNodeStatsAssertion symbolStats(Symbol symbol, Consumer columnAssertionConsumer) + { + SymbolStatsAssertion columnAssertion = SymbolStatsAssertion.assertThat(actual.getSymbolStatistics(symbol)); + columnAssertionConsumer.accept(columnAssertion); + return this; + } + + public PlanNodeStatsAssertion symbolStatsUnknown(String symbolName) + { + return symbolStatsUnknown(new Symbol(symbolName)); + } + + public PlanNodeStatsAssertion symbolStatsUnknown(Symbol symbol) + { + return symbolStats(symbol, + columnStats -> columnStats + .lowValueUnknown() + .highValueUnknown() + .nullsFractionUnknown() + .distinctValuesCountUnknown()); + } + + public PlanNodeStatsAssertion equalTo(PlanNodeStatsEstimate expected) + { + assertEstimateEquals(actual.getOutputRowCount(), expected.getOutputRowCount(), "outputRowCount mismatch"); + + for (Symbol symbol : union(expected.getSymbolsWithKnownStatistics(), actual.getSymbolsWithKnownStatistics())) { + assertSymbolStatsEqual(symbol, actual.getSymbolStatistics(symbol), expected.getSymbolStatistics(symbol)); + } + return this; + } + + private void assertSymbolStatsEqual(Symbol symbol, SymbolStatsEstimate actual, SymbolStatsEstimate expected) + { + assertEstimateEquals(actual.getNullsFraction(), expected.getNullsFraction(), "nullsFraction mismatch for " + symbol.getName()); + assertEstimateEquals(actual.getLowValue(), expected.getLowValue(), "lowValue mismatch for " + symbol.getName()); + assertEstimateEquals(actual.getHighValue(), expected.getHighValue(), "highValue mismatch for " + symbol.getName()); + assertEstimateEquals(actual.getDistinctValuesCount(), expected.getDistinctValuesCount(), "distinct values count mismatch for " + symbol.getName()); + assertEstimateEquals(actual.getAverageRowSize(), expected.getAverageRowSize(), "average row size mismatch for " + symbol.getName()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorAssertion.java new file mode 100644 index 0000000000000..94bc8362cb60d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorAssertion.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import static com.facebook.presto.cost.PlanNodeCostEstimate.INFINITE_COST; +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class StatsCalculatorAssertion +{ + private final StatsCalculator statsCalculator; + private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + private final Metadata metadata; + private final Session session; + private PlanNode planNode; + private Map sourcesStats; + private Map types; + + public StatsCalculatorAssertion(StatsCalculator statsCalculator, Metadata metadata, Session session) + { + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator can not be null"); + this.metadata = requireNonNull(metadata, "metadata can not be null"); + this.session = requireNonNull(session, "sesssion can not be null"); + } + + public StatsCalculatorAssertion on(Function planProvider) + { + PlanBuilder planBuilder = new PlanBuilder(idAllocator, metadata); + this.planNode = planProvider.apply(planBuilder); + this.sourcesStats = new HashMap<>(); + this.planNode.getSources().forEach(child -> sourcesStats.put(child, UNKNOWN_STATS)); + this.types = planBuilder.getSymbols(); + return this; + } + + public StatsCalculatorAssertion withSourceStats(Consumer sourceStatsBuilderConsumer) + { + checkPlanNodeSet(); + checkState(planNode.getSources().size() == 1, "expected single source"); + return withSourceStats(0, sourceStatsBuilderConsumer); + } + + public StatsCalculatorAssertion withSourceStats(PlanNodeStatsEstimate sourceStats) + { + checkPlanNodeSet(); + checkState(planNode.getSources().size() == 1, "expected single source"); + return withSourceStats(0, sourceStats); + } + + public StatsCalculatorAssertion withSourceStats(int sourceIndex, Consumer sourceStatsBuilderConsumer) + { + PlanNodeStatsEstimate.Builder sourceStatsBuilder = PlanNodeStatsEstimate.builder(); + sourceStatsBuilderConsumer.accept(sourceStatsBuilder); + return withSourceStats(sourceIndex, sourceStatsBuilder.build()); + } + + public StatsCalculatorAssertion withSourceStats(int sourceIndex, PlanNodeStatsEstimate sourceStats) + { + checkPlanNodeSet(); + checkArgument(sourceIndex < planNode.getSources().size(), "invalid sourceIndex %s; planNode has %s sources", sourceIndex, planNode.getSources().size()); + sourcesStats.put(planNode.getSources().get(sourceIndex), sourceStats); + return this; + } + + public StatsCalculatorAssertion check(Consumer statisticsAssertionConsumer) + { + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, mockLookup(), session, types); + statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate)); + return this; + } + + private void checkPlanNodeSet() + { + checkState(planNode != null, "tested planNode not set yet"); + } + + private Lookup mockLookup() + { + return new Lookup() + { + @Override + public PlanNode resolve(PlanNode node) + { + throw new UnsupportedOperationException(); + } + + @Override + public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) + { + checkArgument(sourcesStats.containsKey(node), "stats not found for source %s", node); + return sourcesStats.get(node); + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode node, Session session, Map types) + { + return INFINITE_COST; + } + }; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java new file mode 100644 index 0000000000000..189b0b54df150 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/StatsCalculatorTester.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; + +import java.util.function.Function; + +import static com.facebook.presto.SystemSessionProperties.USE_NEW_STATS_CALCULATOR; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class StatsCalculatorTester + implements AutoCloseable +{ + private final StatsCalculator statsCalculator; + private final Metadata metadata; + private final Session session; + private final LocalQueryRunner queryRunner; + + public StatsCalculatorTester() + { + this(createQueryRunner()); + } + + private StatsCalculatorTester(LocalQueryRunner queryRunner) + { + this.statsCalculator = queryRunner.getStatsCalculator(); + this.session = queryRunner.getDefaultSession(); + this.metadata = queryRunner.getMetadata(); + this.queryRunner = queryRunner; + } + + private static LocalQueryRunner createQueryRunner() + { + Session session = testSessionBuilder() + .setSystemProperty(USE_NEW_STATS_CALCULATOR, "true") + .build(); + + LocalQueryRunner queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog(session.getCatalog().get(), + new TpchConnectorFactory(1), + ImmutableMap.of()); + return queryRunner; + } + + public StatsCalculatorAssertion assertStatsFor(Function planProvider) + { + return new StatsCalculatorAssertion(statsCalculator, metadata, session).on(planProvider); + } + + @Override + public void close() + { + queryRunner.close(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java b/presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java new file mode 100644 index 0000000000000..0856a9ad94d54 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/SymbolStatsAssertion.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import static com.facebook.presto.cost.EstimateAssertion.assertEstimateEquals; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.lang.Double.isNaN; +import static org.testng.Assert.assertTrue; + +public class SymbolStatsAssertion +{ + private final SymbolStatsEstimate statistics; + + private SymbolStatsAssertion(SymbolStatsEstimate statistics) + { + this.statistics = statistics; + } + + public static SymbolStatsAssertion assertThat(SymbolStatsEstimate actual) + { + return new SymbolStatsAssertion(actual); + } + + public SymbolStatsAssertion nullsFraction(double expected) + { + // we bind nullsFraction and nonNullsFraction together + assertEstimateEquals(statistics.getNullsFraction(), expected, "nullsFraction mismatch"); + return this; + } + + public SymbolStatsAssertion nullsFractionUnknown() + { + // we bind nullsFraction and nonNullsFraction together + assertTrue(isNaN(statistics.getNullsFraction()), "expected unknown nullsFraction but got " + statistics.getNullsFraction()); + return this; + } + + public SymbolStatsAssertion lowValue(double expected) + { + assertEstimateEquals(statistics.getLowValue(), expected, "lowValue mismatch"); + return this; + } + + public SymbolStatsAssertion lowValueUnknown() + { + return lowValue(NEGATIVE_INFINITY); + } + + public SymbolStatsAssertion highValue(double expected) + { + assertEstimateEquals(statistics.getHighValue(), expected, "highValue mismatch"); + return this; + } + + public SymbolStatsAssertion highValueUnknown() + { + return highValue(POSITIVE_INFINITY); + } + + public SymbolStatsAssertion emptyRange() + { + assertTrue(isNaN(statistics.getLowValue()) && isNaN(statistics.getHighValue()), + "expected empty range (NaN, NaN) but got (" + statistics.getLowValue() + ", " + statistics.getHighValue() + ") instead"); + return this; + } + + public SymbolStatsAssertion unknownRange() + { + return lowValueUnknown() + .highValueUnknown(); + } + + public SymbolStatsAssertion distinctValuesCount(double expected) + { + assertEstimateEquals(statistics.getDistinctValuesCount(), expected, "distinctValuesCount mismatch"); + return this; + } + + public SymbolStatsAssertion distinctValuesCountUnknown() + { + assertTrue(isNaN(statistics.getDistinctValuesCount()), "expected unknown distinctValuesCount but got " + statistics.getDistinctValuesCount()); + return this; + } + + public SymbolStatsAssertion averageRowSize(double expected) + { + assertEstimateEquals(statistics.getAverageRowSize(), expected, "average row size mismatch"); + return this; + } + + public SymbolStatsAssertion dataSizeUnknown() + { + assertTrue(isNaN(statistics.getAverageRowSize()), "expected unknown dataSize but got " + statistics.getAverageRowSize()); + return this; + } + + public SymbolStatsAssertion isEqualTo(SymbolStatsEstimate expected) + { + return nullsFraction(expected.getNullsFraction()) + .lowValue(expected.getLowValue()) + .highValue(expected.getHighValue()) + .distinctValuesCount(expected.getDistinctValuesCount()) + .averageRowSize(expected.getAverageRowSize()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java new file mode 100644 index 0000000000000..96d40555d7716 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestAggregationStatsRule.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.function.Consumer; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; + +public class TestAggregationStatsRule +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testAggregationWhenAllStatisticsAreKnown() + { + Consumer outputRowCountAndZStatsAreCalculated = check -> check + .outputRowsCount(15) + .symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion.lowValue(10) + .highValue(15) + .distinctValuesCount(4) + .nullsFraction(0.2)); + testAggregation(SymbolStatsEstimate.builder().setLowValue(10).setHighValue(15).setDistinctValuesCount(4).setNullsFraction(0.1).build()) + .check(outputRowCountAndZStatsAreCalculated); + testAggregation(SymbolStatsEstimate.builder().setLowValue(10).setHighValue(15).setDistinctValuesCount(4).build()) + .check(outputRowCountAndZStatsAreCalculated); + + Consumer outputRowsCountAndZStatsAreNotFullyCalculated = check -> check + .outputRowsCountUnknown() + .symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion.lowValue(10) + .highValue(15) + .distinctValuesCountUnknown() + .nullsFractionUnknown()); + testAggregation(SymbolStatsEstimate.builder().setLowValue(10).setHighValue(15).setNullsFraction(0.1).build()) + .check(outputRowsCountAndZStatsAreNotFullyCalculated); + testAggregation(SymbolStatsEstimate.builder().setLowValue(10).setHighValue(15).build()) + .check(outputRowsCountAndZStatsAreNotFullyCalculated); + } + + private StatsCalculatorAssertion testAggregation(SymbolStatsEstimate zStats) + { + return tester.assertStatsFor(pb -> pb + .aggregation(ab -> ab + .addAggregation(pb.symbol("sum", BIGINT), expression("sum(x)"), ImmutableList.of(BIGINT)) + .addAggregation(pb.symbol("count", BIGINT), expression("count()"), ImmutableList.of()) + .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) + .addGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) + .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) + .withSourceStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setDistinctValuesCount(3) + .setNullsFraction(0) + .build()) + .addSymbolStatistics(new Symbol("z"), zStats) + .build()) + .check(check -> check + .symbolStats("sum", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown()) + .symbolStats("count", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown()) + .symbolStats("count_on_x", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown()) + .symbolStats("x", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown()) + .symbolStats("y", symbolStatsAssertion -> symbolStatsAssertion.lowValue(0).highValue(3).distinctValuesCount(3).nullsFraction(0))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToOutputRowsCount.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToOutputRowsCount.java new file mode 100644 index 0000000000000..3c13c55888dbf --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToOutputRowsCount.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.cost.SymbolStatsAssertion.assertThat; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; + +public class TestCapDistinctValuesCountToOutputRowsCount +{ + private static final ValuesNode NODE = new ValuesNode(new PlanNodeId("1"), emptyList(), emptyList()); + private static final Map TYPES = emptyMap(); + private static final Symbol A = new Symbol("a"); + private static final Symbol B = new Symbol("b"); + private static final Symbol C = new Symbol("c"); + + @Test + public void tesOutputRowCountIsKnown() + { + PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(A, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) + .addSymbolStatistics(B, SymbolStatsEstimate.builder().setDistinctValuesCount(5).build()) + .addSymbolStatistics(C, SymbolStatsEstimate.builder().build()) + .build(); + + assertThat(normalize(estimate).getSymbolStatistics(A)).distinctValuesCount(10); + assertThat(normalize(estimate).getSymbolStatistics(B)).distinctValuesCount(5); + assertThat(normalize(estimate).getSymbolStatistics(C)).distinctValuesCountUnknown(); + } + + @Test + public void testOutputRowCountIsNotKnown() + { + PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(A, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) + .build(); + + assertThat(normalize(estimate).getSymbolStatistics(A)).distinctValuesCount(20); + } + + private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate estimate) + { + CapDistinctValuesCountToOutputRowsCount normalizer = new CapDistinctValuesCountToOutputRowsCount(); + return normalizer.normalize(NODE, estimate, TYPES); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToTypeDomainRangeLength.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToTypeDomainRangeLength.java new file mode 100644 index 0000000000000..f5ef70c81c457 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCapDistinctValuesCountToTypeDomainRangeLength.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.testing.TestingConnectorSession; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.cost.SymbolStatsAssertion.assertThat; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.SmallintType.SMALLINT; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; +import static java.util.Collections.emptyList; + +public class TestCapDistinctValuesCountToTypeDomainRangeLength +{ + private final TypeManager typeManager = new TypeRegistry(); + private final FunctionRegistry functionRegistry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + private final ConnectorSession session = new TestingConnectorSession(emptyList()); + + @Test + public void test() + { + Symbol bool = new Symbol("bool"); + Symbol bool2 = new Symbol("bool2"); + Symbol tinyint = new Symbol("tinyint"); + Symbol smallint = new Symbol("smallint"); + Symbol integer = new Symbol("integer"); + Symbol integer2 = new Symbol("integer2"); + Symbol integer3 = new Symbol("integer3"); + Symbol bigint = new Symbol("bigint"); + Symbol decimal = new Symbol("decimal"); + Symbol decimal2 = new Symbol("decimal2"); + Symbol decimal3 = new Symbol("decimal3"); + Symbol double1 = new Symbol("double1"); + Symbol double2 = new Symbol("double2"); + + DecimalType decimalType = DecimalType.createDecimalType(10, 2); + PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(bool, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(true, BOOLEAN)) + .setHighValue(asStatsValue(true, BOOLEAN)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(bool2, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(false, BOOLEAN)) + .setHighValue(asStatsValue(true, BOOLEAN)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(tinyint, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, TINYINT)) + .setHighValue(asStatsValue(5, TINYINT)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(smallint, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, SMALLINT)) + .setHighValue(asStatsValue(5, SMALLINT)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(integer, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, INTEGER)) + .setHighValue(asStatsValue(5, INTEGER)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(integer2, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, INTEGER)) + .setHighValue(asStatsValue(5, INTEGER)) + .setDistinctValuesCount(3) + .build()) + .addSymbolStatistics(integer3, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, INTEGER)) + .setHighValue(asStatsValue(5, INTEGER)) + .build()) + .addSymbolStatistics(bigint, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, BIGINT)) + .setHighValue(asStatsValue(5, BIGINT)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(decimal, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(1, decimalType)) + .setHighValue(asStatsValue(1, decimalType)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(decimal2, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(101, decimalType)) + .setHighValue(asStatsValue(103, decimalType)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(decimal3, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(100, decimalType)) + .setHighValue(asStatsValue(200, decimalType)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(double1, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(10.1, DOUBLE)) + .setHighValue(asStatsValue(10.2, DOUBLE)) + .setDistinctValuesCount(10) + .build()) + .addSymbolStatistics(double2, SymbolStatsEstimate.builder() + .setLowValue(asStatsValue(10.0, DOUBLE)) + .setHighValue(asStatsValue(10.0, DOUBLE)) + .setDistinctValuesCount(10) + .build()) + .build(); + + Map types = ImmutableMap.builder() + .put(bool, BOOLEAN) + .put(bool2, BOOLEAN) + .put(tinyint, TINYINT) + .put(smallint, SMALLINT) + .put(integer, INTEGER) + .put(integer2, INTEGER) + .put(integer3, INTEGER) + .put(bigint, BIGINT) + .put(decimal, decimalType) + .put(decimal2, decimalType) + .put(decimal3, decimalType) + .put(double1, DOUBLE) + .put(double2, DOUBLE) + .build(); + + ComposableStatsCalculator.Normalizer normalizer = new CapDistinctValuesCountToTypeDomainRangeLength(); + PlanNodeStatsEstimate normalized = normalizer.normalize(null, estimate, types); + + assertThat(normalized.getSymbolStatistics(bool)).distinctValuesCount(1); + assertThat(normalized.getSymbolStatistics(bool2)).distinctValuesCount(2); + assertThat(normalized.getSymbolStatistics(tinyint)).distinctValuesCount(5); + assertThat(normalized.getSymbolStatistics(smallint)).distinctValuesCount(5); + assertThat(normalized.getSymbolStatistics(smallint)).distinctValuesCount(5); + assertThat(normalized.getSymbolStatistics(integer)).distinctValuesCount(5); + assertThat(normalized.getSymbolStatistics(integer2)).distinctValuesCount(3); + assertThat(normalized.getSymbolStatistics(integer3)).distinctValuesCountUnknown(); + assertThat(normalized.getSymbolStatistics(bigint)).distinctValuesCount(5); + assertThat(normalized.getSymbolStatistics(decimal)).distinctValuesCount(1); + assertThat(normalized.getSymbolStatistics(decimal2)).distinctValuesCount(3); + assertThat(normalized.getSymbolStatistics(decimal3)).distinctValuesCount(10); + assertThat(normalized.getSymbolStatistics(double1)).distinctValuesCount(10); + assertThat(normalized.getSymbolStatistics(double2)).distinctValuesCount(1); + } + + private double asStatsValue(Object value, Type type) + { + return new DomainConverter(type, functionRegistry, session).translateToDouble(value).orElse(Double.NaN); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java new file mode 100644 index 0000000000000..4492b3e34ac78 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java @@ -0,0 +1,469 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.NOT_EQUAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; + +@Test(singleThreaded = true) +public class TestComparisonStatsCalculator +{ + private FilterStatsCalculator filterStatsCalculator; + private Session session; + private PlanNodeStatsEstimate standardInputStatistics; + private Map types; + + @BeforeMethod + public void setUp() + throws Exception + { + session = testSessionBuilder().build(); + filterStatsCalculator = new FilterStatsCalculator(MetadataManager.createTestMetadataManager()); + + SymbolStatsEstimate xStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(40.0) + .setLowValue(-10.0) + .setHighValue(10.0) + .setNullsFraction(0.25) + .build(); + SymbolStatsEstimate yStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(20.0) + .setLowValue(0.0) + .setHighValue(5.0) + .setNullsFraction(0.5) + .build(); + SymbolStatsEstimate zStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(5.0) + .setLowValue(-100.0) + .setHighValue(100.0) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate leftOpenStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(15.0) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate rightOpenStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(-15.0) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate unknownRangeStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate emptyRangeStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(0.0) + .setLowValue(NaN) + .setHighValue(NaN) + .setNullsFraction(NaN) + .build(); + standardInputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), xStats) + .addSymbolStatistics(new Symbol("y"), yStats) + .addSymbolStatistics(new Symbol("z"), zStats) + .addSymbolStatistics(new Symbol("leftOpen"), leftOpenStats) + .addSymbolStatistics(new Symbol("rightOpen"), rightOpenStats) + .addSymbolStatistics(new Symbol("unknownRange"), unknownRangeStats) + .addSymbolStatistics(new Symbol("emptyRange"), emptyRangeStats) + .setOutputRowCount(1000.0) + .build(); + + types = ImmutableMap.builder() + .put(new Symbol("x"), DoubleType.DOUBLE) + .put(new Symbol("y"), DoubleType.DOUBLE) + .put(new Symbol("z"), DoubleType.DOUBLE) + .put(new Symbol("leftOpen"), DoubleType.DOUBLE) + .put(new Symbol("rightOpen"), DoubleType.DOUBLE) + .put(new Symbol("unknownRange"), DoubleType.DOUBLE) + .put(new Symbol("emptyRange"), DoubleType.DOUBLE) + .build(); + } + + private PlanNodeStatsAssertion assertCalculate(Expression comparisonExpression) + { + return PlanNodeStatsAssertion.assertThat(filterStatsCalculator.filterStats(standardInputStatistics, comparisonExpression, session, types)); + } + + @Test + public void symbolToLiteralEqualStats() + { + // Simple case + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))) + .outputRowsCount(25.0) // all rows minus nulls divided by distinct values count + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(2.5) + .highValue(2.5) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))) + .outputRowsCount(18.75) // all rows minus nulls divided by distinct values count + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Literal out of symbol range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))) + .outputRowsCount(0.0) // all rows minus nulls divided by distinct values count + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + + // Literal in left open range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))) + .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count + .symbolStats("leftOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(2.5) + .highValue(2.5) + .nullsFraction(0.0); + }); + + // Literal in right open range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))) + .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count + .symbolStats("rightOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(-2.5) + .highValue(-2.5) + .nullsFraction(0.0); + }); + + // Literal in unknown range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(18.0) // all rows minus nulls divided by distinct values count + .symbolStats("unknownRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(0.0) + .highValue(0.0) + .nullsFraction(0.0); + }); + + // Literal in empty range + assertCalculate(new ComparisonExpression(EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(0.0) + .symbolStats("emptyRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void symbolToLiteralNotEqualStats() + { + // Simple case + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))) + .outputRowsCount(475.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(19.0) + .lowValue(0.0) + .highValue(5.0) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))) + .outputRowsCount(731.25) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(39.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Literal out of symbol range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))) + .outputRowsCount(500.0) // all rows minus nulls + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(19.0) + .lowValue(0.0) + .highValue(5.0) + .nullsFraction(0.0); + }); + + // Literal in left open range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))) + .outputRowsCount(882.0) // all rows minus nulls multiplied by ((distinct values - 1) / distinct values) + .symbolStats("leftOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(49.0) + .lowValueUnknown() + .highValue(15.0) + .nullsFraction(0.0); + }); + + // Literal in right open range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))) + .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count + .symbolStats("rightOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(49.0) + .lowValue(-15.0) + .highValueUnknown() + .nullsFraction(0.0); + }); + + // Literal in unknown range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(882.0) // all rows minus nulls divided by distinct values count + .symbolStats("unknownRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(49.0) + .lowValueUnknown() + .highValueUnknown() + .nullsFraction(0.0); + }); + + // Literal in empty range + assertCalculate(new ComparisonExpression(NOT_EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(0.0) + .symbolStats("emptyRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void symbolToLiteralLessThanStats() + { + // Simple case + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))) + .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(10.0) + .lowValue(0.0) + .highValue(2.5) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range (whole range included) + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))) + .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(40.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range (whole range excluded) + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))) + .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(-10.0) + .highValue(-10.0) + .nullsFraction(0.0); + }); + + // Literal range out of symbol range + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("y"), new DoubleLiteral("-10.0"))) + .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + + // Literal in left open range + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))) + .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) + .symbolStats("leftOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(25.0) //(50% heuristic) + .lowValueUnknown() + .highValue(0.0) + .nullsFraction(0.0); + }); + + // Literal in right open range + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))) + .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) + .symbolStats("rightOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(12.5) //(25% heuristic) + .lowValue(-15.0) + .highValue(0.0) + .nullsFraction(0.0); + }); + + // Literal in unknown range + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) + .symbolStats("unknownRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(25.0) // (50% heuristic) + .lowValueUnknown() + .highValue(0.0) + .nullsFraction(0.0); + }); + + // Literal in empty range + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(0.0) + .symbolStats("emptyRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void symbolToLiteralGreaterThanStats() + { + // Simple case + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))) + .outputRowsCount(250.0) // all rows minus nulls times range coverage (50%) + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(10.0) + .lowValue(2.5) + .highValue(5.0) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range (whole range included) + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))) + .outputRowsCount(750.0) // all rows minus nulls times range coverage (100%) + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(40.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Literal on the edge of symbol range (whole range excluded) + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))) + .outputRowsCount(18.75) // all rows minus nulls divided by NDV (one value from edge is included as approximation) + .symbolStats("x", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(1.0) + .lowValue(10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Literal range out of symbol range + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("10.0"))) + .outputRowsCount(0.0) // all rows minus nulls times range coverage (0%) + .symbolStats("y", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + + // Literal in left open range + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))) + .outputRowsCount(225.0) // all rows minus nulls times range coverage (25% - heuristic) + .symbolStats("leftOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(12.5) //(25% heuristic) + .lowValue(0.0) + .highValue(15.0) + .nullsFraction(0.0); + }); + + // Literal in right open range + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))) + .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) + .symbolStats("rightOpen", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(25.0) //(50% heuristic) + .lowValue(0.0) + .highValueUnknown() + .nullsFraction(0.0); + }); + + // Literal in unknown range + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(450.0) // all rows minus nulls times range coverage (50% - heuristic) + .symbolStats("unknownRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(25.0) // (50% heuristic) + .lowValue(0.0) + .highValueUnknown() + .nullsFraction(0.0); + }); + + // Literal in empty range + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))) + .outputRowsCount(0.0) + .symbolStats("emptyRange", symbolAssert -> { + symbolAssert.averageRowSize(4.0) + .distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java new file mode 100644 index 0000000000000..736059fcd3526 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -0,0 +1,447 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.cpuCost; +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Double.isNaN; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestCostCalculator +{ + private static final int NUMBER_OF_NODES = 10; + private final CostCalculator costCalculatorUsingExchanges = new CostCalculatorUsingExchanges(NUMBER_OF_NODES); + private final CostCalculator costCalculatorWithEstimatedExchanges = new CostCalculatorWithEstimatedExchanges(costCalculatorUsingExchanges, NUMBER_OF_NODES); + private Session session = testSessionBuilder().build(); + + @Test + public void testTableScan() + { + TableScanNode tableScan = tableScan("ts", "orderkey"); + + assertCost( + tableScan, + ImmutableMap.of(), + ImmutableMap.of("ts", statsEstimate(1000))) + .cpu(1000) + .memory(0) + .network(0); + assertCostEstimatedExchanges( + tableScan, + ImmutableMap.of(), + ImmutableMap.of("ts", statsEstimate(1000))) + .cpu(1000) + .memory(0) + .network(0); + + assertUnknownCostForUnknownStats(tableScan); + } + + @Test + public void testProject() + { + PlanNode project = project("project", tableScan("ts", "orderkey"), "string", new Cast(new SymbolReference("orderkey"), "STRING")); + Map costs = ImmutableMap.of("ts", cpuCost(1000)); + Map stats = ImmutableMap.of("project", statsEstimate(4000), "ts", statsEstimate(1000)); + + assertCost( + project, + costs, + stats) + .cpu(1000 + 4000) + .memory(0) + .network(0); + + assertCostEstimatedExchanges( + project, + costs, + stats) + .cpu(1000 + 4000) + .memory(0) + .network(0); + + assertUnknownCostForUnknownStats(project); + } + + @Test + public void testRepartitionedJoin() + { + JoinNode join = join("join", + tableScan("ts1", "orderkey"), + tableScan("ts2", "orderkey_0"), + JoinNode.DistributionType.PARTITIONED, + "orderkey", + "orderkey_0"); + + Map costs = ImmutableMap.of( + "ts1", cpuCost(6000), + "ts2", cpuCost(1000)); + + Map stats = ImmutableMap.of( + "join", statsEstimate(12000), + "ts1", statsEstimate(6000), + "ts2", statsEstimate(1000)); + + assertCost( + join, + costs, + stats) + .cpu(12000 + 6000 + 1000 + 6000 + 1000); + + assertCostEstimatedExchanges( + join, + costs, + stats) + .cpu(12000 + 6000 + 1000 + 6000 + 1000 + 6000 + 1000); + + assertUnknownCostForUnknownStats(join); + } + + @Test + public void testReplicatedJoin() + { + JoinNode join = join("join", + tableScan("ts1", "orderkey"), + tableScan("ts2", "orderkey_0"), + JoinNode.DistributionType.REPLICATED, + "orderkey", + "orderkey_0"); + + Map costs = ImmutableMap.of( + "ts1", cpuCost(6000), + "ts2", cpuCost(1000)); + + Map stats = ImmutableMap.of( + "join", statsEstimate(12000), + "ts1", statsEstimate(6000), + "ts2", statsEstimate(1000)); + + assertCost( + join, + costs, + stats) + .cpu(12000 + 6000 + 10000 + 6000 + 1000); + assertCostEstimatedExchanges( + join, + costs, + stats) + .cpu(12000 + 6000 + 10000 + 6000 + 1000); + + assertUnknownCostForUnknownStats(join); + } + + @Test + public void testAggregation() + { + AggregationNode aggregationNode = aggregation("agg", + tableScan("ts", "orderkey")); + + Map costs = ImmutableMap.of("ts", cpuCost(6000)); + Map stats = ImmutableMap.of( + "ts", statsEstimate(6000), + "agg", statsEstimate(8)); + + assertCost(aggregationNode, costs, stats) + .cpu(6000 + 6000); + assertCostEstimatedExchanges(aggregationNode, costs, stats) + .cpu(6000 + 6000 + 6000); + + assertUnknownCostForUnknownStats(aggregationNode); + } + + private CostAssertionBuilder assertCost( + PlanNode node, + Map costs, + Map stats) + { + return new CostAssertionBuilder(costCalculatorUsingExchanges.calculateCumulativeCost( + node, + new FixedLookup(costs, stats), + session, + ImmutableMap.of())); + } + + private CostAssertionBuilder assertCostEstimatedExchanges( + PlanNode node, + Map costs, + Map stats) + { + return new CostAssertionBuilder(costCalculatorWithEstimatedExchanges.calculateCumulativeCost( + node, + new FixedLookup(costs, stats), + session, + ImmutableMap.of())); + } + + private void assertUnknownCostForUnknownStats(PlanNode planNode) + { + new CostAssertionBuilder(costCalculatorUsingExchanges.calculateCumulativeCost( + planNode, + new FixedLookup(id -> UNKNOWN_COST, id -> UNKNOWN_STATS), + session, + ImmutableMap.of())) + .hasUnknownComponents(); + new CostAssertionBuilder(costCalculatorWithEstimatedExchanges.calculateCumulativeCost( + planNode, + new FixedLookup(id -> UNKNOWN_COST, id -> UNKNOWN_STATS), + session, + ImmutableMap.of())) + .hasUnknownComponents(); + } + + private static class CostAssertionBuilder + { + private final PlanNodeCostEstimate actual; + + public CostAssertionBuilder(PlanNodeCostEstimate actual) + { + this.actual = requireNonNull(actual, "actual is null"); + } + + public CostAssertionBuilder network(double value) + { + assertEquals(actual.getNetworkCost(), value, 0.1); + return this; + } + + public CostAssertionBuilder networkUnknown() + { + assertIsNaN(actual.getNetworkCost()); + return this; + } + + public CostAssertionBuilder cpu(double value) + { + assertEquals(actual.getCpuCost(), value, 0.1); + return this; + } + + public CostAssertionBuilder cpuUnknown() + { + assertIsNaN(actual.getCpuCost()); + return this; + } + + public CostAssertionBuilder memory(double value) + { + assertEquals(actual.getMemoryCost(), value, 0.1); + return this; + } + + public CostAssertionBuilder memoryUnknown() + { + assertIsNaN(actual.getMemoryCost()); + return this; + } + + public CostAssertionBuilder hasUnknownComponents() + { + assertTrue(actual.hasUnknownComponents()); + return this; + } + + private void assertIsNaN(double value) + { + assertTrue(isNaN(value), "Expected NaN got " + value); + } + } + + private static PlanNodeStatsEstimate statsEstimate(int outputSizeInBytes) + { + double rowCount = Math.max(outputSizeInBytes / 8, 1); + + return PlanNodeStatsEstimate.builder() + .setOutputRowCount(rowCount) + .addSymbolStatistics( + new Symbol("s"), + SymbolStatsEstimate.builder() + .setAverageRowSize(outputSizeInBytes / rowCount) + .build()) + .build(); + } + + private TableScanNode tableScan(String id, String... symbols) + { + List symbolsList = Arrays.stream(symbols).map(Symbol::new).collect(toImmutableList()); + ImmutableMap.Builder assignments = ImmutableMap.builder(); + + for (Symbol symbol : symbolsList) { + assignments.put(symbol, new TpchColumnHandle("orderkey", BIGINT)); + } + + return new TableScanNode( + new PlanNodeId(id), + new TableHandle(new ConnectorId("tpch"), new TpchTableHandle("local", "orders", 1.0)), + symbolsList, + assignments.build(), + Optional.empty(), + TupleDomain.none(), + null); + } + + private PlanNode project(String id, PlanNode source, String symbol, Expression expression) + { + return new ProjectNode( + new PlanNodeId(id), + source, + Assignments.of(new Symbol(symbol), expression)); + } + + private String symbol(String name) + { + return name; + } + + private AggregationNode aggregation(String id, PlanNode source) + { + AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation( + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), + new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), + Optional.empty()); + + return new AggregationNode( + new PlanNodeId(id), + source, + ImmutableMap.of(new Symbol("count"), aggregation), + ImmutableList.of(source.getOutputSymbols()), + AggregationNode.Step.FINAL, + Optional.empty(), + Optional.empty()); + } + + private JoinNode join(String planNodeId, PlanNode left, PlanNode right, String... symbols) + { + return join(planNodeId, left, right, JoinNode.DistributionType.PARTITIONED, symbols); + } + + /** + * EquiJoinClause is created from symbols in form of: + * symbol[0] = symbol[1] AND symbol[2] = symbol[3] AND ... + */ + private JoinNode join(String planNodeId, PlanNode left, PlanNode right, JoinNode.DistributionType distributionType, String... symbols) + { + checkArgument(symbols.length % 2 == 0); + ImmutableList.Builder criteria = ImmutableList.builder(); + + for (int i = 0; i < symbols.length; i += 2) { + criteria.add(new JoinNode.EquiJoinClause(new Symbol(symbols[i]), new Symbol(symbols[i + 1]))); + } + + return new JoinNode( + new PlanNodeId(planNodeId), + JoinNode.Type.INNER, + left, + right, + criteria.build(), + ImmutableList.builder() + .addAll(left.getOutputSymbols()) + .addAll(right.getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(distributionType)); + } + + private ValuesNode values(int planNodeId, String... symbols) + { + return new ValuesNode( + new PlanNodeId(Integer.toString(planNodeId)), + Arrays.stream(symbols) + .map(Symbol::new) + .collect(toImmutableList()), + ImmutableList.of()); + } + + private class FixedLookup + implements Lookup + { + private Function costs; + private Function stats; + + public FixedLookup(Function costs, Function stats) + { + this.costs = costs; + this.stats = stats; + } + + public FixedLookup(Map costs, Map stats) + { + this(costs::get, stats::get); + } + + @Override + public PlanNode resolve(PlanNode node) + { + throw new UnsupportedOperationException(); + } + + @Override + public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) + { + return stats.apply(node.getId().toString()); + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode node, Session session, Map types) + { + return costs.apply(node.getId().toString()); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostComparator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostComparator.java new file mode 100644 index 0000000000000..bb09b1a35b79e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostComparator.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import org.testng.annotations.Test; + +import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.ZERO_COST; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +public class TestCostComparator +{ + @Test + public void testCpuWeight() + { + new CostComparisonAssertion(1.0, 0.0, 0.0) + .smaller(200, 200, 200) + .larger(1000, 100, 100) + .assertCompare(); + } + + @Test + public void testMemoryWeight() + { + new CostComparisonAssertion(0.0, 1.0, 0.0) + .smaller(200, 200, 200) + .larger(100, 1000, 100) + .assertCompare(); + } + + @Test + public void testNetworkWeight() + { + new CostComparisonAssertion(0.0, 0.0, 1.0) + .smaller(200, 200, 200) + .larger(100, 100, 1000) + .assertCompare(); + } + + @Test + public void testAllWeights() + { + new CostComparisonAssertion(1.0, 1.0, 1.0) + .smaller(333, 333, 333) + .larger(200, 300, 500) + .assertCompare(); + + new CostComparisonAssertion(1.0, 1000.0, 1.0) + .smaller(300, 299, 300) + .larger(100, 300, 100) + .assertCompare(); + } + + @Test + public void testUnknownCost() + { + CostComparator costComparator = new CostComparator(1.0, 1.0, 1.0); + Session session = testSessionBuilder().build(); + assertThrows(IllegalArgumentException.class, () -> costComparator.compare(session, ZERO_COST, UNKNOWN_COST)); + assertThrows(IllegalArgumentException.class, () -> costComparator.compare(session, UNKNOWN_COST, ZERO_COST)); + assertThrows(IllegalArgumentException.class, () -> costComparator.compare(session, UNKNOWN_COST, UNKNOWN_COST)); + } + + private static class CostComparisonAssertion + { + private final PlanNodeCostEstimate.Builder smaller = PlanNodeCostEstimate.builder(); + private final PlanNodeCostEstimate.Builder larger = PlanNodeCostEstimate.builder(); + private final CostComparator costComparator; + private Session session = testSessionBuilder().build(); + + public CostComparisonAssertion(double cpuWeight, double memoryWeight, double networkWeight) + { + costComparator = new CostComparator(cpuWeight, memoryWeight, networkWeight); + } + + public void assertCompare() + { + assertTrue(costComparator.compare(session, smaller.build(), larger.build()) < 0, + "smaller < larger is false"); + + assertTrue(costComparator.compare(session, larger.build(), smaller.build()) > 0, + "larger > smaller is false"); + } + + public CostComparisonAssertion smaller(double cpu, double memory, double network) + { + smaller.setCpuCost(cpu).setMemoryCost(memory).setNetworkCost(network); + return this; + } + + public CostComparisonAssertion larger(double cpu, double memory, double network) + { + larger.setCpuCost(cpu).setMemoryCost(memory).setNetworkCost(network); + return this; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestEnsureStatsMatchOutput.java b/presto-main/src/test/java/com/facebook/presto/cost/TestEnsureStatsMatchOutput.java new file mode 100644 index 0000000000000..8084aa9c326df --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestEnsureStatsMatchOutput.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.cost.SymbolStatsAssertion.assertThat; +import static java.util.Collections.emptyList; +import static org.testng.Assert.assertEquals; + +public class TestEnsureStatsMatchOutput +{ + @Test + public void test() + { + Symbol a = new Symbol("a"); + Symbol b = new Symbol("b"); + Symbol c = new Symbol("c"); + + PlanNodeStatsEstimate estimate = PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(a, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) + .addSymbolStatistics(b, SymbolStatsEstimate.builder().setDistinctValuesCount(20).build()) + .build(); + + ComposableStatsCalculator.Normalizer normalizer = new EnsureStatsMatchOutput(); + PlanNode node = new ValuesNode(new PlanNodeId(""), ImmutableList.of(a, c), emptyList()); + PlanNodeStatsEstimate normalized = normalizer.normalize(node, estimate, null); + + assertEquals(normalized.getSymbolsWithKnownStatistics(), ImmutableList.of(a, c)); + assertThat(normalized.getSymbolStatistics(a)).distinctValuesCount(20); + assertThat(normalized.getSymbolStatistics(c)).distinctValuesCountUnknown(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java new file mode 100644 index 0000000000000..2d7a052bc8545 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestExchangeStatsRule.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static java.util.Collections.emptyList; + +public class TestExchangeStatsRule +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testExchange() + { + // test cases origins from TestUnionStatsRule + // i11, i21 have separated low/high ranges and known all stats, unknown distinct values count + // i12, i22 have overlapping low/high ranges and known all stats, unknown nulls fraction + // i13, i23 have some unknown range stats + // i14, i24 have the same stats + + tester.assertStatsFor(pb -> pb + .exchange(exchangeBuilder -> exchangeBuilder + .addInputsSet(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT)) + .addInputsSet(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT)) + .fixedHashDistributionParitioningScheme( + ImmutableList.of(pb.symbol("o1", BIGINT), pb.symbol("o2", BIGINT), pb.symbol("o3", BIGINT), pb.symbol("o4", BIGINT)), + emptyList()) + .addSource(pb.values(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT))) + .addSource(pb.values(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT))))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .addSymbolStatistics(new Symbol("i12"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setDistinctValuesCount(4) + .setNullsFraction(0) + .build()) + .addSymbolStatistics(new Symbol("i13"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addSymbolStatistics(new Symbol("i14"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(11) + .setHighValue(20) + .setNullsFraction(0.4) + .build()) + .addSymbolStatistics(new Symbol("i22"), SymbolStatsEstimate.builder() + .setLowValue(2) + .setHighValue(7) + .setDistinctValuesCount(3) + .build()) + .addSymbolStatistics(new Symbol("i23"), SymbolStatsEstimate.builder() + .setDistinctValuesCount(6) + .setNullsFraction(0.2) + .build()) + .addSymbolStatistics(new Symbol("i24"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .build()) + .check(check -> check + .outputRowsCount(30) + .symbolStats("o1", assertion -> assertion + .lowValue(1) + .highValue(20) + .distinctValuesCountUnknown() + .nullsFraction(0.3666666)) + .symbolStats("o2", assertion -> assertion + .lowValue(0) + .highValue(7) + .distinctValuesCount(7) + .nullsFractionUnknown()) + .symbolStats("o3", assertion -> assertion + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(10.0) + .nullsFraction(0.1666667)) + .symbolStats("o4", assertion -> assertion + .lowValue(10) + .highValue(15) + .distinctValuesCount(6) + .nullsFraction(0.1))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java new file mode 100644 index 0000000000000..394c368a0fda9 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java @@ -0,0 +1,506 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.BetweenPredicate; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InListExpression; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; +import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.or; +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; + +@Test(singleThreaded = true) +public class TestFilterStatsCalculator +{ + private FilterStatsCalculator statsCalculator; + private PlanNodeStatsEstimate standardInputStatistics; + private Map standardTypes; + private Session session; + + @BeforeMethod + public void setUp() + throws Exception + { + SymbolStatsEstimate xStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(40.0) + .setLowValue(-10.0) + .setHighValue(10.0) + .setNullsFraction(0.25) + .build(); + SymbolStatsEstimate yStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(20.0) + .setLowValue(0.0) + .setHighValue(5.0) + .setNullsFraction(0.5) + .build(); + SymbolStatsEstimate zStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(5.0) + .setLowValue(-100.0) + .setHighValue(100.0) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate leftOpenStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(15.0) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate rightOpenStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(-15.0) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate unknownRangeStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(50.0) + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.1) + .build(); + SymbolStatsEstimate emptyRangeStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(0.0) + .setLowValue(NaN) + .setHighValue(NaN) + .setNullsFraction(NaN) + .build(); + standardInputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), xStats) + .addSymbolStatistics(new Symbol("y"), yStats) + .addSymbolStatistics(new Symbol("z"), zStats) + .addSymbolStatistics(new Symbol("leftOpen"), leftOpenStats) + .addSymbolStatistics(new Symbol("rightOpen"), rightOpenStats) + .addSymbolStatistics(new Symbol("unknownRange"), unknownRangeStats) + .addSymbolStatistics(new Symbol("emptyRange"), emptyRangeStats) + .setOutputRowCount(1000.0) + .build(); + + standardTypes = ImmutableMap.builder() + .put(new Symbol("x"), DoubleType.DOUBLE) + .put(new Symbol("y"), DoubleType.DOUBLE) + .put(new Symbol("z"), DoubleType.DOUBLE) + .put(new Symbol("leftOpen"), DoubleType.DOUBLE) + .put(new Symbol("rightOpen"), DoubleType.DOUBLE) + .put(new Symbol("unknownRange"), DoubleType.DOUBLE) + .put(new Symbol("emptyRange"), DoubleType.DOUBLE).build(); + + session = testSessionBuilder().build(); + statsCalculator = new FilterStatsCalculator(MetadataManager.createTestMetadataManager()); + } + + public PlanNodeStatsAssertion assertExpression(Expression expression) + { + return PlanNodeStatsAssertion.assertThat(statsCalculator.filterStats(standardInputStatistics, + expression, + session, + standardTypes)); + } + + @Test + public void testBooleanLiteralStas() + { + assertExpression(TRUE_LITERAL).equalTo(standardInputStatistics); + + assertExpression(FALSE_LITERAL).outputRowsCount(0.0) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("y", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("z", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("leftOpen", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("rightOpen", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("emptyRange", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }) + .symbolStats("unknownRange", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void testOrStats() + { + Expression leftExpression = new ComparisonExpression(ComparisonExpressionType.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("0.0")); + Expression rightExpression = new ComparisonExpression(ComparisonExpressionType.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("-7.5")); + + assertExpression(or(leftExpression, rightExpression)) + .outputRowsCount(375) + .symbolStats(new Symbol("x"), symbolAssert -> + symbolAssert.averageRowSize(4.0) + .lowValue(-10.0) + .highValue(0.0) + .distinctValuesCount(20.0) + .nullsFraction(0.0) + ); + + Expression leftExpressionSingleValue = new ComparisonExpression(ComparisonExpressionType.EQUAL, new SymbolReference("x"), new DoubleLiteral("0.0")); + Expression rightExpressionSingleValue = new ComparisonExpression(ComparisonExpressionType.EQUAL, new SymbolReference("x"), new DoubleLiteral("-7.5")); + + assertExpression(or(leftExpressionSingleValue, rightExpressionSingleValue)) + .outputRowsCount(37.5) + .symbolStats(new Symbol("x"), symbolAssert -> + symbolAssert.averageRowSize(4.0) + .lowValue(-7.5) + .highValue(0.0) + .distinctValuesCount(2.0) + .nullsFraction(0.0) + ); + } + + @Test + public void testAndStats() + { + Expression leftExpression = new ComparisonExpression(ComparisonExpressionType.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("0.0")); + Expression rightExpression = new ComparisonExpression(ComparisonExpressionType.GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("-7.5")); + + assertExpression(and(leftExpression, rightExpression)) + .outputRowsCount(281.25) + .symbolStats(new Symbol("x"), symbolAssert -> + symbolAssert.averageRowSize(4.0) + .lowValue(-7.5) + .highValue(0.0) + .distinctValuesCount(15.0) + .nullsFraction(0.0) + ); + } + + @Test + public void testNotStats() + { + Expression innerExpression = new ComparisonExpression(ComparisonExpressionType.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("0.0")); + + assertExpression(new NotExpression(innerExpression)) + .outputRowsCount(625) // FIXME - nulls shouldn't be restored + .symbolStats(new Symbol("x"), symbolAssert -> + symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(10.0) + .distinctValuesCount(20.0) + .nullsFraction(0.4) // FIXME - nulls shouldn't be restored + ); + } + + @Test + public void testIsNullFilter() + { + Expression isNullPredicate = new IsNullPredicate(new SymbolReference("x")); + assertExpression(isNullPredicate) + .outputRowsCount(250.0) + .symbolStats(new Symbol("x"), symbolStats -> { + symbolStats.distinctValuesCount(0) + .emptyRange() + .nullsFraction(1.0); + }); + + Expression isNullEmptyRangePredicate = new IsNullPredicate(new SymbolReference("emptyRange")); + assertExpression(isNullEmptyRangePredicate) + .outputRowsCount(1000.0) + .symbolStats(new Symbol("emptyRange"), symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void testIsNotNullFilter() + { + Expression isNotNullPredicate = new IsNotNullPredicate(new SymbolReference("x")); + assertExpression(isNotNullPredicate) + .outputRowsCount(750.0) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(40.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + Expression isNotNullEmptyRangePredicate = new IsNotNullPredicate(new SymbolReference("emptyRange")); + assertExpression(isNotNullEmptyRangePredicate) + .outputRowsCount(0.0) + .symbolStats("emptyRange", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + } + + @Test + public void testBetweenOperatorFilter() + { + // Only right side cut + Expression betweenPredicateRightCut = new BetweenPredicate(new SymbolReference("x"), new DoubleLiteral("7.5"), new DoubleLiteral("12.0")); + assertExpression(betweenPredicateRightCut) + .outputRowsCount(93.75) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(5.0) + .lowValue(7.5) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Only left side cut + Expression betweenPredicateLeftCut = new BetweenPredicate(new SymbolReference("x"), new DoubleLiteral("-12.0"), new DoubleLiteral("-7.5")); + assertExpression(betweenPredicateLeftCut) + .outputRowsCount(93.75) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(5.0) + .lowValue(-10) + .highValue(-7.5) + .nullsFraction(0.0); + }); + + // Both sides cut + Expression betweenPredicateBothSidesCut = new BetweenPredicate(new SymbolReference("x"), new DoubleLiteral("-2.5"), new DoubleLiteral("2.5")); + assertExpression(betweenPredicateBothSidesCut) + .outputRowsCount(187.5) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(10.0) + .lowValue(-2.5) + .highValue(2.5) + .nullsFraction(0.0); + }); + + // Both sides cut unknownRange + Expression betweenPredicateBothSidesCutUnknownRange = new BetweenPredicate(new SymbolReference("unknownRange"), new DoubleLiteral("2.72"), new DoubleLiteral("3.14")); + assertExpression(betweenPredicateBothSidesCutUnknownRange) + .outputRowsCount(112.5) + .symbolStats("unknownRange", symbolStats -> { + symbolStats.distinctValuesCount(6.25) + .lowValue(2.72) + .highValue(3.14) + .nullsFraction(0.0); + }); + + // Left side open, cut on open side + Expression betweenPredicateCutOnLeftOpenSide = new BetweenPredicate(new SymbolReference("leftOpen"), new DoubleLiteral("-10.0"), new DoubleLiteral("10.0")); + assertExpression(betweenPredicateCutOnLeftOpenSide) + .outputRowsCount(180.0) + .symbolStats("leftOpen", symbolStats -> { + symbolStats.distinctValuesCount(10.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Right side open, cut on open side + Expression betweenPredicateCutOnRightOpenSide = new BetweenPredicate(new SymbolReference("rightOpen"), new DoubleLiteral("-10.0"), new DoubleLiteral("10.0")); + assertExpression(betweenPredicateCutOnRightOpenSide) + .outputRowsCount(180.0) + .symbolStats("rightOpen", symbolStats -> { + symbolStats.distinctValuesCount(10.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0); + }); + + // Filter all + Expression betweenPredicateFilterAll = new BetweenPredicate(new SymbolReference("y"), new DoubleLiteral("27.5"), new DoubleLiteral("107.0")); + assertExpression(betweenPredicateFilterAll) + .outputRowsCount(0.0) + .symbolStats("y", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + + // Filter nothing + Expression betweenPredicateFilterNothing = new BetweenPredicate(new SymbolReference("y"), new DoubleLiteral("-100.0"), new DoubleLiteral("100.0")); + assertExpression(betweenPredicateFilterNothing) + .outputRowsCount(500.0) + .symbolStats("y", symbolStats -> { + symbolStats.distinctValuesCount(20.0) + .lowValue(0.0) + .highValue(5.0) + .nullsFraction(0.0); + }); + + // Filter non exact match + Expression betweenPredicateFilterNothingExact = new BetweenPredicate(new SymbolReference("z"), new DoubleLiteral("-100.0"), new DoubleLiteral("100.0")); + assertExpression(betweenPredicateFilterNothingExact) + .outputRowsCount(900.0) + .symbolStats("z", symbolStats -> { + symbolStats.distinctValuesCount(5.0) + .lowValue(-100.0) + .highValue(100.0) + .nullsFraction(0.0); + }); + } + + @Test + public void testInPredicateFilter() + { + // One value in range + Expression singleValueInIn = new InPredicate(new SymbolReference("x"), new InListExpression(ImmutableList.of(new DoubleLiteral("7.5")))); + assertExpression(singleValueInIn) + .outputRowsCount(18.75) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(1.0) + .lowValue(7.5) + .highValue(7.5) + .nullsFraction(0.0); + }); + + // Multiple values in range + Expression multipleValuesInIn = new InPredicate(new SymbolReference("x"), new InListExpression( + ImmutableList.of(new DoubleLiteral("1.5"), + new DoubleLiteral("2.5"), + new DoubleLiteral("7.5")))); + assertExpression(multipleValuesInIn) + .outputRowsCount(56.25) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(3.0) + .lowValue(1.5) + .highValue(7.5) + .nullsFraction(0.0); + }); + + // Multiple values some in some out of range + Expression multipleValuesInInSomeOutOfRange = new InPredicate(new SymbolReference("x"), new InListExpression( + ImmutableList.of(new DoubleLiteral("-42.0"), + new DoubleLiteral("1.5"), + new DoubleLiteral("2.5"), + new DoubleLiteral("7.5"), + new DoubleLiteral("314.0")))); + assertExpression(multipleValuesInInSomeOutOfRange) + .outputRowsCount(56.25) + .symbolStats("x", symbolStats -> { + symbolStats.distinctValuesCount(3.0) + .lowValue(1.5) + .highValue(7.5) + .nullsFraction(0.0); + }); + + // Multiple values in unknown range + Expression multipleValuesInUnknownRange = new InPredicate(new SymbolReference("unknownRange"), new InListExpression( + ImmutableList.of(new DoubleLiteral("-42.0"), + new DoubleLiteral("1.5"), + new DoubleLiteral("2.5"), + new DoubleLiteral("7.5"), + new DoubleLiteral("314.0")))); + assertExpression(multipleValuesInUnknownRange) + .outputRowsCount(90.0) + .symbolStats("unknownRange", symbolStats -> { + symbolStats.distinctValuesCount(5.0) + .lowValue(-42.0) + .highValue(314.0) + .nullsFraction(0.0); + }); + + // No value in range + Expression noValuesInRange = new InPredicate(new SymbolReference("y"), new InListExpression( + ImmutableList.of(new DoubleLiteral("-42.0"), + new DoubleLiteral("6.0"), + new DoubleLiteral("31.1341"), + new DoubleLiteral("-0.000000002"), + new DoubleLiteral("314.0")))); + assertExpression(noValuesInRange) + .outputRowsCount(0.0) + .symbolStats("y", symbolStats -> { + symbolStats.distinctValuesCount(0.0) + .emptyRange() + .nullsFraction(1.0); + }); + + // More values in range than distinct values + Expression ndvOverflowInIn = new InPredicate(new SymbolReference("z"), new InListExpression( + ImmutableList.of(new DoubleLiteral("-1.0"), + new DoubleLiteral("3.14"), + new DoubleLiteral("0.0"), + new DoubleLiteral("1.0"), + new DoubleLiteral("2.0"), + new DoubleLiteral("3.0"), + new DoubleLiteral("4.0"), + new DoubleLiteral("5.0"), + new DoubleLiteral("6.0"), + new DoubleLiteral("7.0"), + new DoubleLiteral("8.0"), + new DoubleLiteral("-2.0")))); + assertExpression(ndvOverflowInIn) + .outputRowsCount(900.0) + .symbolStats("z", symbolStats -> { + symbolStats.distinctValuesCount(5.0) + .lowValue(-2.0) + .highValue(8.0) + .nullsFraction(0.0); + }); + + // Values in weird order + Expression ndvOverflowInNotSortedValues = new InPredicate(new SymbolReference("z"), new InListExpression( + ImmutableList.of(new DoubleLiteral("-1.0"), + new DoubleLiteral("1.0"), + new DoubleLiteral("0.0")))); + assertExpression(ndvOverflowInNotSortedValues) + .outputRowsCount(540.0) + .symbolStats("z", symbolStats -> { + symbolStats.distinctValuesCount(3.0) + .lowValue(-1.0) + .highValue(1.0) + .nullsFraction(0.0); + }); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestIntersectStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestIntersectStatsRule.java new file mode 100644 index 0000000000000..c44f7d126bbc3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestIntersectStatsRule.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; + +public class TestIntersectStatsRule +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testIntersectWhenRangesAreOverlapping() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(10) + .setDistinctValuesCount(8) + .setNullsFraction(0.2) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(5) + .setHighValue(15) + .setNullsFraction(0.4) + .setDistinctValuesCount(4) + .build()) + .build()) + .check(check -> check + .outputRowsCount(5) + .symbolStats("o1", assertion -> assertion + .lowValue(5) + .highValue(10) + .dataSizeUnknown() + .distinctValuesCount(4) + .nullsFraction(0.2))); + } + + @Test + public void testIntersectWhenRangesAreSeparated() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(11) + .setHighValue(20) + .setNullsFraction(0.4) + .build()) + .build()) + .check(check -> check + .outputRowsCount(1) + .symbolStats("o1", assertion -> assertion + .emptyRange() + .dataSizeUnknown() + .distinctValuesCount(0) + .nullsFraction(1))); + } + + @Test + public void testIntersectWhenLeftHasUnknownRange() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(11) + .setHighValue(20) + .setNullsFraction(0.4) + .build()) + .build()) + .check(check -> check + .outputRowsCount(2.25) + .symbolStats("o1", assertion -> assertion + .lowValue(11) + .highValue(20) + .dataSizeUnknown() + .distinctValuesCount(1.25) + .nullsFraction(0.44444444))); + } + + @Test + public void testIntersectWhenRightIsUnknown() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.UNKNOWN_STATS) + .build()) + .check(check -> check + .outputRowsCount(6) + .symbolStats("o1", assertion -> assertion + .lowValue(1) + .highValue(10) + .distinctValuesCount(5) + .nullsFraction(0.166666667))); + } + + @Test + public void testIntersectWhenNullFractionsAreUnknown() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(10) + .setDistinctValuesCount(8) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(5) + .setHighValue(15) + .setDistinctValuesCount(4) + .build()) + .build()) + .check(check -> check + .outputRowsCount(5) + .symbolStats("o1", assertion -> assertion + .lowValue(5) + .highValue(10) + .dataSizeUnknown() + .distinctValuesCount(4) + .nullsFraction(0.2))); + } + + @Test + public void testIntersectWithoutNulls() + { + tester.assertStatsFor(pb -> pb + .intersect( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT)), + pb.values(pb.symbol("i21", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(10) + .setDistinctValuesCount(8) + .setNullsFraction(0) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(5) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0) + .build()) + .build()) + .check(check -> check + .outputRowsCount(4) + .symbolStats("o1", assertion -> assertion + .lowValue(5) + .highValue(10) + .dataSizeUnknown() + // TODO DVC should be 2 as right side has two times lower values density + .distinctValuesCount(4) + .nullsFraction(0))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java new file mode 100644 index 0000000000000..c5d2c88c979c7 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java @@ -0,0 +1,220 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.cost.PlanNodeStatsAssertion.assertThat; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; +import static java.lang.Double.NaN; + +public class TestJoinStatsRule +{ + private static final String LEFT_JOIN_COLUMN = "left_join_column"; + private static final String RIGHT_JOIN_COLUMN = "right_join_column"; + private static final String LEFT_OTHER_COLUMN = "left_column"; + private static final String RIGHT_OTHER_COLUMN = "right_column"; + + private static final double LEFT_ROWS_COUNT = 500.0; + private static final double RIGHT_ROWS_COUNT = 1000.0; + private static final double TOTAL_ROWS_COUNT = LEFT_ROWS_COUNT + RIGHT_ROWS_COUNT; + private static final double LEFT_JOIN_COLUMN_NULLS = 0.3; + private static final double LEFT_JOIN_COLUMN_NON_NULLS = 0.7; + private static final int LEFT_JOIN_COLUMN_NDV = 20; + private static final double RIGHT_JOIN_COLUMN_NULLS = 0.6; + private static final double RIGHT_JOIN_COLUMN_NON_NULLS = 0.4; + private static final int RIGHT_JOIN_COLUMN_NDV = 15; + + private static final SymbolStatistics LEFT_OTHER_COLUMN_STATS = + symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, 0.42, 1); + private static final SymbolStatistics RIGHT_OTHER_COLUMN_STATS = + symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, 0.24, 1); + private static final PlanNodeStatsEstimate LEFT_STATS = planNodeStats(LEFT_ROWS_COUNT, + symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, LEFT_JOIN_COLUMN_NULLS, LEFT_JOIN_COLUMN_NDV), + LEFT_OTHER_COLUMN_STATS); + private static final PlanNodeStatsEstimate RIGHT_STATS = planNodeStats(RIGHT_ROWS_COUNT, + symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, RIGHT_JOIN_COLUMN_NULLS, RIGHT_JOIN_COLUMN_NDV), + RIGHT_OTHER_COLUMN_STATS); + + private static final JoinStatsRule JOIN_STATS_RULE = new JoinStatsRule(new FilterStatsCalculator(createTestMetadataManager())); + + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + throws Exception + { + tester = new StatsCalculatorTester(); + } + + @Test + public void testStatsForInnerJoin() + throws Exception + { + double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS; + PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, + symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), + LEFT_OTHER_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS); + + assertJoinStats(INNER, LEFT_STATS, RIGHT_STATS, innerJoinStats); + } + + @Test + public void testStatsForLeftAntiJoin() + { + PlanNodeStatsEstimate antiJoinStats = planNodeStats(LEFT_ROWS_COUNT * (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), + symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 5.0, LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4), 5), + LEFT_OTHER_COLUMN_STATS); + + assertThat(JOIN_STATS_RULE.calculateAntiJoinStats( + Optional.empty(), + ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol(LEFT_JOIN_COLUMN), new Symbol(RIGHT_JOIN_COLUMN))), + LEFT_STATS, RIGHT_STATS)).equalTo(antiJoinStats); + } + + @Test + public void testStatsForRightAntiJoin() + { + PlanNodeStatsEstimate antiJoinStats = planNodeStats(RIGHT_ROWS_COUNT * RIGHT_JOIN_COLUMN_NULLS, + symbolStatistics(RIGHT_JOIN_COLUMN, NaN, NaN, 1.0, 0), + RIGHT_OTHER_COLUMN_STATS); + + assertThat(JOIN_STATS_RULE.calculateAntiJoinStats( + Optional.empty(), + ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol(RIGHT_JOIN_COLUMN), new Symbol(LEFT_JOIN_COLUMN))), + RIGHT_STATS, LEFT_STATS)).equalTo(antiJoinStats); + } + + @Test + public void testStatsForLeftAndRightJoin() + { + double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS; + double antiJoinRowCount = LEFT_ROWS_COUNT * (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4); + double antiJoinColumnNulls = LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4); + double totalRowCount = innerJoinRowCount + antiJoinRowCount; + + PlanNodeStatsEstimate leftJoinStats = planNodeStats( + totalRowCount, + symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, antiJoinColumnNulls * antiJoinRowCount / totalRowCount, LEFT_JOIN_COLUMN_NDV), + LEFT_OTHER_COLUMN_STATS, + symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, antiJoinRowCount / totalRowCount, RIGHT_JOIN_COLUMN_NDV), + symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * innerJoinRowCount + antiJoinRowCount) / totalRowCount, 1)); + + assertJoinStats(LEFT, LEFT_STATS, RIGHT_STATS, leftJoinStats); + assertJoinStats(RIGHT, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_STATS, LEFT_STATS, leftJoinStats); + } + + @Test + public void testStatsForFullJoin() + { + double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS; + double leftAntiJoinRowCount = LEFT_ROWS_COUNT * (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4); + double leftAntiJoinColumnNulls = LEFT_JOIN_COLUMN_NULLS / (LEFT_JOIN_COLUMN_NULLS + LEFT_JOIN_COLUMN_NON_NULLS / 4); + double rightAntiJoinRowCount = RIGHT_ROWS_COUNT * RIGHT_JOIN_COLUMN_NULLS; + double rightAntiJoinColumnNulls = 1.0; + double totalRowCount = innerJoinRowCount + leftAntiJoinRowCount + rightAntiJoinRowCount; + + PlanNodeStatsEstimate leftJoinStats = planNodeStats( + totalRowCount, + symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (leftAntiJoinColumnNulls * leftAntiJoinRowCount + rightAntiJoinRowCount) / totalRowCount, LEFT_JOIN_COLUMN_NDV), + symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * (innerJoinRowCount + leftAntiJoinRowCount) + rightAntiJoinRowCount) / totalRowCount, 1), + symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, (rightAntiJoinColumnNulls * rightAntiJoinRowCount + leftAntiJoinRowCount) / totalRowCount, RIGHT_JOIN_COLUMN_NDV), + symbolStatistics(RIGHT_OTHER_COLUMN, 24, 24, (0.24 * (innerJoinRowCount + rightAntiJoinRowCount) + leftAntiJoinRowCount) / totalRowCount, 1)); + + assertJoinStats(FULL, LEFT_STATS, RIGHT_STATS, leftJoinStats); + } + + @Test + public void testAddAntiJoinStats() + { + PlanNodeStatsEstimate statsToAdd = planNodeStats(RIGHT_ROWS_COUNT, + symbolStatistics(LEFT_JOIN_COLUMN, -5.0, 5.0, 0.2, 5)); + + PlanNodeStatsEstimate addedStats = planNodeStats(TOTAL_ROWS_COUNT, + symbolStatistics(LEFT_JOIN_COLUMN, -5.0, 20.0, (LEFT_ROWS_COUNT * LEFT_JOIN_COLUMN_NULLS + RIGHT_ROWS_COUNT * 0.2) / TOTAL_ROWS_COUNT, 25), + symbolStatistics(LEFT_OTHER_COLUMN, 42, 42, (0.42 * LEFT_ROWS_COUNT + RIGHT_ROWS_COUNT) / TOTAL_ROWS_COUNT, 1)); + + assertThat(JOIN_STATS_RULE.addAntiJoinStats(LEFT_STATS, statsToAdd, ImmutableSet.of(new Symbol(LEFT_JOIN_COLUMN)))).equalTo(addedStats); + } + + private void assertJoinStats(JoinNode.Type joinType, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate resultStats) + { + assertJoinStats(joinType, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, leftStats, rightStats, resultStats); + } + + private void assertJoinStats(JoinNode.Type joinType, String leftJoinColumn, String leftOtherColumn, String rightJoinColumn, String rightOtherColumn, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate resultStats) + { + tester.assertStatsFor(pb -> { + Symbol leftJoinColumnSymbol = pb.symbol(leftJoinColumn, BIGINT); + Symbol rightJoinColumnSymbol = pb.symbol(rightJoinColumn, DOUBLE); + Symbol leftOtherColumnSymbol = pb.symbol(leftOtherColumn, BIGINT); + Symbol rightOtherColumnSymbol = pb.symbol(rightOtherColumn, DOUBLE); + return pb + .join(joinType, pb.values(leftJoinColumnSymbol, leftOtherColumnSymbol), + pb.values(rightJoinColumnSymbol, rightOtherColumnSymbol), + new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)); + }).withSourceStats(0, leftStats) + .withSourceStats(1, rightStats) + .check(stats -> stats.equalTo(resultStats)); + } + + private static PlanNodeStatsEstimate planNodeStats(double rowCount, SymbolStatistics... symbolStatistics) + { + PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder() + .setOutputRowCount(rowCount); + for (SymbolStatistics symbolStatistic : symbolStatistics) { + builder.addSymbolStatistics(symbolStatistic.symbol, symbolStatistic.estimate); + } + return builder.build(); + } + + private static SymbolStatistics symbolStatistics(String symbolName, double low, double high, double nullsFraction, double ndv) + { + return new SymbolStatistics( + new Symbol(symbolName), + SymbolStatsEstimate.builder() + .setLowValue(low) + .setHighValue(high) + .setNullsFraction(nullsFraction) + .setDistinctValuesCount(ndv) + .build()); + } + + private static class SymbolStatistics + { + final Symbol symbol; + final SymbolStatsEstimate estimate; + + SymbolStatistics(Symbol symbol, SymbolStatsEstimate estimate) + { + this.symbol = symbol; + this.estimate = estimate; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java new file mode 100644 index 0000000000000..04f2ac6b980aa --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestOutputNodeStats.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static java.lang.Double.POSITIVE_INFINITY; + +public class TestOutputNodeStats +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testStatsForOutputNode() + throws Exception + { + PlanNodeStatsEstimate stats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics( + new Symbol("a"), + SymbolStatsEstimate.builder() + .setNullsFraction(0.3) + .setLowValue(1) + .setHighValue(30) + .setDistinctValuesCount(20) + .build()) + .addSymbolStatistics( + new Symbol("b"), + SymbolStatsEstimate.builder() + .setNullsFraction(0.6) + .setLowValue(13.5) + .setHighValue(POSITIVE_INFINITY) + .setDistinctValuesCount(40) + .build()) + .build(); + + tester.assertStatsFor(pb -> pb + .output(outputBuilder -> { + Symbol a = pb.symbol("a", BIGINT); + Symbol b = pb.symbol("b", DOUBLE); + outputBuilder + .source(pb.values(a, b)) + .column(a, "a1") + .column(a, "a2") + .column(b, "b"); + })) + .withSourceStats(stats) + .check(outputStats -> outputStats.equalTo(stats)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java new file mode 100644 index 0000000000000..350cf7e17d686 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java @@ -0,0 +1,384 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.util.Collections.emptyMap; + +public class TestScalarStatsCalculator +{ + private ScalarStatsCalculator calculator; + private Session session; + private final SqlParser sqlParser = new SqlParser(); + + @BeforeMethod + public void setUp() + throws Exception + { + calculator = new ScalarStatsCalculator(MetadataManager.createTestMetadataManager()); + session = testSessionBuilder().build(); + } + + @Test + public void testLiteral() + { + assertCalculate(new DoubleLiteral("7.5")) + .distinctValuesCount(1.0) + .lowValue(7.5) + .highValue(7.5) + .nullsFraction(0.0); + + assertCalculate(new DecimalLiteral("75.5")) + .distinctValuesCount(1.0) + .lowValue(75.5) + .highValue(75.5) + .nullsFraction(0.0); + + assertCalculate(new StringLiteral("blah")) + .distinctValuesCount(1.0) + .lowValueUnknown() + .highValueUnknown() + .nullsFraction(0.0); + + assertCalculate(new NullLiteral()) + .distinctValuesCount(0.0) + .lowValueUnknown() + .highValueUnknown() + .nullsFraction(1.0); + } + + @Test + public void testSymbolReference() + { + SymbolStatsEstimate xStats = SymbolStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(2.0) + .build(); + PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), xStats) + .build(); + + assertCalculate(expression("x"), inputStatistics).isEqualTo(xStats); + assertCalculate(expression("y"), inputStatistics).isEqualTo(SymbolStatsEstimate.UNKNOWN_STATS); + } + + @Test + public void testCastDoubleToBigint() + { + Map types = ImmutableMap.of(new Symbol("a"), DOUBLE); + + PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .setNullsFraction(0.3) + .setLowValue(1.6) + .setHighValue(17.3) + .setDistinctValuesCount(10) + .setAverageRowSize(2.0) + .build()) + .build(); + + assertCalculate(new Cast(new SymbolReference("a"), "bigint"), inputStatistics, types) + .lowValue(2.0) + .highValue(17.0) + .distinctValuesCount(10) + .nullsFraction(0.3) + .dataSizeUnknown(); + } + + @Test + public void testCastDoubleToShortRange() + { + Map types = ImmutableMap.of(new Symbol("a"), DOUBLE); + + PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .setNullsFraction(0.3) + .setLowValue(1.6) + .setHighValue(3.3) + .setDistinctValuesCount(10) + .setAverageRowSize(2.0) + .build()) + .build(); + + assertCalculate(new Cast(new SymbolReference("a"), "bigint"), inputStatistics, types) + .lowValue(2.0) + .highValue(3.0) + .distinctValuesCount(2) + .nullsFraction(0.3) + .dataSizeUnknown(); + } + + @Test + public void testCastDoubleToShortRangeUnknownDistinctValuesCount() + { + Map types = ImmutableMap.of(new Symbol("a"), DOUBLE); + + PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .setNullsFraction(0.3) + .setLowValue(1.6) + .setHighValue(3.3) + .setAverageRowSize(2.0) + .build()) + .build(); + + assertCalculate(new Cast(new SymbolReference("a"), "bigint"), inputStatistics, types) + .lowValue(2.0) + .highValue(3.0) + .distinctValuesCountUnknown() + .nullsFraction(0.3) + .dataSizeUnknown(); + } + + @Test + public void testCastBigintToDouble() + { + Map types = ImmutableMap.of(new Symbol("a"), BIGINT); + + PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("a"), SymbolStatsEstimate.builder() + .setNullsFraction(0.3) + .setLowValue(2.0) + .setHighValue(10.0) + .setDistinctValuesCount(4) + .setAverageRowSize(2.0) + .build()) + .build(); + + assertCalculate(new Cast(new SymbolReference("a"), "double"), inputStatistics, types) + .lowValue(2.0) + .highValue(10.0) + .distinctValuesCount(4) + .nullsFraction(0.3) + .dataSizeUnknown(); + } + + @Test + public void testCastUnknown() + { + Map types = ImmutableMap.of(new Symbol("a"), DOUBLE); + assertCalculate(new Cast(new SymbolReference("a"), "bigint"), UNKNOWN_STATS, types) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .dataSizeUnknown(); + } + + private SymbolStatsAssertion assertCalculate(Expression scalarExpression) + { + return assertCalculate(scalarExpression, UNKNOWN_STATS); + } + + private SymbolStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics) + { + return assertCalculate(scalarExpression, inputStatistics, emptyMap()); + } + + private SymbolStatsAssertion assertCalculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Map types) + { + return SymbolStatsAssertion.assertThat(calculator.calculate(scalarExpression, inputStatistics, session, types)); + } + + @Test + public void testNonDivideArithmeticBinaryExpression() + { + PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(2.0) + .build()) + .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .setAverageRowSize(2.0) + .build()) + .setOutputRowCount(10) + .build(); + + assertCalculate(expression("x + y"), relationStats) + .distinctValuesCount(10.0) + .lowValue(-3.0) + .highValue(15.0) + .nullsFraction(0.28) + .averageRowSize(2.0); + + assertCalculate(expression("x - y"), relationStats) + .distinctValuesCount(10.0) + .lowValue(-6.0) + .highValue(12.0) + .nullsFraction(0.28) + .averageRowSize(2.0); + + assertCalculate(expression("x * y"), relationStats) + .distinctValuesCount(10.0) + .lowValue(-20.0) + .highValue(50.0) + .nullsFraction(0.28) + .averageRowSize(2.0); + } + + @Test + public void testDivideArithmeticBinaryExpression() + { + assertCalculate(expression("x / y"), xyStats(-11, -3, -5, -4)).lowValue(0.6).highValue(2.75); + assertCalculate(expression("x / y"), xyStats(-11, -3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(expression("x / y"), xyStats(-11, -3, 4, 5)).lowValue(-2.75).highValue(-0.6); + + assertCalculate(expression("x / y"), xyStats(-11, 0, -5, -4)).lowValue(0).highValue(2.75); + assertCalculate(expression("x / y"), xyStats(-11, 0, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(expression("x / y"), xyStats(-11, 0, 4, 5)).lowValue(-2.75).highValue(0); + + assertCalculate(expression("x / y"), xyStats(-11, 3, -5, -4)).lowValue(-0.75).highValue(2.75); + assertCalculate(expression("x / y"), xyStats(-11, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(expression("x / y"), xyStats(-11, 3, 4, 5)).lowValue(-2.75).highValue(0.75); + + assertCalculate(expression("x / y"), xyStats(0, 3, -5, -4)).lowValue(-0.75).highValue(0); + assertCalculate(expression("x / y"), xyStats(0, 3, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(expression("x / y"), xyStats(0, 3, 4, 5)).lowValue(0).highValue(0.75); + + assertCalculate(expression("x / y"), xyStats(3, 11, -5, -4)).lowValue(-2.75).highValue(-0.6); + assertCalculate(expression("x / y"), xyStats(3, 11, -5, 4)).lowValue(NEGATIVE_INFINITY).highValue(POSITIVE_INFINITY); + assertCalculate(expression("x / y"), xyStats(3, 11, 4, 5)).lowValue(0.6).highValue(2.75); + } + + @Test + public void testModulusArithmeticBinaryExpression() + { + // negative + assertCalculate(expression("x % y"), xyStats(-1, 0, -6, -4)).lowValue(-1).highValue(0); + assertCalculate(expression("x % y"), xyStats(-5, 0, -6, -4)).lowValue(-5).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, -6, -4)).lowValue(-6).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, -6, 4)).lowValue(-6).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, -6, 6)).lowValue(-6).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + assertCalculate(expression("x % y"), xyStats(-1, 0, 4, 6)).lowValue(-1).highValue(0); + assertCalculate(expression("x % y"), xyStats(-5, 0, 4, 6)).lowValue(-5).highValue(0); + assertCalculate(expression("x % y"), xyStats(-8, 0, 4, 6)).lowValue(-6).highValue(0); + + // positive + assertCalculate(expression("x % y"), xyStats(0, 5, -6, -4)).lowValue(0).highValue(5); + assertCalculate(expression("x % y"), xyStats(0, 8, -6, -4)).lowValue(0).highValue(6); + assertCalculate(expression("x % y"), xyStats(0, 1, -6, 4)).lowValue(0).highValue(1); + assertCalculate(expression("x % y"), xyStats(0, 5, -6, 4)).lowValue(0).highValue(5); + assertCalculate(expression("x % y"), xyStats(0, 8, -6, 4)).lowValue(0).highValue(6); + assertCalculate(expression("x % y"), xyStats(0, 1, 4, 6)).lowValue(0).highValue(1); + assertCalculate(expression("x % y"), xyStats(0, 5, 4, 6)).lowValue(0).highValue(5); + assertCalculate(expression("x % y"), xyStats(0, 8, 4, 6)).lowValue(0).highValue(6); + + // mix + assertCalculate(expression("x % y"), xyStats(-1, 1, -6, -4)).lowValue(-1).highValue(1); + assertCalculate(expression("x % y"), xyStats(-1, 5, -6, -4)).lowValue(-1).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 1, -6, -4)).lowValue(-5).highValue(1); + assertCalculate(expression("x % y"), xyStats(-5, 5, -6, -4)).lowValue(-5).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 8, -6, -4)).lowValue(-5).highValue(6); + assertCalculate(expression("x % y"), xyStats(-8, 5, -6, -4)).lowValue(-6).highValue(5); + assertCalculate(expression("x % y"), xyStats(-8, 8, -6, -4)).lowValue(-6).highValue(6); + assertCalculate(expression("x % y"), xyStats(-1, 1, -6, 4)).lowValue(-1).highValue(1); + assertCalculate(expression("x % y"), xyStats(-1, 5, -6, 4)).lowValue(-1).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 1, -6, 4)).lowValue(-5).highValue(1); + assertCalculate(expression("x % y"), xyStats(-5, 5, -6, 4)).lowValue(-5).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 8, -6, 4)).lowValue(-5).highValue(6); + assertCalculate(expression("x % y"), xyStats(-8, 5, -6, 4)).lowValue(-6).highValue(5); + assertCalculate(expression("x % y"), xyStats(-8, 8, -6, 4)).lowValue(-6).highValue(6); + assertCalculate(expression("x % y"), xyStats(-1, 1, 4, 6)).lowValue(-1).highValue(1); + assertCalculate(expression("x % y"), xyStats(-1, 5, 4, 6)).lowValue(-1).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 1, 4, 6)).lowValue(-5).highValue(1); + assertCalculate(expression("x % y"), xyStats(-5, 5, 4, 6)).lowValue(-5).highValue(5); + assertCalculate(expression("x % y"), xyStats(-5, 8, 4, 6)).lowValue(-5).highValue(6); + assertCalculate(expression("x % y"), xyStats(-8, 5, 4, 6)).lowValue(-6).highValue(5); + assertCalculate(expression("x % y"), xyStats(-8, 8, 4, 6)).lowValue(-6).highValue(6); + } + + private PlanNodeStatsEstimate xyStats(double lowX, double highX, double lowY, double highY) + { + return PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .setLowValue(lowX) + .setHighValue(highX) + .build()) + .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .setLowValue(lowY) + .setHighValue(highY) + .build()) + .build(); + } + + @Test + public void testCoalesceExpression() + { + PlanNodeStatsEstimate relationStats = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(2.0) + .build()) + .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .setAverageRowSize(2.0) + .build()) + .setOutputRowCount(10) + .build(); + + assertCalculate(expression("coalesce(x, y)"), relationStats) + .distinctValuesCount(5) + .lowValue(-2) + .highValue(10) + .nullsFraction(0.02) + .averageRowSize(2.0); + } + + private Expression expression(String sqlExpression) + { + return rewriteIdentifiersToSymbolReferences(sqlParser.createExpression(sqlExpression)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestStatsCalculator.java similarity index 62% rename from presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java rename to presto-main/src/test/java/com/facebook/presto/cost/TestStatsCalculator.java index 5641fd9bfa2cf..e130de7f53199 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestStatsCalculator.java @@ -13,76 +13,64 @@ */ package com.facebook.presto.cost; -import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.sql.planner.LogicalPlanner; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.assertions.PlanAssert; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; -import static com.facebook.presto.spi.statistics.Estimate.unknownValue; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -public class TestCoefficientBasedCostCalculator +public class TestStatsCalculator { private final LocalQueryRunner queryRunner; - private final CostCalculator costCalculator; - public TestCoefficientBasedCostCalculator() + public TestStatsCalculator() { this.queryRunner = new LocalQueryRunner(testSessionBuilder() .setCatalog("local") .setSchema("tiny") .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel + .setSystemProperty("use_new_stats_calculator", "true") .build()); queryRunner.createCatalog( queryRunner.getDefaultSession().getCatalog().get(), new TpchConnectorFactory(1, true), ImmutableMap.of()); - - costCalculator = new CoefficientBasedCostCalculator(queryRunner.getMetadata()); } @Test - public void testCostCalculatorUsesLayout() + public void testStatsCalculatorUsesLayout() { assertPlan("SELECT orderstatus FROM orders WHERE orderstatus = 'P'", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, anyTree( - node(FilterNode.class, - node(TableScanNode.class) - .withCost(PlanNodeCost.builder() - .setOutputRowCount(new Estimate(385.0)) - .setOutputSizeInBytes(unknownValue()) - .build())))); + node(TableScanNode.class) + .withStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(385.0) + .build()))); assertPlan("SELECT orderstatus FROM orders WHERE orderkey = 42", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, anyTree( - node(FilterNode.class, - node(TableScanNode.class) - .withCost(PlanNodeCost.builder() - .setOutputRowCount(new Estimate(0.0)) - .setOutputSizeInBytes(unknownValue()) - .build())))); - } - - private void assertPlan(String sql, PlanMatchPattern pattern) - { - assertPlan(sql, LogicalPlanner.Stage.CREATED, pattern); + node(TableScanNode.class) + .withStats(PlanNodeStatsEstimate.builder() + .setOutputRowCount(0) + .build()))); } private void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern) { queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, stage); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), costCalculator, actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java new file mode 100644 index 0000000000000..4635fa7cc8a1e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestUnionStatsRule.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; + +public class TestUnionStatsRule +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testUnion() + { + // test cases + // i11, i21 have separated low/high ranges and known all stats, unknown distinct values count + // i12, i22 have overlapping low/high ranges and known all stats, unknown nulls fraction + // i13, i23 have some unknown range stats + // i14, i24 have the same stats + // i15, i25 one has stats, other contains only nulls + + tester.assertStatsFor(pb -> pb + .union( + ImmutableList.of( + pb.values(pb.symbol("i11", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i14", BIGINT), pb.symbol("i15", BIGINT)), + pb.values(pb.symbol("i21", BIGINT), pb.symbol("i22", BIGINT), pb.symbol("i23", BIGINT), pb.symbol("i24", BIGINT), pb.symbol("i25", BIGINT))), + ImmutableListMultimap.builder() + .putAll(pb.symbol("o1", BIGINT), pb.symbol("i11", BIGINT), pb.symbol("i21", BIGINT)) + .putAll(pb.symbol("o2", BIGINT), pb.symbol("i12", BIGINT), pb.symbol("i22", BIGINT)) + .putAll(pb.symbol("o3", BIGINT), pb.symbol("i13", BIGINT), pb.symbol("i23", BIGINT)) + .putAll(pb.symbol("o4", BIGINT), pb.symbol("i14", BIGINT), pb.symbol("i24", BIGINT)) + .putAll(pb.symbol("o5", BIGINT), pb.symbol("i15", BIGINT), pb.symbol("i25", BIGINT)) + .build(), + ImmutableList.of(pb.symbol("o1", BIGINT), pb.symbol("o2", BIGINT), pb.symbol("o3", BIGINT), pb.symbol("o4", BIGINT), pb.symbol("o5", BIGINT)))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol("i11"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setDistinctValuesCount(5) + .setNullsFraction(0.3) + .build()) + .addSymbolStatistics(new Symbol("i12"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setDistinctValuesCount(4) + .setNullsFraction(0) + .build()) + .addSymbolStatistics(new Symbol("i13"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addSymbolStatistics(new Symbol("i14"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addSymbolStatistics(new Symbol("i15"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(20) + .addSymbolStatistics(new Symbol("i21"), SymbolStatsEstimate.builder() + .setLowValue(11) + .setHighValue(20) + .setNullsFraction(0.4) + .build()) + .addSymbolStatistics(new Symbol("i22"), SymbolStatsEstimate.builder() + .setLowValue(2) + .setHighValue(7) + .setDistinctValuesCount(3) + .build()) + .addSymbolStatistics(new Symbol("i23"), SymbolStatsEstimate.builder() + .setDistinctValuesCount(6) + .setNullsFraction(0.2) + .build()) + .addSymbolStatistics(new Symbol("i24"), SymbolStatsEstimate.builder() + .setLowValue(10) + .setHighValue(15) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addSymbolStatistics(new Symbol("i25"), SymbolStatsEstimate.builder() + .setNullsFraction(1) + .build()) + .build()) + .check(check -> check + .outputRowsCount(30) + .symbolStats("o1", assertion -> assertion + .lowValue(1) + .highValue(20) + .dataSizeUnknown() + .nullsFraction(0.3666666)) + .symbolStats("o2", assertion -> assertion + .lowValue(0) + .highValue(7) + .distinctValuesCount(7.0) + .nullsFractionUnknown()) + .symbolStats("o3", assertion -> assertion + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(10.0) + .nullsFraction(0.1666667)) + .symbolStats("o4", assertion -> assertion + .lowValue(10) + .highValue(15) + .distinctValuesCount(6.0) + .nullsFraction(0.1)) + .symbolStats("o5", assertion -> assertion + .lowValue(NEGATIVE_INFINITY) + .highValue(POSITIVE_INFINITY) + .distinctValuesCountUnknown() + .nullsFraction(0.7))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java new file mode 100644 index 0000000000000..b57398656e9ca --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestValuesNodeStats.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.sql.planner.Symbol; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static java.lang.Double.NaN; + +public class TestValuesNodeStats +{ + private StatsCalculatorTester tester; + + @BeforeMethod + public void setUp() + { + tester = new StatsCalculatorTester(); + } + + @AfterMethod + public void tearDown() + { + tester.close(); + tester = null; + } + + @Test + public void testStatsForValuesNode() + throws Exception + { + tester.assertStatsFor(pb -> pb + .values(ImmutableList.of(pb.symbol("a", BIGINT), pb.symbol("b", DOUBLE)), + ImmutableList.of( + ImmutableList.of(expression("3+3"), expression("13.5")), + ImmutableList.of(expression("55"), expression("null")), + ImmutableList.of(expression("6"), expression("13.5"))))) + .check(outputStats -> outputStats.equalTo( + PlanNodeStatsEstimate.builder() + .setOutputRowCount(3) + .addSymbolStatistics( + new Symbol("a"), + SymbolStatsEstimate.builder() + .setNullsFraction(0) + .setLowValue(6) + .setHighValue(55) + .setDistinctValuesCount(2) + .build()) + .addSymbolStatistics( + new Symbol("b"), + SymbolStatsEstimate.builder() + .setNullsFraction(0.33333333333333333) + .setLowValue(13.5) + .setHighValue(13.5) + .setDistinctValuesCount(1) + .build()) + .build())); + } + + @Test + public void testStatsForValuesNodeWithJustNulls() + throws Exception + { + PlanNodeStatsEstimate nullAStats = PlanNodeStatsEstimate.builder() + .setOutputRowCount(1) + .addSymbolStatistics( + new Symbol("a"), + SymbolStatsEstimate.builder() + .setLowValue(NaN) + .setHighValue(NaN) + .setNullsFraction(1.0) + .setDistinctValuesCount(0.0) + .build()) + .build(); + + tester.assertStatsFor(pb -> pb + .values(ImmutableList.of(pb.symbol("a", BIGINT)), + ImmutableList.of( + ImmutableList.of(expression("3 + null"))))) + .check(outputStats -> outputStats.equalTo(nullAStats)); + + tester.assertStatsFor(pb -> pb + .values(ImmutableList.of(pb.symbol("a", BIGINT)), + ImmutableList.of( + ImmutableList.of(expression("null"))))) + .check(outputStats -> outputStats.equalTo(nullAStats)); + + tester.assertStatsFor(pb -> pb + .values(ImmutableList.of(pb.symbol("a", UNKNOWN)), + ImmutableList.of( + ImmutableList.of(expression("null"))))) + .check(outputStats -> outputStats.equalTo(nullAStats)); + } + + @Test + public void testStatsForEmptyValues() + throws Exception + { + tester.assertStatsFor(pb -> pb + .values(ImmutableList.of(pb.symbol("a", BIGINT)), + ImmutableList.of())) + .check(outputStats -> outputStats.equalTo( + PlanNodeStatsEstimate.builder() + .setOutputRowCount(0) + .addSymbolStatistics( + new Symbol("a"), + SymbolStatsEstimate.builder() + .setLowValue(NaN) + .setHighValue(NaN) + .setNullsFraction(0.0) + .setDistinctValuesCount(0.0) + .build()) + .build())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index 5bf137c6caad0..4204a9d64d387 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -18,7 +18,11 @@ import com.facebook.presto.TaskSource; import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CoefficientBasedStatsCalculator; +import com.facebook.presto.cost.CostCalculatorUsingExchanges; +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.SelectingStatsCalculator; import com.facebook.presto.execution.TestSqlTaskManager.MockExchangeClientSupplier; import com.facebook.presto.execution.scheduler.LegacyNetworkTopology; import com.facebook.presto.execution.scheduler.NodeScheduler; @@ -31,6 +35,7 @@ import com.facebook.presto.operator.LookupJoinOperators; import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.index.IndexJoinLookupStats; +import com.facebook.presto.server.ServerMainModule; import com.facebook.presto.spi.block.TestingBlockEncodingSerde; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.predicate.TupleDomain; @@ -127,7 +132,10 @@ public static LocalExecutionPlanner createTestingPlanner() return new LocalExecutionPlanner( metadata, new SqlParser(), - new CoefficientBasedCostCalculator(metadata), + new SelectingStatsCalculator( + new CoefficientBasedStatsCalculator(metadata), + ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata))), + new CostCalculatorUsingExchanges(1), Optional.empty(), pageSourceManager, new IndexManager(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 4e66a6ed77a53..780c48507c263 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -35,6 +35,9 @@ public class TestFeaturesConfig public void testDefaults() { assertRecordedDefaults(ConfigAssertions.recordDefaults(FeaturesConfig.class) + .setCpuCostWeight(0.75) + .setMemoryCostWeight(0) + .setNetworkCostWeight(0.25) .setResourceGroupsEnabled(false) .setDistributedIndexJoinsEnabled(false) .setDistributedJoinsEnabled(true) @@ -64,13 +67,17 @@ public void testDefaults() .setIterativeOptimizerTimeout(new Duration(3, MINUTES)) .setExchangeCompressionEnabled(false) .setEnableIntermediateAggregations(false) - .setPushAggregationThroughJoin(true)); + .setPushAggregationThroughJoin(true) + .setUseNewStatsCalculator(true)); } @Test public void testExplicitPropertyMappings() { Map propertiesLegacy = new ImmutableMap.Builder() + .put("cpu-cost-weight", "0.4") + .put("memory-cost-weight", "0.3") + .put("network-cost-weight", "0.2") .put("experimental.resource-groups-enabled", "true") .put("experimental.iterative-optimizer-enabled", "false") .put("experimental.iterative-optimizer-timeout", "10s") @@ -101,8 +108,12 @@ public void testExplicitPropertyMappings() .put("experimental.spiller-max-used-space-threshold", "0.8") .put("exchange.compression-enabled", "true") .put("optimizer.enable-intermediate-aggregations", "true") + .put("experimental.use-new-stats-calculator", "false") .build(); Map properties = new ImmutableMap.Builder() + .put("cpu-cost-weight", "0.4") + .put("memory-cost-weight", "0.3") + .put("network-cost-weight", "0.2") .put("experimental.resource-groups-enabled", "true") .put("experimental.iterative-optimizer-enabled", "false") .put("experimental.iterative-optimizer-timeout", "10s") @@ -133,9 +144,13 @@ public void testExplicitPropertyMappings() .put("experimental.spiller-max-used-space-threshold", "0.8") .put("exchange.compression-enabled", "true") .put("optimizer.enable-intermediate-aggregations", "true") + .put("experimental.use-new-stats-calculator", "false") .build(); FeaturesConfig expected = new FeaturesConfig() + .setCpuCostWeight(0.4) + .setMemoryCostWeight(0.3) + .setNetworkCostWeight(0.2) .setResourceGroupsEnabled(true) .setIterativeOptimizerEnabled(false) .setIterativeOptimizerTimeout(new Duration(10, SECONDS)) @@ -165,7 +180,8 @@ public void testExplicitPropertyMappings() .setSpillMaxUsedSpaceThreshold(0.8) .setLegacyOrderBy(true) .setExchangeCompressionEnabled(true) - .setEnableIntermediateAggregations(true); + .setEnableIntermediateAggregations(true) + .setUseNewStatsCalculator(false); assertFullMapping(properties, expected); assertDeprecatedEquivalence(FeaturesConfig.class, properties, propertiesLegacy); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java index abc57531342b7..347fde45570e7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -56,7 +56,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); AggregationNode aggregationNode = (AggregationNode) node; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java index 49cb8097ad655..af3f64cf52ced 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -52,7 +52,7 @@ public boolean shapeMatches(PlanNode node) * higher up the tree. */ @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { Optional symbol = matcher.getAssignedSymbol(node, session, metadata, symbolAliases); if (symbol.isPresent() && alias.isPresent()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index af575a145dc20..62f4b7f349832 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -116,7 +116,7 @@ protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPatte { queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, stage); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } @@ -126,7 +126,7 @@ protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMat List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), new PruneUnreferencedOutputs(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections()))); + new IterativeOptimizer(new StatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getCostCalculator(), ImmutableSet.of(new RemoveRedundantIdentityProjections()))); assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java index fd9b9bbb66561..bf03df389842b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -48,7 +48,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); return new MatchResult(getActual.apply(node).equals(getExpectedSymbols(node, session, metadata, symbolAliases))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java index 6c794a8dc8acf..5aa6929da6310 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -46,7 +46,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState( shapeMatches(node), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java index 9e5771c530a26..68116741e2863 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -46,7 +46,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java index 69877793f041b..76fd7092271bc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -41,7 +41,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java index ea4dd92b99b2a..3e8bb8685bcb7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -49,7 +49,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index d73a005c8b22e..a78c856247075 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -57,7 +57,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java index 6d9573bbe7f93..9132c286f41dc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -43,7 +43,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node)); return MatchResult.match(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java index d79f38a45f98d..fb59eb35fd300 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -52,7 +52,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); MarkDistinctNode markDistinctNode = (MarkDistinctNode) node; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java index 106e5a6ce9637..1d0e0758b382d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -63,11 +63,11 @@ public interface Matcher * node if shapeMatches didn't return true for the same node. * * @param node The node to apply the matching tests to - * @param planNodeCost The computed cost of plan node + * @param stats The computed stats of plan node * @param session The session information for the query * @param metadata The metadata for the query * @param symbolAliases The SymbolAliases containing aliases from the nodes sources * @return a MatchResult with information about the success of the match */ - MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases); + MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java index a74219a22c441..85f281b8b834e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -40,7 +40,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); return match(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java index f10fd79b83102..1f2c249e4207e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -45,7 +45,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { int i = 0; for (String alias : aliases) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index 386fb13b7a66e..cef0ba1ba3777 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.iterative.Lookup; @@ -28,18 +27,13 @@ public final class PlanAssert { private PlanAssert() {} - public static void assertPlan(Session session, Metadata metadata, CostCalculator costCalculator, Plan actual, PlanMatchPattern pattern) + public static void assertPlan(Session session, Metadata metadata, Lookup lookup, Plan actual, PlanMatchPattern pattern) { - assertPlan(session, metadata, costCalculator, actual, Lookup.noLookup(), pattern); - } - - public static void assertPlan(Session session, Metadata metadata, CostCalculator costCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern) - { - MatchResult matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata, actual.getPlanNodeCosts(), lookup), pattern); + MatchResult matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata, lookup, actual.getTypes()), pattern); if (!matches.isMatch()) { - String formattedPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, costCalculator, session); + String formattedPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, lookup, session); PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup); - String resolvedFormattedPlan = textLogicalPlan(resolvedPlan, actual.getTypes(), metadata, costCalculator, session); + String resolvedFormattedPlan = textLogicalPlan(resolvedPlan, actual.getTypes(), metadata, lookup, session); throw new AssertionError(format( "Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n] which resolves to [\n\n%s\n]", pattern, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 733e83e24475b..9e39e36ec9506 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -14,7 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.predicate.Domain; @@ -409,12 +409,12 @@ List shapeMatches(PlanNode node) return states.build(); } - MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { SymbolAliases.Builder newAliases = SymbolAliases.builder(); for (Matcher matcher : matchers) { - MatchResult matchResult = matcher.detailMatches(node, planNodeCost, session, metadata, symbolAliases); + MatchResult matchResult = matcher.detailMatches(node, stats, session, metadata, symbolAliases); if (!matchResult.isMatch()) { return NO_MATCH; } @@ -496,9 +496,9 @@ public PlanMatchPattern withExactAssignments(Collection return this; } - public PlanMatchPattern withCost(PlanNodeCost cost) + public PlanMatchPattern withStats(PlanNodeStatsEstimate stats) { - matchers.add(new PlanCostMatcher(cost)); + matchers.add(new PlanStatsMatcher(stats)); return this; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java index 42f5dc64aa6dd..932ae4e35f531 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java @@ -14,15 +14,14 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; -import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -40,15 +39,15 @@ final class PlanMatchingVisitor { private final Metadata metadata; private final Session session; - private final Map planCost; private final Lookup lookup; + private final Map types; - PlanMatchingVisitor(Session session, Metadata metadata, Map planCost, Lookup lookup) + PlanMatchingVisitor(Session session, Metadata metadata, Lookup lookup, Map types) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.planCost = requireNonNull(planCost, "planCost is null"); this.lookup = requireNonNull(lookup, "lookup is null"); + this.types = requireNonNull(types, "types is null"); } @Override @@ -123,7 +122,7 @@ protected MatchResult visitPlan(PlanNode node, PlanMatchPattern pattern) // Try upMatching this node with the the aliases gathered from the source nodes. SymbolAliases allSourceAliases = sourcesMatch.getAliases(); - MatchResult matchResult = pattern.detailMatches(node, planCost.get(node.getId()), session, metadata, allSourceAliases); + MatchResult matchResult = pattern.detailMatches(node, lookup.getStats(node, session, types), session, metadata, allSourceAliases); if (matchResult.isMatch()) { checkState(result == NO_MATCH, format("Ambiguous match on node %s", node)); result = match(allSourceAliases.withNewAliases(matchResult.getAliases())); @@ -148,7 +147,7 @@ private MatchResult matchLeaf(PlanNode node, PlanMatchPattern pattern, List { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java similarity index 65% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java index 6097f46a177f8..ba6d81c259609 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java @@ -15,6 +15,8 @@ package com.facebook.presto.sql.planner.iterative; import com.facebook.presto.Session; +import com.facebook.presto.matching.MatchingEngine; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.DummyMetadata; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -34,19 +36,19 @@ import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; -public class TestRuleStore +public class TestMatchingEngine { private final PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), new DummyMetadata()); @Test - public void test() + public void testWithPlanNodeHierarchy() { - Rule projectRule1 = new NoOpRule(Pattern.node(ProjectNode.class)); - Rule projectRule2 = new NoOpRule(Pattern.node(ProjectNode.class)); - Rule filterRule = new NoOpRule(Pattern.node(FilterNode.class)); + Rule projectRule1 = new NoOpRule(Pattern.matchByClass(ProjectNode.class)); + Rule projectRule2 = new NoOpRule(Pattern.matchByClass(ProjectNode.class)); + Rule filterRule = new NoOpRule(Pattern.matchByClass(FilterNode.class)); Rule anyRule = new NoOpRule(Pattern.any()); - RuleStore ruleStore = RuleStore.builder() + MatchingEngine matchingEngine = MatchingEngine.builder() .register(projectRule1) .register(projectRule2) .register(filterRule) @@ -58,16 +60,40 @@ public void test() ValuesNode valuesNode = planBuilder.values(); assertEquals( - ruleStore.getCandidates(projectNode).collect(toList()), + matchingEngine.getCandidates(projectNode).collect(toList()), ImmutableList.of(projectRule1, projectRule2, anyRule)); assertEquals( - ruleStore.getCandidates(filterNode).collect(toList()), + matchingEngine.getCandidates(filterNode).collect(toList()), ImmutableList.of(filterRule, anyRule)); assertEquals( - ruleStore.getCandidates(valuesNode).collect(toList()), + matchingEngine.getCandidates(valuesNode).collect(toList()), ImmutableList.of(anyRule)); } + @Test + public void testInterfacesHierarchy() + { + Rule a = new NoOpRule(Pattern.matchByClass(A.class)); + Rule b = new NoOpRule(Pattern.matchByClass(B.class)); + Rule ab = new NoOpRule(Pattern.matchByClass(AB.class)); + + MatchingEngine matchingEngine = MatchingEngine.builder() + .register(a) + .register(b) + .register(ab) + .build(); + + assertEquals( + matchingEngine.getCandidates(new A() {}).collect(toList()), + ImmutableList.of(a)); + assertEquals( + matchingEngine.getCandidates(new B() {}).collect(toList()), + ImmutableList.of(b)); + assertEquals( + matchingEngine.getCandidates(new AB()).collect(toList()), + ImmutableList.of(ab, a, b)); + } + private static class NoOpRule implements Rule { @@ -98,4 +124,8 @@ public String toString() .toString(); } } + + private interface A {} + private interface B {} + private static class AB implements A, B {} } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushDownTableConstraints.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushDownTableConstraints.java new file mode 100644 index 0000000000000..5d4a0e7ff7bbf --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushDownTableConstraints.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedTableScan; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; +import static io.airlift.slice.Slices.utf8Slice; + +public class TestPushDownTableConstraints +{ + @Test + public void test() + { + Session session = testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel + .build(); + + LocalQueryRunner queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog( + session.getCatalog().get(), + new TpchConnectorFactory(1), + ImmutableMap.of()); + + Map tableScanConstraint = ImmutableMap.builder() + .put("orderstatus", Domain.singleValue(createVarcharType(1), utf8Slice("P"))) + .build(); + + new RuleTester(queryRunner).assertThat(new PushDownTableConstraints(queryRunner.getMetadata(), new SqlParser())) + .on(p -> + p.filter(expression("orderstatus = 'P'"), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), + ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))) + ))) + .matches(filter("orderstatus = 'P'", + constrainedTableScan("orders", tableScanConstraint, ImmutableMap.of("orderstatus", "orderstatus")))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java index 47b164d7e7b78..e7ef0fd1dc2c0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -30,7 +30,7 @@ public void testRemoveUnreferencedInput() tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) .on(p -> p.lateral( emptyList(), - p.values(new Symbol("x")), + p.values(p.symbol("x", BigintType.BIGINT)), p.values(emptyList(), ImmutableList.of(emptyList())))) .matches(values("x")); } @@ -42,7 +42,7 @@ public void testRemoveUnreferencedSubquery() .on(p -> p.lateral( emptyList(), p.values(emptyList(), ImmutableList.of(emptyList())), - p.values(new Symbol("x")))) + p.values(p.symbol("x", BigintType.BIGINT)))) .matches(values("x")); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 140be79ee2708..1a6e0ff0bfe26 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -36,10 +36,12 @@ import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -424,6 +426,11 @@ public UnionNode union(List sources, ListMultimap) sources, outputsToInputs, outputs); } + public IntersectNode intersect(List sources, ListMultimap outputsToInputs, List outputs) + { + return new IntersectNode(idAllocator.getNextId(), (List) sources, outputsToInputs, outputs); + } + public Symbol symbol(String name) { return symbol(name, BIGINT); @@ -457,6 +464,54 @@ public WindowNode window(WindowNode.Specification specification, Map outputBuilderConsumer) + { + OutputBuilder outputBuilder = new OutputBuilder(); + outputBuilderConsumer.accept(outputBuilder); + return outputBuilder.build(); + } + + public OutputNode output(PlanNode source, List columnNames, List symbols) + { + checkArgument(columnNames.size() == symbols.size(), "columnNames and outputs size do not match"); + OutputBuilder outputBuilder = new OutputBuilder(); + outputBuilder.source(source); + for (int columnIndex = 0; columnIndex < columnNames.size(); ++columnIndex) { + outputBuilder.column(symbols.get(columnIndex), columnNames.get(columnIndex)); + } + return outputBuilder.build(); + } + + public class OutputBuilder + { + private PlanNode source; + private List columnNames = new ArrayList<>(); + private List outputs = new ArrayList<>(); + + public OutputBuilder source(PlanNode source) + { + this.source = source; + return this; + } + + public OutputBuilder column(Symbol symbol) + { + return column(symbol, symbol.getName()); + } + + public OutputBuilder column(Symbol symbol, String columnName) + { + outputs.add(symbol); + columnNames.add(columnName); + return this; + } + + protected OutputNode build() + { + return new OutputNode(idAllocator.getNextId(), source, columnNames, outputs); + } + } + public static Expression expression(String sql) { return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index e078d4d636764..2b79a9d079926 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.type.Type; @@ -28,7 +28,6 @@ import com.facebook.presto.sql.planner.iterative.Memo; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -46,7 +45,6 @@ public class RuleAssert { private final Metadata metadata; - private final CostCalculator costCalculator; private Session session; private final Rule rule; @@ -56,15 +54,18 @@ public class RuleAssert private PlanNode plan; private final TransactionManager transactionManager; private final AccessControl accessControl; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; - public RuleAssert(Metadata metadata, CostCalculator costCalculator, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl) + public RuleAssert(Metadata metadata, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl, StatsCalculator statsCalculator, CostCalculator costCalculator) { this.metadata = metadata; - this.costCalculator = costCalculator; this.session = session; this.rule = rule; this.transactionManager = transactionManager; this.accessControl = accessControl; + this.statsCalculator = statsCalculator; + this.costCalculator = costCalculator; } public RuleAssert setSystemProperty(String key, String value) @@ -98,7 +99,7 @@ public void doesNotFire() fail(String.format( "Expected %s to not fire for:\n%s", rule.getClass().getName(), - inTransaction(session -> PlanPrinter.textLogicalPlan(plan, ruleApplication.types, metadata, costCalculator, session, 2)))); + inTransaction(session -> PlanPrinter.textLogicalPlan(plan, ruleApplication.types, metadata, ruleApplication.lookup, session, 2)))); } } @@ -111,7 +112,7 @@ public void matches(PlanMatchPattern pattern) fail(String.format( "%s did not fire for:\n%s", rule.getClass().getName(), - formatPlan(plan, types))); + formatPlan(plan, types, ruleApplication.lookup))); } PlanNode actual = ruleApplication.getResult(); @@ -120,7 +121,7 @@ public void matches(PlanMatchPattern pattern) fail(String.format( "%s: rule fired but return the original plan:\n%s", rule.getClass().getName(), - formatPlan(plan, types))); + formatPlan(plan, types, ruleApplication.lookup))); } if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { @@ -134,8 +135,7 @@ public void matches(PlanMatchPattern pattern) } inTransaction(session -> { - Map planNodeCosts = costCalculator.calculateCostForPlan(session, types, actual); - assertPlan(session, metadata, costCalculator, new Plan(actual, types, planNodeCosts), ruleApplication.lookup, pattern); + assertPlan(session, metadata, ruleApplication.lookup, new Plan(actual, types, ruleApplication.lookup, session), pattern); return null; }); } @@ -144,7 +144,7 @@ private RuleApplication applyRule() { SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); Memo memo = new Memo(idAllocator, plan); - Lookup lookup = Lookup.from(memo::resolve); + Lookup lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator); if (!rule.getPattern().matches(plan)) { return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); @@ -155,9 +155,9 @@ private RuleApplication applyRule() return new RuleApplication(lookup, symbolAllocator.getTypes(), result); } - private String formatPlan(PlanNode plan, Map types) + private String formatPlan(PlanNode plan, Map types, Lookup lookup) { - return inTransaction(session -> PlanPrinter.textLogicalPlan(plan, types, metadata, costCalculator, session, 2)); + return inTransaction(session -> PlanPrinter.textLogicalPlan(plan, types, metadata, lookup, session, 2)); } private T inTransaction(Function transactionSessionConsumer) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index aa41b2ed8c152..40bfe95a31281 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -15,7 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; import com.facebook.presto.sql.planner.iterative.Rule; @@ -35,7 +34,6 @@ public class RuleTester public static final ConnectorId CONNECTOR_ID = new ConnectorId(CATALOG_ID); private final Metadata metadata; - private final CostCalculator costCalculator; private final Session session; private final LocalQueryRunner queryRunner; private final TransactionManager transactionManager; @@ -43,26 +41,43 @@ public class RuleTester public RuleTester() { - session = testSessionBuilder() + this(createQueryRunner()); + } + + private static LocalQueryRunner createQueryRunner() + { + Session session = testSessionBuilder() .setCatalog(CATALOG_ID) .setSchema("tiny") .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel .build(); - queryRunner = new LocalQueryRunner(session); + LocalQueryRunner queryRunner = new LocalQueryRunner(session); queryRunner.createCatalog(session.getCatalog().get(), new TpchConnectorFactory(1), ImmutableMap.of()); + return queryRunner; + } + public RuleTester(LocalQueryRunner queryRunner) + { + this.queryRunner = queryRunner; + this.session = queryRunner.getDefaultSession(); this.metadata = queryRunner.getMetadata(); - this.costCalculator = queryRunner.getCostCalculator(); this.transactionManager = queryRunner.getTransactionManager(); this.accessControl = queryRunner.getAccessControl(); } public RuleAssert assertThat(Rule rule) { - return new RuleAssert(metadata, costCalculator, session, rule, transactionManager, accessControl); + return new RuleAssert( + metadata, + session, + rule, + transactionManager, + accessControl, + queryRunner.getStatsCalculator(), + queryRunner.getEstimatedExchangesCostCalculator()); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index 3db490c3443bd..668a406091585 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -92,12 +92,12 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern new UnaliasSymbolReferences(), new AddExchanges(queryRunner.getMetadata(), new SqlParser()), new PruneUnreferencedOutputs(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())) + new IterativeOptimizer(new StatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getCostCalculator(), ImmutableSet.of(new RemoveRedundantIdentityProjections())) ); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index 3150ac96759d7..fdf2089eb4069 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -485,14 +485,18 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter LocalQueryRunner queryRunner = getQueryRunner(); List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of( - new RemoveRedundantIdentityProjections(), - new SwapAdjacentWindowsBySpecifications(), - new MergeAdjacentWindows())), + new IterativeOptimizer( + new StatsRecorder(), + queryRunner.getStatsCalculator(), + queryRunner.getEstimatedExchangesCostCalculator(), + ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new SwapAdjacentWindowsBySpecifications(), + new MergeAdjacentWindows())), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java index 6f5c35f3bcdf9..77eaf448e09ba 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java @@ -140,13 +140,13 @@ public void assertUnitPlan(String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), + new IterativeOptimizer(new StatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getEstimatedExchangesCostCalculator(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), new OptimizeMixedDistinctAggregations(queryRunner.getMetadata()), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java index a5c5f0866b32f..40a6321583d92 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java @@ -111,7 +111,7 @@ public void testEliminateCrossJoinPreserveFilters() anyTree( join(INNER, ImmutableList.of(equiJoinClause("P_PARTKEY", "L_PARTKEY")), anyTree(PART_TABLESCAN), - anyTree(filter("L_RETURNFLAG = 'R'", LINEITEM_WITH_RETURNFLAG_TABLESCAN)))), + anyTree(filter("'R' = L_RETURNFLAG", LINEITEM_WITH_RETURNFLAG_TABLESCAN)))), anyTree(filter("O_SHIPPRIORITY >= 10", ORDERS_WITH_SHIPPRIORITY_TABLESCAN))))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index 543bc29c34dfd..8946b247499c0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -274,13 +274,15 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), new IterativeOptimizer(new StatsRecorder(), + queryRunner.getStatsCalculator(), + queryRunner.getEstimatedExchangesCostCalculator(), ImmutableSet.of( new RemoveRedundantIdentityProjections(), new SwapAdjacentWindowsBySpecifications())), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getLookup(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java index 54ef7f6fff88b..a4debb09d0265 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java @@ -129,7 +129,7 @@ public void assertPlan(String sql, PlanMatchPattern pattern) List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), new PruneUnreferencedOutputs(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), + new IterativeOptimizer(new StatsRecorder(), getQueryRunner().getStatsCalculator(), getQueryRunner().getEstimatedExchangesCostCalculator(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), new SetFlatteningOptimizer()); assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java index 25c076b886e7b..e0012b928cb53 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java @@ -20,6 +20,8 @@ import com.teradata.tempto.Requires; import com.teradata.tempto.configuration.Configuration; import com.teradata.tempto.fulfillment.table.MutableTableRequirement; +import com.teradata.tempto.fulfillment.table.hive.HiveTableDefinition; +import com.teradata.tempto.fulfillment.table.hive.InlineDataSource; import com.teradata.tempto.query.QueryExecutor; import org.testng.annotations.Test; @@ -62,6 +64,14 @@ public Requirement getRequirements(Configuration configuration) private static final String ALL_TYPES_TABLE_NAME = "all_types"; private static final String EMPTY_ALL_TYPES_TABLE_NAME = "empty_all_types"; + private static final HiveTableDefinition ALL_TYPES_TABLE = HiveTableDefinition.like(ALL_HIVE_SIMPLE_TYPES_TEXTFILE) + .setDataSource(InlineDataSource.createStringDataSource( + "all_analyzable_types", + "", + "121|32761|2147483641|9223372036854775801|123.341|234.561|344.671|345.671|2015-05-10 12:15:31.123456|2015-05-09|ela ma kota|ela ma kot|ela ma |false|cGllcyBiaW5hcm55|\n" + + "127|32767|2147483647|9223372036854775807|123.345|234.567|345.678|345.678|2015-05-10 12:15:35.123456|2015-05-10|ala ma kota|ala ma kot|ala ma |true|a290IGJpbmFybnk=|\n")) + .build(); + private static final class AllTypesTable implements RequirementsProvider { @@ -69,8 +79,8 @@ private static final class AllTypesTable public Requirement getRequirements(Configuration configuration) { return Requirements.compose( - mutableTable(ALL_HIVE_SIMPLE_TYPES_TEXTFILE, ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.LOADED), - mutableTable(ALL_HIVE_SIMPLE_TYPES_TEXTFILE, EMPTY_ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.CREATED)); + mutableTable(ALL_TYPES_TABLE, ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.LOADED), + mutableTable(ALL_TYPES_TABLE, EMPTY_ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.CREATED)); } } @@ -85,33 +95,33 @@ public void testStatisticsForUnpartitionedTable() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, null, null, null), - row("n_name", null, null, null, null), - row("n_regionkey", null, null, null, null), - row("n_comment", null, null, null, null), - row(null, null, null, null, anyOf(null, 0.0))); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) + row("n_nationkey", null, null, null, null, null, null), + row("n_name", null, null, null, null, null, null), + row("n_regionkey", null, null, null, null, null, null), + row("n_comment", null, null, null, null, null, null), + row(null, null, null, null, anyOf(null, 0.0), null, null)); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) // basic analysis onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, null, null, null), - row("n_name", null, null, null, null), - row("n_regionkey", null, null, null, null), - row("n_comment", null, null, null, null), - row(null, null, null, null, 25.0)); + row("n_nationkey", null, null, null, null, null, null), + row("n_name", null, null, null, null, null, null), + row("n_regionkey", null, null, null, null, null, null), + row("n_comment", null, null, null, null, null, null), + row(null, null, null, null, 25.0, null, null)); // column analysis onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, 19.0, 0.0, null), - row("n_name", null, 24.0, 0.0, null), - row("n_regionkey", null, 5.0, 0.0, null), - row("n_comment", null, 31.0, 0.0, null), - row(null, null, null, null, 25.0)); + row("n_nationkey", null, 19.0, 0.0, null, "0", "24"), + row("n_name", null, 24.0, 0.0, null, null, null), + row("n_regionkey", null, 5.0, 0.0, null, "0", "4"), + row("n_comment", null, 31.0, 0.0, null, null, null), + row(null, null, null, null, 25.0, null, null)); } @Test(groups = {HIVE_CONNECTOR}) @@ -127,118 +137,118 @@ public void testStatisticsForPartitionedTable() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 3.0, null, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 3.0, null, null, "1", "3"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, null, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, null, null, "1", "1"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, null, null, null)); // basic analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 15.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, null, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, null, null, "2", "2"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, null, null, null)); // basic analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 15.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null)); // column analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null), - row("p_name", null, 6.0, 0.0, null), - row("p_regionkey", null, 3.0, 0.0, null), - row("p_comment", null, 1.0, 0.0, null), - row(null, null, null, null, 15.0)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), + row("p_name", null, 6.0, 0.0, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), + row("p_comment", null, 1.0, 0.0, null, null, null), + row(null, null, null, null, 15.0, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null), - row("p_name", null, 6.0, 0.0, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, 1.0, 0.0, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), + row("p_name", null, 6.0, 0.0, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), + row("p_comment", null, 1.0, 0.0, null, null, null), + row(null, null, null, null, 5.0, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null), - row("p_name", null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, null, null, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), + row("p_comment", null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null)); // column analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null), - row("p_name", null, 6.0, 0.0, null), - row("p_regionkey", null, 3.0, 0.0, null), - row("p_comment", null, 1.0, 0.0, null), - row(null, null, null, null, 15.0)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), + row("p_name", null, 6.0, 0.0, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), + row("p_comment", null, 1.0, 0.0, null, null, null), + row(null, null, null, null, 15.0, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null), - row("p_name", null, 6.0, 0.0, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, 1.0, 0.0, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), + row("p_name", null, 6.0, 0.0, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), + row("p_comment", null, 1.0, 0.0, null, null, null), + row(null, null, null, null, 5.0, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, 4.0, 0.0, null), - row("p_name", null, 6.0, 0.0, null), - row("p_regionkey", null, 1.0, 0.0, null), - row("p_comment", null, 1.0, 0.0, null), - row(null, null, null, null, 5.0)); + row("p_nationkey", null, 4.0, 0.0, null, "8", "21"), + row("p_name", null, 6.0, 0.0, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), + row("p_comment", null, 1.0, 0.0, null, null, null), + row(null, null, null, null, 5.0, null, null)); } @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -249,42 +259,42 @@ public void testStatisticsForAllDataTypes() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null), - row("c_smallint", null, null, null, null), - row("c_int", null, null, null, null), - row("c_bigint", null, null, null, null), - row("c_float", null, null, null, null), - row("c_double", null, null, null, null), - row("c_decimal", null, null, null, null), - row("c_decimal_w_params", null, null, null, null), - row("c_timestamp", null, null, null, null), - row("c_date", null, null, null, null), - row("c_string", null, null, null, null), - row("c_varchar", null, null, null, null), - row("c_char", null, null, null, null), - row("c_boolean", null, null, null, null), - row("c_binary", null, null, null, null), - row(null, null, null, null, 1.0)); + row("c_tinyint", null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null), + row(null, null, null, null, 2.0, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 1.0, 0.0, null), - row("c_smallint", null, 1.0, 0.0, null), - row("c_int", null, 2.0, 0.0, null), - row("c_bigint", null, 1.0, 0.0, null), - row("c_float", null, 1.0, 0.0, null), - row("c_double", null, 1.0, 0.0, null), - row("c_decimal", null, 1.0, 0.0, null), - row("c_decimal_w_params", null, 1.0, 0.0, null), - row("c_timestamp", null, 1.0, 0.0, null), - row("c_date", null, 2.0, 0.0, null), - row("c_string", null, 1.0, 0.0, null), - row("c_varchar", null, 1.0, 0.0, null), - row("c_char", null, 1.0, 0.0, null), - row("c_boolean", null, 1.0, 0.0, null), - row("c_binary", null, null, 0.0, null), - row(null, null, null, null, 1.0)); + row("c_tinyint", null, 2.0, 0.0, null, "121", "127"), + row("c_smallint", null, 2.0, 0.0, null, "32761", "32767"), + row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647"), + row("c_bigint", null, 2.0, 0.0, null, "9223372036854775801", "9223372036854775807"), + row("c_float", null, 2.0, 0.0, null, "123.341", "123.345"), + row("c_double", null, 1.0, 0.0, null, "234.561", "234.567"), + row("c_decimal", null, 2.0, 0.0, null, "345", "346"), + row("c_decimal_w_params", null, 2.0, 0.0, null, "345.67100", "345.67800"), + row("c_timestamp", null, 2.0, 0.0, null, "2015-05-10 12:15:31.000", "2015-05-10 12:15:35.000"), + row("c_date", null, 3.0, 0.0, null, "2015-05-09", "2015-05-10"), + row("c_string", null, 2.0, 0.0, null, null, null), + row("c_varchar", null, 2.0, 0.0, null, null, null), + row("c_char", null, 2.0, 0.0, null, null, null), + row("c_boolean", null, 2.0, 0.0, null, null, null), + row("c_binary", null, null, 0.0, null, null, null), + row(null, null, null, null, 2.0, null, null)); } @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -295,42 +305,42 @@ public void testStatisticsForAllDataTypesNoData() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null), - row("c_smallint", null, null, null, null), - row("c_int", null, null, null, null), - row("c_bigint", null, null, null, null), - row("c_float", null, null, null, null), - row("c_double", null, null, null, null), - row("c_decimal", null, null, null, null), - row("c_decimal_w_params", null, null, null, null), - row("c_timestamp", null, null, null, null), - row("c_date", null, null, null, null), - row("c_string", null, null, null, null), - row("c_varchar", null, null, null, null), - row("c_char", null, null, null, null), - row("c_boolean", null, null, null, null), - row("c_binary", null, null, null, null), - row(null, null, null, null, 0.0)); + row("c_tinyint", null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null), + row(null, null, null, null, 0.0, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 0.0, 0.0, null), - row("c_smallint", null, 0.0, 0.0, null), - row("c_int", null, 0.0, 0.0, null), - row("c_bigint", null, 0.0, 0.0, null), - row("c_float", null, 0.0, 0.0, null), - row("c_double", null, 0.0, 0.0, null), - row("c_decimal", null, 0.0, 0.0, null), - row("c_decimal_w_params", null, 0.0, 0.0, null), - row("c_timestamp", null, 0.0, 0.0, null), - row("c_date", null, 0.0, 0.0, null), - row("c_string", null, 0.0, 0.0, null), - row("c_varchar", null, 0.0, 0.0, null), - row("c_char", null, 0.0, 0.0, null), - row("c_boolean", null, 0.0, 0.0, null), - row("c_binary", null, null, 0.0, null), - row(null, null, null, null, 0.0)); + row("c_tinyint", null, 0.0, 0.0, null, null, null), + row("c_smallint", null, 0.0, 0.0, null, null, null), + row("c_int", null, 0.0, 0.0, null, null, null), + row("c_bigint", null, 0.0, 0.0, null, null, null), + row("c_float", null, 0.0, 0.0, null, null, null), + row("c_double", null, 0.0, 0.0, null, null, null), + row("c_decimal", null, 0.0, 0.0, null, null, null), + row("c_decimal_w_params", null, 0.0, 0.0, null, null, null), + row("c_timestamp", null, 0.0, 0.0, null, null, null), + row("c_date", null, 0.0, 0.0, null, null, null), + row("c_string", null, 0.0, 0.0, null, null, null), + row("c_varchar", null, 0.0, 0.0, null, null, null), + row("c_char", null, 0.0, 0.0, null, null, null), + row("c_boolean", null, 0.0, 0.0, null, null, null), + row("c_binary", null, null, 0.0, null, null, null), + row(null, null, null, null, 0.0, null, null)); } @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -343,42 +353,42 @@ public void testStatisticsForAllDataTypesOnlyNulls() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null), - row("c_smallint", null, null, null, null), - row("c_int", null, null, null, null), - row("c_bigint", null, null, null, null), - row("c_float", null, null, null, null), - row("c_double", null, null, null, null), - row("c_decimal", null, null, null, null), - row("c_decimal_w_params", null, null, null, null), - row("c_timestamp", null, null, null, null), - row("c_date", null, null, null, null), - row("c_string", null, null, null, null), - row("c_varchar", null, null, null, null), - row("c_char", null, null, null, null), - row("c_boolean", null, null, null, null), - row("c_binary", null, null, null, null), - row(null, null, null, null, 1.0)); + row("c_tinyint", null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null), + row(null, null, null, null, 1.0, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 1.0, 1.0, null), - row("c_smallint", null, 1.0, 1.0, null), - row("c_int", null, 1.0, 1.0, null), - row("c_bigint", null, 1.0, 1.0, null), - row("c_float", null, 1.0, 1.0, null), - row("c_double", null, 1.0, 1.0, null), - row("c_decimal", null, 1.0, 1.0, null), - row("c_decimal_w_params", null, 1.0, 1.0, null), - row("c_timestamp", null, 1.0, 1.0, null), - row("c_date", null, 1.0, 1.0, null), - row("c_string", null, 1.0, 1.0, null), - row("c_varchar", null, 1.0, 1.0, null), - row("c_char", null, 1.0, 1.0, null), - row("c_boolean", null, 0.0, 1.0, null), - row("c_binary", null, null, 1.0, null), - row(null, null, null, null, 1.0)); + row("c_tinyint", null, 1.0, 1.0, null, null, null), + row("c_smallint", null, 1.0, 1.0, null, null, null), + row("c_int", null, 1.0, 1.0, null, null, null), + row("c_bigint", null, 1.0, 1.0, null, null, null), + row("c_float", null, 1.0, 1.0, null, null, null), + row("c_double", null, 1.0, 1.0, null, null, null), + row("c_decimal", null, 1.0, 1.0, null, null, null), + row("c_decimal_w_params", null, 1.0, 1.0, null, null, null), + row("c_timestamp", null, 1.0, 1.0, null, null, null), + row("c_date", null, 1.0, 1.0, null, null, null), + row("c_string", null, 1.0, 1.0, null, null, null), + row("c_varchar", null, 1.0, 1.0, null, null, null), + row("c_char", null, 1.0, 1.0, null, null, null), + row("c_boolean", null, 0.0, 1.0, null, null, null), + row("c_binary", null, null, 1.0, null, null, null), + row(null, null, null, null, 1.0, null, null)); } private static QueryExecutor onHive() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Constraint.java b/presto-spi/src/main/java/com/facebook/presto/spi/Constraint.java index 59b945ceecf77..b993eec7039db 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Constraint.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Constraint.java @@ -31,6 +31,11 @@ public static Constraint alwaysTrue() return new Constraint<>(TupleDomain.all(), bindings -> true); } + public static Constraint alwaysFalse() + { + return new Constraint<>(TupleDomain.none(), bindings -> false); + } + public Constraint(TupleDomain summary, Predicate> predicate) { requireNonNull(summary, "summary is null"); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java index 113ec2e264815..1b5d847c0d535 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java @@ -11,56 +11,53 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.facebook.presto.spi.statistics; -import java.util.HashMap; -import java.util.Map; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; import static com.facebook.presto.spi.statistics.Estimate.unknownValue; -import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; public final class ColumnStatistics { - private final Map statistics; - private static final String DATA_SIZE_STATISTIC_KEY = "data_size"; - private static final String NULLS_COUNT_STATISTIC_KEY = "nulls_count"; - private static final String DISTINCT_VALUES_STATITIC_KEY = "distinct_values_count"; + private static final List SINGLE_UNKNOWN_RANGE_STATISTICS = singletonList(RangeColumnStatistics.builder().build()); + public static final ColumnStatistics UNKNOWN_COLUMN_STATISTICS = ColumnStatistics.builder().build(); - private ColumnStatistics(Estimate dataSize, Estimate nullsCount, Estimate distinctValuesCount) - { - requireNonNull(dataSize, "dataSize can not be null"); - statistics = createStatisticsMap(dataSize, nullsCount, distinctValuesCount); - } + private final List rangeColumnStatistics; + private final Estimate nullsFraction; - private static Map createStatisticsMap(Estimate dataSize, Estimate nullsCount, Estimate distinctValuesCount) + private ColumnStatistics(Estimate nullsFraction, List rangeColumnStatistics) { - Map statistics = new HashMap<>(); - statistics.put(DATA_SIZE_STATISTIC_KEY, dataSize); - statistics.put(NULLS_COUNT_STATISTIC_KEY, nullsCount); - statistics.put(DISTINCT_VALUES_STATITIC_KEY, distinctValuesCount); - return unmodifiableMap(statistics); - } + this.nullsFraction = requireNonNull(nullsFraction, "nullsFraction can not be null"); + requireNonNull(rangeColumnStatistics, "rangeColumnStatistics can not be null"); + if (rangeColumnStatistics.size() > 1) { + // todo add support for multiple ranges. + throw new IllegalArgumentException("Statistics for multiple ranges are not supported"); + } + if (rangeColumnStatistics.isEmpty()) { + rangeColumnStatistics = SINGLE_UNKNOWN_RANGE_STATISTICS; + } - public Estimate getDataSize() - { - return statistics.get(DATA_SIZE_STATISTIC_KEY); - } + if (nullsFraction.isValueUnknown() ^ rangeColumnStatistics.get(0).getFraction().isValueUnknown()) { + throw new IllegalArgumentException("All or none fraction/nullsFraction must be set"); + } - public Estimate getNullsCount() - { - return statistics.get(NULLS_COUNT_STATISTIC_KEY); + this.rangeColumnStatistics = unmodifiableList(rangeColumnStatistics); } - public Estimate getDistinctValuesCount() + public RangeColumnStatistics getOnlyRangeColumnStatistics() { - return statistics.get(DISTINCT_VALUES_STATITIC_KEY); + return rangeColumnStatistics.get(0); } - public Map getStatistics() + public Estimate getNullsFraction() { - return statistics; + return nullsFraction; } public static Builder builder() @@ -70,31 +67,48 @@ public static Builder builder() public static final class Builder { - private Estimate dataSize = unknownValue(); - private Estimate nullsCount = unknownValue(); - private Estimate distinctValuesCount = unknownValue(); + private Estimate nullsFraction = unknownValue(); + private List rangeColumnStatistics = new ArrayList<>(); + + public Builder setNullsFraction(Estimate nullsFraction) + { + this.nullsFraction = nullsFraction; + return this; + } + + public Builder addRange(Consumer rangeBuilderConsumer) + { + RangeColumnStatistics.Builder rangeBuilder = RangeColumnStatistics.builder(); + rangeBuilderConsumer.accept(rangeBuilder); + addRange(rangeBuilder.build()); + return this; + } - public Builder setDataSize(Estimate dataSize) + public Builder addRange(Object lowValue, Object highValue, Consumer rangeBuilderConsumer) { - this.dataSize = requireNonNull(dataSize, "dataSize can not be null"); + RangeColumnStatistics.Builder rangeBuilder = RangeColumnStatistics.builder(); + rangeBuilder.setLowValue(Optional.of(lowValue)); + rangeBuilder.setHighValue(Optional.of(highValue)); + rangeBuilderConsumer.accept(rangeBuilder); + addRange(rangeBuilder.build()); return this; } - public Builder setNullsCount(Estimate nullsCount) + public Builder addRange(RangeColumnStatistics rangeColumnStatistics) { - this.nullsCount = nullsCount; + this.rangeColumnStatistics.add(rangeColumnStatistics); return this; } - public Builder setDistinctValuesCount(Estimate distinctValuesCount) + public Builder clearRanges() { - this.distinctValuesCount = distinctValuesCount; + rangeColumnStatistics.clear(); return this; } public ColumnStatistics build() { - return new ColumnStatistics(dataSize, nullsCount, distinctValuesCount); + return new ColumnStatistics(nullsFraction, rangeColumnStatistics); } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java index 27a490699fde0..821fb9dbb2ece 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java @@ -34,6 +34,11 @@ public static final Estimate unknownValue() return new Estimate(UNKNOWN_VALUE); } + public static final Estimate zeroValue() + { + return new Estimate(0); + } + public Estimate(double value) { this.value = value; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RangeColumnStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RangeColumnStatistics.java new file mode 100644 index 0000000000000..9781b4d146d78 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/RangeColumnStatistics.java @@ -0,0 +1,139 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.statistics; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +public class RangeColumnStatistics +{ + public static final String DATA_SIZE_STATISTIC_KEY = "data_size"; + public static final String FRACTION_STATISTICS_KEY = "fraction"; + public static final String DISTINCT_VALUES_STATISTICS_KEY = "distinct_values_count"; + + private final Optional lowValue; + private final Optional highValue; + private final Map statistics; + + public RangeColumnStatistics( + Optional lowValue, + Optional highValue, + Estimate fraction, + Estimate dataSize, + Estimate distinctValuesCount) + { + this.lowValue = requireNonNull(lowValue, "lowValue can not be null"); + this.highValue = requireNonNull(highValue, "highValue can not be null"); + requireNonNull(fraction, "fraction can not be null"); + requireNonNull(dataSize, "dataSize can not be null"); + requireNonNull(distinctValuesCount, "distinctValuesCount can not be null"); + this.statistics = createStatisticsMap(dataSize, fraction, distinctValuesCount); + } + + private static Map createStatisticsMap( + Estimate dataSize, + Estimate fraction, + Estimate distinctValuesCount) + { + Map statistics = new HashMap<>(); + statistics.put(FRACTION_STATISTICS_KEY, fraction); + statistics.put(DATA_SIZE_STATISTIC_KEY, dataSize); + statistics.put(DISTINCT_VALUES_STATISTICS_KEY, distinctValuesCount); + return unmodifiableMap(statistics); + } + + public Optional getLowValue() + { + return lowValue; + } + + public Optional getHighValue() + { + return highValue; + } + + public Estimate getDataSize() + { + return statistics.get(DATA_SIZE_STATISTIC_KEY); + } + + public Estimate getFraction() + { + return statistics.get(FRACTION_STATISTICS_KEY); + } + + public Estimate getDistinctValuesCount() + { + return statistics.get(DISTINCT_VALUES_STATISTICS_KEY); + } + + public Map getStatistics() + { + return statistics; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Optional lowValue = Optional.empty(); + private Optional highValue = Optional.empty(); + private Estimate dataSize = unknownValue(); + private Estimate fraction = unknownValue(); + private Estimate distinctValuesCount = unknownValue(); + + public Builder setLowValue(Optional lowValue) + { + this.lowValue = lowValue; + return this; + } + + public Builder setHighValue(Optional highValue) + { + this.highValue = highValue; + return this; + } + + public Builder setFraction(Estimate fraction) + { + this.fraction = fraction; + return this; + } + + public Builder setDataSize(Estimate dataSize) + { + this.dataSize = requireNonNull(dataSize, "dataSize can not be null"); + return this; + } + + public Builder setDistinctValuesCount(Estimate distinctValuesCount) + { + this.distinctValuesCount = distinctValuesCount; + return this; + } + + public RangeColumnStatistics build() + { + return new RangeColumnStatistics(lowValue, highValue, fraction, dataSize, distinctValuesCount); + } + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index e8557d08b56ed..40c65b4e36ddc 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -5846,6 +5846,9 @@ public void testInformationSchemaFiltering() assertQuery( "SELECT table_name FROM information_schema.tables WHERE table_name = 'orders' LIMIT 1", "SELECT 'orders' table_name"); + assertQuery( + "SELECT table_name FROM information_schema.columns WHERE data_type = 'bigint' AND table_name = 'customer' and column_name = 'custkey' LIMIT 1", + "SELECT 'customer' table_name"); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index a9cfeb63ee691..2d07cca3feb91 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -14,9 +14,16 @@ package com.facebook.presto.tests; import com.facebook.presto.Session; -import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CoefficientBasedStatsCalculator; import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostCalculatorUsingExchanges; +import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.SelectingStatsCalculator; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.server.ServerMainModule; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -60,7 +67,6 @@ public abstract class AbstractTestQueryFramework private QueryRunner queryRunner; private H2QueryRunner h2QueryRunner; private SqlParser sqlParser; - private CostCalculator costCalculator; protected AbstractTestQueryFramework(QueryRunnerSupplier supplier) { @@ -74,7 +80,6 @@ public void init() queryRunner = queryRunnerSupplier.get(); h2QueryRunner = new H2QueryRunner(); sqlParser = new SqlParser(); - costCalculator = new CoefficientBasedCostCalculator(queryRunner.getMetadata()); } @AfterClass(alwaysRun = true) @@ -295,13 +300,25 @@ private QueryExplainer getQueryExplainer() Metadata metadata = queryRunner.getMetadata(); FeaturesConfig featuresConfig = new FeaturesConfig().setOptimizeHashGeneration(true); boolean forceSingleNode = queryRunner.getNodeCount() == 1; - List optimizers = new PlanOptimizers(metadata, sqlParser, featuresConfig, forceSingleNode, new MBeanExporter(new TestingMBeanServer())).get(); + CostCalculator costCalculator = new CostCalculatorUsingExchanges(queryRunner.getNodeCount()); + List optimizers = new PlanOptimizers( + metadata, + sqlParser, + featuresConfig, + forceSingleNode, + new MBeanExporter(new TestingMBeanServer()), + new CostComparator(featuresConfig), + new SelectingStatsCalculator( + new CoefficientBasedStatsCalculator(metadata), + ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata))), + costCalculator, + new CostCalculatorWithEstimatedExchanges(costCalculator, queryRunner.getNodeCount())).get(); return new QueryExplainer( optimizers, metadata, queryRunner.getAccessControl(), sqlParser, - costCalculator, + queryRunner.getLookup(), ImmutableMap.of()); } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 17144487a9791..6f683b3ef7578 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -15,7 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.metadata.AllNodes; @@ -29,6 +28,7 @@ import com.facebook.presto.spi.QueryId; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingAccessControlManager; @@ -229,9 +229,9 @@ public Metadata getMetadata() } @Override - public CostCalculator getCostCalculator() + public Lookup getLookup() { - return coordinator.getCostCalculator(); + return coordinator.getLookup(); } @Override diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java b/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java index 2a03128eb4af2..6e7985bfa391f 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java @@ -62,7 +62,7 @@ private String getPlanText(Session session, String sql) { return localQueryRunner.inTransaction(session, transactionSession -> { Plan plan = localQueryRunner.createPlan(transactionSession, sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED); - return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), localQueryRunner.getMetadata(), localQueryRunner.getCostCalculator(), transactionSession); + return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), localQueryRunner.getMetadata(), localQueryRunner.getLookup(), transactionSession); }); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java index 1271f4c3d235c..69d1319692d77 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java @@ -15,7 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.QualifiedObjectName; @@ -24,6 +23,7 @@ import com.facebook.presto.spi.Node; import com.facebook.presto.spi.Plugin; import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingAccessControlManager; @@ -134,9 +134,9 @@ public Metadata getMetadata() } @Override - public CostCalculator getCostCalculator() + public Lookup getLookup() { - return server.getCostCalculator(); + return server.getLookup(); } @Override diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java index bf3d0205af4de..dc49bb4b15fa5 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java @@ -13,25 +13,45 @@ */ package com.facebook.presto.tests.statistics; -import com.facebook.presto.cost.PlanNodeCost; -import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.testing.MaterializedRow; -import java.util.function.Function; +import java.util.Objects; +import java.util.Optional; -public enum Metric +public abstract class Metric { - OUTPUT_ROW_COUNT(PlanNodeCost::getOutputRowCount), - OUTPUT_SIZE_BYTES(PlanNodeCost::getOutputSizeInBytes); + public abstract Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext); - private final Function extractor; + public abstract Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext); - Metric(Function extractor) + public abstract String getComputingAggregationSql(); + + public abstract String getName(); + + // name based equality + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Metric metric = (Metric) o; + return Objects.equals(getName(), metric.getName()); + } + + @Override + public int hashCode() { - this.extractor = extractor; + return Objects.hash(getName()); } - Estimate getValue(PlanNodeCost cost) + @Override + public String toString() { - return extractor.apply(cost); + return getName(); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java index 053b758602b1f..6b70e2f84aeed 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java @@ -13,90 +13,120 @@ */ package com.facebook.presto.tests.statistics; -import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.execution.StageInfo; -import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.QueryId; import com.facebook.presto.sql.planner.Plan; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; -import com.facebook.presto.sql.planner.planPrinter.PlanNodeStats; -import com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedRow; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.function.BinaryOperator; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import java.util.Set; +import java.util.function.Function; -import static com.facebook.presto.execution.StageInfo.getAllStages; -import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.util.MoreMaps.mergeMaps; -import static com.google.common.collect.Maps.transformValues; -import static java.util.Arrays.asList; +import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; public class MetricComparator { - private final List metrics = asList(Metric.values()); - private final double tolerance = 0.1; - - public List getMetricComparisons(Plan queryPlan, StageInfo outputStageInfo) + public Set> getMetricComparisons(String query, QueryRunner runner, Set> metrics) { - return metrics.stream().flatMap(metric -> { - Map estimates = queryPlan.getPlanNodeCosts(); - Map actuals = extractActualCosts(outputStageInfo); - return estimates.entrySet().stream().map(entry -> { - // todo refactor to stay in PlanNodeId domain ???? - PlanNode node = planNodeForId(queryPlan, entry.getKey()); - PlanNodeCost estimate = entry.getValue(); - Optional execution = Optional.ofNullable(actuals.get(node.getId())); - return createMetricComparison(metric, node, estimate, execution); - }); - }).collect(Collectors.toList()); + if (runner instanceof DistributedQueryRunner) { + return getMetricComparisonsDistributed(query, (DistributedQueryRunner) runner, metrics); + } + else if (runner instanceof LocalQueryRunner) { + return getMetricComparisonsLocal(query, (LocalQueryRunner) runner, metrics); + } + else { + throw new IllegalArgumentException("only local and distributed runner supported"); + } } - private PlanNode planNodeForId(Plan queryPlan, PlanNodeId id) + private Set> getMetricComparisonsDistributed(String query, DistributedQueryRunner runner, Set> metrics) { - return searchFrom(queryPlan.getRoot()) - .where(node -> node.getId().equals(id)) - .findOnlyElement(); + String queryId = runner.executeWithQueryId(runner.getDefaultSession(), query).getQueryId(); + Plan queryPlan = runner.getQueryPlan(new QueryId(queryId)); + StageInfo stageInfo = runner.getQueryInfo(new QueryId(queryId)).getOutputStage().get(); + OutputNode outputNode = (OutputNode) stageInfo.getPlan().getRoot(); + + return getMetricComparisons(query, runner, queryPlan, outputNode, metrics); } - private Map extractActualCosts(StageInfo outputStageInfo) + private Set> getMetricComparisonsLocal(String query, LocalQueryRunner runner, Set> metrics) { - Stream> stagesStatsStream = - getAllStages(Optional.of(outputStageInfo)).stream() - .map(PlanNodeStatsSummarizer::aggregatePlanNodeStats); + Plan queryPlan = inTransaction(runner, (session) -> runner.createPlan(session, query)); + OutputNode outputNode = (OutputNode) queryPlan.getRoot(); + return getMetricComparisons(query, runner, queryPlan, outputNode, metrics); + } - Map mergedStats = mergeStats(stagesStatsStream); - return transformValues(mergedStats, this::toPlanNodeCost); + private T inTransaction(QueryRunner runner, Function transactionSessionConsumer) + { + return transaction(runner.getTransactionManager(), runner.getAccessControl()) + .singleStatement() + .execute(runner.getDefaultSession(), session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> runner.getMetadata().getCatalogHandle(session, catalog)); + return transactionSessionConsumer.apply(session); + }); } - private Map mergeStats(Stream> stagesStatsStream) + private Set> getMetricComparisons(String query, QueryRunner runner, Plan queryPlan, OutputNode outputNode, Set> metrics) { - BinaryOperator allowNoDuplicates = (a, b) -> { - throw new IllegalArgumentException("PlanNodeIds must be unique"); - }; - return mergeMaps(stagesStatsStream, allowNoDuplicates); + StatsContext statsContext = buildStatsContext(queryPlan, outputNode); + List> metricsList = ImmutableList.copyOf(metrics); + List> actualValues = getActualValues(metricsList, query, runner, statsContext); + List> estimatedValues = getEstimatedValues(metricsList, queryPlan.getPlanNodeStats().get(outputNode.getId()), statsContext); + + ImmutableSet.Builder> metricComparisons = ImmutableSet.builder(); + for (int i = 0; i < metricsList.size(); ++i) { + metricComparisons.add(new MetricComparison( + outputNode, + metricsList.get(i), + estimatedValues.get(i), + actualValues.get(i))); + } + return metricComparisons.build(); } - private PlanNodeCost toPlanNodeCost(PlanNodeStats operatorStats) + private StatsContext buildStatsContext(Plan queryPlan, OutputNode outputNode) { - return PlanNodeCost.builder() - .setOutputRowCount(new Estimate(operatorStats.getPlanNodeOutputPositions())) - .setOutputSizeInBytes(new Estimate(operatorStats.getPlanNodeOutputDataSize().toBytes())) - .build(); + ImmutableMap.Builder columnSymbols = ImmutableMap.builder(); + for (int columnId = 0; columnId < outputNode.getColumnNames().size(); ++columnId) { + columnSymbols.put(outputNode.getColumnNames().get(columnId), outputNode.getOutputSymbols().get(columnId)); + } + return new StatsContext(columnSymbols.build(), queryPlan.getTypes()); } - private MetricComparison createMetricComparison(Metric metric, PlanNode node, PlanNodeCost estimate, Optional execution) + private List> getActualValues(List> metrics, String query, QueryRunner runner, StatsContext statsContext) { - Optional estimatedCost = asOptional(metric.getValue(estimate)); - Optional executionCost = execution.flatMap(e -> asOptional(metric.getValue(e))); - return new MetricComparison(node, metric, estimatedCost, executionCost, tolerance); + String statsQuery = "SELECT " + + metrics.stream().map(Metric::getComputingAggregationSql).collect(joining(",")) + + " FROM (" + query + ")"; + + MaterializedRow actualValuesRow = getOnlyElement(runner.execute(statsQuery).getMaterializedRows()); + + ImmutableList.Builder> actualValues = ImmutableList.builder(); + for (int i = 0; i < metrics.size(); ++i) { + actualValues.add(metrics.get(i).getValueFromAggregationQuery(actualValuesRow, i, statsContext)); + } + return actualValues.build(); } - private Optional asOptional(Estimate estimate) + private List> getEstimatedValues(List> metrics, PlanNodeStatsEstimate outputNodeStatisticsEstimates, StatsContext statsContext) { - return estimate.isValueUnknown() ? Optional.empty() : Optional.of(estimate.getValue()); + return metrics.stream() + .map(metric -> metric.getValueFromPlanNodeEstimate(outputNodeStatisticsEstimates, statsContext)) + .collect(toList()); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java index 3ab2b50682d7b..393fefd32e952 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java @@ -21,24 +21,21 @@ import static com.facebook.presto.tests.statistics.MetricComparison.Result.MATCH; import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_BASELINE; import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_ESTIMATE; -import static java.lang.Math.abs; import static java.lang.String.format; -public class MetricComparison +public class MetricComparison { private final PlanNode planNode; private final Metric metric; - private final Optional estimatedCost; - private final Optional executionCost; - private final double tolerance; + private final Optional estimatedValue; + private final Optional actualValue; - public MetricComparison(PlanNode planNode, Metric metric, Optional estimatedCost, Optional executionCost, double tolerance) + public MetricComparison(PlanNode planNode, Metric metric, Optional estimatedValue, Optional actualValue) { this.planNode = planNode; this.metric = metric; - this.estimatedCost = estimatedCost; - this.executionCost = executionCost; - this.tolerance = tolerance; + this.estimatedValue = estimatedValue; + this.actualValue = actualValue; } public Metric getMetric() @@ -54,29 +51,27 @@ public PlanNode getPlanNode() @Override public String toString() { - return format("Metric [%s] - [%s] - estimated: [%s], real: [%s] - plan node: [%s]", - metric, result(), print(estimatedCost), print(executionCost), planNode); + return format("Metric [%s] - estimated: [%s], real: [%s] - plan node: [%s]", + metric, print(estimatedValue), print(actualValue), planNode); } - public Result result() + public Result result(MetricComparisonStrategy metricComparisonStrategy) { - return estimatedCost - .map(estimate -> executionCost - .map(execution -> estimateMatchesReality(estimate, execution) ? MATCH : DIFFER) + if (!estimatedValue.isPresent() && !actualValue.isPresent()) { + return MATCH; + } + return estimatedValue + .map(estimate -> actualValue + .map(execution -> metricComparisonStrategy.matches(execution, estimate) ? MATCH : DIFFER) .orElse(NO_BASELINE)) .orElse(NO_ESTIMATE); } - private String print(Optional cost) + private String print(Optional cost) { return cost.map(Object::toString).orElse("UNKNOWN"); } - private boolean estimateMatchesReality(double estimate, double execution) - { - return abs(execution - estimate) / execution < tolerance; - } - public enum Result { NO_ESTIMATE, diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategies.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategies.java new file mode 100644 index 0000000000000..b2bc4d5ab53c2 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategies.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.tests.statistics; + +import com.google.common.collect.Range; + +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; + +public class MetricComparisonStrategies +{ + private MetricComparisonStrategies() {} + + public static MetricComparisonStrategy noError() + { + return absoluteError(0); + } + + public static MetricComparisonStrategy absoluteError(double error) + { + return absoluteError(Range.closed(-error, error)); + } + + public static MetricComparisonStrategy absoluteError(Range errorRange) + { + return (actual, estimate) -> mapRange(errorRange, endpoint -> endpoint + actual) + .contains(estimate); + } + + public static MetricComparisonStrategy defaultTolerance() + { + return relativeError(.1); + } + + public static MetricComparisonStrategy relativeError(double error) + { + return relativeError(Range.closed(-error, error)); + } + + public static MetricComparisonStrategy relativeError(Range errorRange) + { + return (actual, estimate) -> mapRange(errorRange, endpoint -> (endpoint + 1) * actual) + .contains(estimate); + } + + private static Range mapRange(Range range, Function mappingFunction) + { + checkArgument(range.hasLowerBound() && range.hasUpperBound(), "Expected error range to have lower and upper bound"); + return Range.range( + mappingFunction.apply(range.lowerEndpoint()), + range.lowerBoundType(), + mappingFunction.apply(range.upperEndpoint()), + range.upperBoundType()); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategy.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategy.java new file mode 100644 index 0000000000000..e9538ba18d271 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparisonStrategy.java @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.tests.statistics; + +public interface MetricComparisonStrategy +{ + boolean matches(T actual, T estimate); +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java new file mode 100644 index 0000000000000..11ed409d3741f --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metrics.java @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.statistics; + +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.testing.MaterializedRow; + +import java.util.Optional; + +import static java.lang.Double.isInfinite; +import static java.lang.Double.isNaN; + +public final class Metrics +{ + private Metrics() {} + + public static final Metric OUTPUT_ROW_COUNT = new Metric() + { + @Override + public Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) + { + return asOptional(planNodeStatsEstimate.getOutputRowCount()); + } + + @Override + public Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext) + { + return Optional.of(((Number) aggregationQueryResult.getField(fieldId)).doubleValue()); + } + + @Override + public String getComputingAggregationSql() + { + return "count(*)"; + } + + @Override + public String getName() + { + return "OUTPUT_ROW_COUNT"; + } + }; + + public static Metric nullsFraction(String columnName) + { + return new Metric() + { + @Override + public Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) + { + return asOptional(getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getNullsFraction()); + } + + @Override + public Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext) + { + return Optional.of(((Number) aggregationQueryResult.getField(fieldId)).doubleValue()); + } + + @Override + public String getComputingAggregationSql() + { + return "(count(*) filter(where " + columnName + " is null)) / cast(count(*) as double)"; + } + + @Override + public String getName() + { + return "NULLS_FRACTION(" + columnName + ")"; + } + }; + } + + public static Metric distinctValuesCount(String columnName) + { + return new Metric() + { + @Override + public Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) + { + return asOptional(getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getDistinctValuesCount()); + } + + @Override + public Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext) + { + return Optional.of(((Number) aggregationQueryResult.getField(fieldId)).doubleValue()); + } + + @Override + public String getComputingAggregationSql() + { + return "count(distinct " + columnName + ")"; + } + + @Override + public String getName() + { + return "DISTINCT_VALUES_COUNT(" + columnName + ")"; + } + }; + } + + public static Metric lowValue(String columnName) + { + return new Metric() + { + @Override + public Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) + { + double lowValue = getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getLowValue(); + if (isInfinite(lowValue)) { + return Optional.empty(); + } + else { + return Optional.of(lowValue); + } + } + + @Override + public Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext) + { + return Optional.ofNullable(aggregationQueryResult.getField(fieldId)).map(value -> ((Number) value).doubleValue()); + } + + @Override + public String getComputingAggregationSql() + { + return "try_cast(min(" + columnName + ") as double)"; + } + + @Override + public String getName() + { + return "LOW_VALUE(" + columnName + ")"; + } + }; + } + + public static Metric highValue(String columnName) + { + return new Metric() + { + @Override + public Optional getValueFromPlanNodeEstimate(PlanNodeStatsEstimate planNodeStatsEstimate, StatsContext statsContext) + { + double highValue = getSymbolStatistics(planNodeStatsEstimate, columnName, statsContext).getHighValue(); + if (isInfinite(highValue)) { + return Optional.empty(); + } + else { + return Optional.of(highValue); + } + } + + @Override + public Optional getValueFromAggregationQuery(MaterializedRow aggregationQueryResult, int fieldId, StatsContext statsContext) + { + return Optional.ofNullable(aggregationQueryResult.getField(fieldId)).map(value -> ((Number) value).doubleValue()); + } + + @Override + public String getComputingAggregationSql() + { + return "max(try_cast(" + columnName + " as double))"; + } + + @Override + public String getName() + { + return "HIGH_VALUE(" + columnName + ")"; + } + }; + } + + private static SymbolStatsEstimate getSymbolStatistics(PlanNodeStatsEstimate planNodeStatsEstimate, String columnName, StatsContext statsContext) + { + return planNodeStatsEstimate.getSymbolStatistics(statsContext.getSymbolForColumn(columnName)); + } + + private static Optional asOptional(double value) + { + return isNaN(value) ? Optional.empty() : Optional.of(value); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatisticsAssertion.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatisticsAssertion.java new file mode 100644 index 0000000000000..b5c407ec5ee74 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatisticsAssertion.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.tests.statistics; + +import com.facebook.presto.testing.QueryRunner; +import org.intellij.lang.annotations.Language; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import static com.facebook.presto.tests.statistics.MetricComparison.Result.MATCH; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_BASELINE; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_ESTIMATE; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.noError; +import static com.facebook.presto.tests.statistics.Metrics.distinctValuesCount; +import static com.facebook.presto.tests.statistics.Metrics.highValue; +import static com.facebook.presto.tests.statistics.Metrics.lowValue; +import static com.facebook.presto.tests.statistics.Metrics.nullsFraction; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertTrue; + +public class StatisticsAssertion +{ + private final QueryRunner runner; + + public StatisticsAssertion(QueryRunner runner) + { + this.runner = requireNonNull(runner, "runner is null"); + } + + public void check(@Language("SQL") String query, Consumer checksBuilderConsumer) + { + Checks checks = new Checks(); + checksBuilderConsumer.accept(checks); + checks.run(query, runner); + } + + private static class MetricsCheck + { + public final Metric metric; + public final MetricComparisonStrategy strategy; + public final MetricComparison.Result expectedComparisonResult; + + public MetricsCheck(Metric metric, MetricComparisonStrategy strategy, MetricComparison.Result expectedComparisonResult) + { + this.metric = metric; + this.strategy = strategy; + this.expectedComparisonResult = expectedComparisonResult; + } + } + + public static class Checks + { + private final List> checks = new ArrayList<>(); + + public Checks verifyExactColumnStatistics(String columnName) + { + verifyColumnStatistics(columnName, noError()); + return this; + } + + public Checks verifyColumnStatistics(String columnName, MetricComparisonStrategy strategy) + { + estimate(nullsFraction(columnName), strategy); + estimate(distinctValuesCount(columnName), strategy); + estimate(lowValue(columnName), strategy); + estimate(highValue(columnName), strategy); + return this; + } + + public Checks verifyNoColumnStatistics(String columnName) + { + noEstimate(nullsFraction(columnName)); + noEstimate(distinctValuesCount(columnName)); + noEstimate(lowValue(columnName)); + noEstimate(highValue(columnName)); + return this; + } + + public Checks estimate(Metric metric, MetricComparisonStrategy strategy) + { + checks.add(new MetricsCheck<>(metric, strategy, MATCH)); + return this; + } + + public Checks noEstimate(Metric metric) + { + checks.add(new MetricsCheck<>(metric, (actual, estimate) -> true, NO_ESTIMATE)); + return this; + } + + public Checks noBaseline(Metric metric) + { + checks.add(new MetricsCheck<>(metric, (actual, estimate) -> true, NO_BASELINE)); + return this; + } + + void run(@Language("SQL") String query, QueryRunner runner) + { + Set> metrics = checks.stream() + .map(check -> check.metric) + .collect(toImmutableSet()); + Set> metricComparisons = metricComparisons(query, runner, metrics); + for (MetricsCheck check : checks) { + testMetrics(check.metric, metricComparison -> metricComparison.result(check.strategy) == check.expectedComparisonResult, metricComparisons); + } + } + + private Set> metricComparisons(@Language("SQL") String query, QueryRunner queryRunner, Set> metrics) + { + return new MetricComparator().getMetricComparisons(query, queryRunner, metrics); + } + + private Checks testMetrics(Metric metric, Predicate assertCondition, Set> metricComparisons) + { + List testMetrics = metricComparisons.stream() + .filter(metricComparison -> metricComparison.getMetric().equals(metric)) + .collect(toImmutableList()); + assertTrue(testMetrics.size() > 0, "No metric found for: " + metric); + assertTrue(testMetrics.stream().allMatch(assertCondition), "Following metrics do not match: " + testMetrics); + return this; + } + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java new file mode 100644 index 0000000000000..00a6d25cfebd1 --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/StatsContext.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.statistics; + +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; + +public class StatsContext +{ + private final Map columnSymbols; + private final Map symbolTypes; + + public StatsContext(Map columnSymbols, Map symbolTypes) + { + this.columnSymbols = ImmutableMap.copyOf(columnSymbols); + this.symbolTypes = ImmutableMap.copyOf(symbolTypes); + } + + public Symbol getSymbolForColumn(String columnName) + { + checkArgument(columnSymbols.containsKey(columnName), "no symbol found for column '" + columnName + "'"); + return columnSymbols.get(columnName); + } + + public Type getTypeForSymbol(Symbol symbol) + { + checkArgument(symbolTypes.containsKey(symbol), "no type found found for symbol '" + symbol + "'"); + return symbolTypes.get(symbol); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java index b5c5964e1f3dc..87f4a08a51e7f 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java @@ -72,8 +72,15 @@ public void testShowColumnStats() // FIXME Add tests for more complex scenario with more stats MaterializedResult result = computeActual("SHOW STATS FOR nation"); - MaterializedResult expectedStatistics = resultBuilder(getSession(), VARCHAR, DOUBLE) - .row(null, 25.0) + MaterializedResult expectedStatistics = + resultBuilder(getSession(), VARCHAR, DOUBLE, DOUBLE, DOUBLE, DOUBLE, VARCHAR, VARCHAR) + .row("regionkey", null, 5.0, 0.0, null, "0", "4") + .row("name", null, 25.0, 0.0, null, "ALGERIA", "VIETNAM") + .row("comment", null, 25.0, 0.0, null, + " haggle. carefully final deposit...", + "y final packages. slow foxes caj...") + .row("nationkey", null, 25.0, 0.0, null, "0", "24") + .row(null, null, null, null, 25.0, null, null) .build(); assertEquals(result, expectedStatistics); diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java index de6b72355c556..a1b7d17cb74f3 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java @@ -13,144 +13,121 @@ */ package com.facebook.presto.tests; -import com.facebook.presto.execution.StageInfo; -import com.facebook.presto.spi.QueryId; -import com.facebook.presto.sql.planner.Plan; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.tests.statistics.Metric; -import com.facebook.presto.tests.statistics.MetricComparator; -import com.facebook.presto.tests.statistics.MetricComparison; +import com.facebook.presto.tests.statistics.StatisticsAssertion; import com.facebook.presto.tpch.ColumnNaming; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; -import com.google.common.io.Resources; +import com.google.common.collect.Range; +import io.airlift.tpch.TpchTable; import org.testng.annotations.Test; -import java.io.IOException; -import java.nio.charset.Charset; -import java.util.List; -import java.util.Map; -import java.util.stream.IntStream; - -import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.tests.statistics.MetricComparison.Result.DIFFER; -import static com.facebook.presto.tests.statistics.MetricComparison.Result.MATCH; -import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_BASELINE; -import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_ESTIMATE; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.absoluteError; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.defaultTolerance; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.noError; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.relativeError; +import static com.facebook.presto.tests.statistics.Metrics.OUTPUT_ROW_COUNT; import static com.facebook.presto.tests.tpch.TpchQueryRunner.createQueryRunnerWithoutCatalogs; -import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; -import static java.util.stream.Collectors.groupingBy; -import static org.testng.Assert.assertEquals; public class TestTpchDistributedStats { - public static final int NUMBER_OF_TPCH_QUERIES = 22; - - DistributedQueryRunner runner; + private final StatisticsAssertion statisticsAssertion; public TestTpchDistributedStats() throws Exception { - runner = createQueryRunnerWithoutCatalogs(emptyMap(), emptyMap()); + DistributedQueryRunner runner = createQueryRunnerWithoutCatalogs(emptyMap(), emptyMap()); runner.createCatalog("tpch", "tpch", ImmutableMap.of( "tpch.column-naming", ColumnNaming.STANDARD.name() )); + statisticsAssertion = new StatisticsAssertion(runner); } - @Test - void testEstimateForSimpleQuery() + @Test(enabled = false) + void testTableScanStats() { - String queryId = executeQuery("SELECT * FROM NATION"); - - Plan queryPlan = getQueryPlan(queryId); - - MetricComparison rootOutputRowCountComparison = getRootOutputRowCountComparison(queryId, queryPlan); - assertEquals(rootOutputRowCountComparison.result(), MATCH); + TpchTable.getTables() + .forEach(table -> statisticsAssertion.check("SELECT * FROM " + table.getTableName(), + checks -> checks + // TODO use noError + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()))); } - private MetricComparison getRootOutputRowCountComparison(String queryId, Plan queryPlan) + @Test(enabled = false) + void testFilter() { - List comparisons = new MetricComparator().getMetricComparisons(queryPlan, getOutputStageInfo(queryId)); - return comparisons.stream() - .filter(comparison -> comparison.getMetric().equals(Metric.OUTPUT_ROW_COUNT)) - .filter(comparison -> comparison.getPlanNode().equals(queryPlan.getRoot())) - .findFirst() - .orElseThrow(() -> new AssertionError("No comparison for root node found")); + String query = "SELECT * FROM lineitem" + + " WHERE l_shipdate <= DATE '1998-12-01' - INTERVAL '90' DAY"; + statisticsAssertion.check(query, + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(Range.closed(-.55, -.45)))); } - /** - * This is a development tool for manual inspection of differences between - * cost estimates and actual execution costs. Its outputs need to be inspected - * manually because at this point no sensible assertions can be formulated - * for the entirety of TPCH queries. - */ @Test(enabled = false) - void testCostEstimatesVsRealityDifferences() + void testJoin() { - IntStream.rangeClosed(1, NUMBER_OF_TPCH_QUERIES) - .filter(i -> i != 15) //query 15 creates a view, which TPCH connector does not support. - .forEach(i -> summarizeQuery(i, getTpchQuery(i))); + statisticsAssertion.check("SELECT * FROM part, partsupp WHERE p_partkey = ps_partkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(Range.closed(.95, 1.05)))); } - private String getTpchQuery(int i) + @Test(enabled = false) + void testSetOperations() { - try { - String queryClassPath = "/io/airlift/tpch/queries/q" + i + ".sql"; - return Resources.toString(getClass().getResource(queryClassPath), Charset.defaultCharset()); - } - catch (IOException e) { - throw Throwables.propagate(e); - } + statisticsAssertion.check("SELECT * FROM nation UNION SELECT * FROM nation", + checks -> checks + .noEstimate(OUTPUT_ROW_COUNT)); + + statisticsAssertion.check("SELECT * FROM nation INTERSECT SELECT * FROM nation", + checks -> checks + .noEstimate(OUTPUT_ROW_COUNT)); + + statisticsAssertion.check("SELECT * FROM nation EXCEPT SELECT * FROM nation", + checks -> checks + .noEstimate(OUTPUT_ROW_COUNT)); } - private Plan getQueryPlan(String queryId) + @Test(enabled = false) + void testEnforceSingleRow() { - return runner.getQueryPlan(new QueryId(queryId)); + statisticsAssertion.check("SELECT (SELECT n_regionkey FROM nation WHERE n_name = 'Germany')", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, noError())); } - private void summarizeQuery(int queryNumber, String query) + @Test(enabled = false) + void testValues() { - String queryId = executeQuery(query); - Plan queryPlan = getQueryPlan(queryId); - - List allPlanNodes = searchFrom(queryPlan.getRoot()).findAll(); - - System.out.println(format("Query TPCH [%s] produces [%s] plan nodes.\n", queryNumber, allPlanNodes.size())); - - List comparisons = new MetricComparator().getMetricComparisons(queryPlan, getOutputStageInfo(queryId)); - - Map>> metricSummaries = - comparisons.stream() - .collect(groupingBy(MetricComparison::getMetric, groupingBy(MetricComparison::result))); - - metricSummaries.forEach((metricName, resultSummaries) -> { - System.out.println(format("Summary for metric [%s]", metricName)); - outputSummary(resultSummaries, NO_ESTIMATE); - outputSummary(resultSummaries, NO_BASELINE); - outputSummary(resultSummaries, DIFFER); - outputSummary(resultSummaries, MATCH); - System.out.println(); - }); - - System.out.println("Detailed results:\n"); - - comparisons.forEach(System.out::println); + statisticsAssertion.check("VALUES 1", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, noError())); } - private String executeQuery(String query) + @Test(enabled = false) + void testSemiJoin() { - return runner.executeWithQueryId(runner.getDefaultSession(), query).getQueryId(); + statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region)", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, noError())); + statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region WHERE r_regionkey % 3 = 0)", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, absoluteError(Range.singleton(15.)))); } - private StageInfo getOutputStageInfo(String queryId) + @Test(enabled = false) + void testLimit() { - return runner.getQueryInfo(new QueryId(queryId)).getOutputStage().get(); + statisticsAssertion.check("SELECT * FROM nation LIMIT 10", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, noError())); } - private void outputSummary(Map> resultSummaries, MetricComparison.Result result) + @Test(enabled = false) + void testGroupBy() { - System.out.println(format("[%s]\t-\t[%s]", result, resultSummaries.getOrDefault(result, emptyList()).size())); + String query = "SELECT l_returnflag, l_linestatus FROM lineitem" + + " GROUP BY l_returnflag, l_linestatus"; + statisticsAssertion.check(query, + checks -> checks + .noEstimate(OUTPUT_ROW_COUNT)); } } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchLocalStats.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchLocalStats.java new file mode 100644 index 0000000000000..bec9d05feb3ab --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchLocalStats.java @@ -0,0 +1,215 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tests.statistics.StatisticsAssertion; +import com.facebook.presto.tpch.ColumnNaming; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; +import static com.facebook.presto.SystemSessionProperties.USE_NEW_STATS_CALCULATOR; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.defaultTolerance; +import static com.facebook.presto.tests.statistics.MetricComparisonStrategies.relativeError; +import static com.facebook.presto.tests.statistics.Metrics.OUTPUT_ROW_COUNT; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestTpchLocalStats +{ + private final StatisticsAssertion statisticsAssertion; + + public TestTpchLocalStats() + throws Exception + { + Session defaultSession = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") + .setSystemProperty(USE_NEW_STATS_CALCULATOR, "true") + .build(); + + LocalQueryRunner runner = new LocalQueryRunner(defaultSession); + runner.createCatalog("tpch", new TpchConnectorFactory(1), + ImmutableMap.of("tpch.column-naming", ColumnNaming.STANDARD.name() + )); + statisticsAssertion = new StatisticsAssertion(runner); + } + + @Test + void testTableScanStats() + { + statisticsAssertion.check("SELECT * FROM nation", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("n_regionkey") + .verifyExactColumnStatistics("n_name") + ); + } + + @Test + void testInnerJoinStats() + { + // cross join + statisticsAssertion.check("SELECT * FROM supplier, nation", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM supplier, nation WHERE n_nationkey <= 12", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyColumnStatistics("n_nationkey", relativeError(0.10)) + .verifyExactColumnStatistics("s_suppkey")); + + // simple equi joins + statisticsAssertion.check("SELECT * FROM supplier, nation WHERE s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM supplier, nation WHERE s_nationkey = n_nationkey AND n_nationkey <= 12", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(0.15)) + .verifyColumnStatistics("s_nationkey", relativeError(0.15)) + .verifyColumnStatistics("n_nationkey", relativeError(0.15))); + + // two joins on different keys + statisticsAssertion.check("SELECT * FROM nation, supplier, partsupp WHERE n_nationkey = s_nationkey AND s_suppkey = ps_suppkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("ps_partkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_name")); + statisticsAssertion.check("SELECT * FROM nation, supplier, partsupp WHERE n_nationkey = s_nationkey AND s_suppkey = ps_suppkey AND n_nationkey <= 12", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(0.15)) + .verifyColumnStatistics("ps_partkey", relativeError(0.15)) + .verifyColumnStatistics("n_nationkey", relativeError(0.15)) + .verifyColumnStatistics("s_nationkey", relativeError(0.15))); + } + + @Test + void testLeftJoinStats() + { + // simple equi join + statisticsAssertion.check("SELECT * FROM supplier left join nation on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM supplier left join nation on s_nationkey = n_nationkey AND n_nationkey <= 12", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyColumnStatistics("n_nationkey", relativeError(0.10)) + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM (SELECT * FROM supplier WHERE s_nationkey <= 12) left join nation on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(0.15)) + .verifyColumnStatistics("s_nationkey", relativeError(0.15)) + .verifyColumnStatistics("n_nationkey", relativeError(0.10))); + } + + @Test + void testRightJoinStats() + { + // simple equi join + statisticsAssertion.check("SELECT * FROM nation right join supplier on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM nation right join supplier on s_nationkey = n_nationkey AND n_nationkey <= 12", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyColumnStatistics("n_nationkey", relativeError(0.10)) + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM nation right JOIN (SELECT * FROM supplier WHERE s_nationkey <= 12) on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(0.15)) + .verifyColumnStatistics("s_nationkey", relativeError(0.15)) + .verifyColumnStatistics("n_nationkey", relativeError(0.10))); + } + + @Test + void testFullJoinStats() + { + // simple equi join + statisticsAssertion.check("SELECT * FROM nation full join supplier on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM (SELECT * FROM nation WHERE n_nationkey <= 12) full join supplier on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("s_nationkey") + .verifyColumnStatistics("n_nationkey", relativeError(0.10)) + .verifyExactColumnStatistics("s_suppkey")); + statisticsAssertion.check("SELECT * FROM nation full join (SELECT * FROM supplier WHERE s_nationkey <= 12) on s_nationkey = n_nationkey", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, relativeError(0.15)) + .verifyColumnStatistics("s_nationkey", relativeError(0.15)) + .verifyColumnStatistics("n_nationkey", relativeError(0.10))); + } + + @Test + public void testAggregation() + { + statisticsAssertion.check("SELECT count() AS count FROM nation", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyNoColumnStatistics("count")); + + statisticsAssertion.check("SELECT n_name, count() AS count FROM nation GROUP BY n_name", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyNoColumnStatistics("count") + .verifyExactColumnStatistics("n_name")); + + statisticsAssertion.check("SELECT n_name, count() AS count FROM nation, region GROUP BY n_name", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyNoColumnStatistics("count") + .verifyExactColumnStatistics("n_name")); + } + + @Test + public void testUnion() + { + statisticsAssertion.check( + "SELECT * FROM nation " + + "UNION ALL " + + "SELECT * FROM nation " + + "UNION ALL " + + "SELECT * FROM nation ", + checks -> checks + .estimate(OUTPUT_ROW_COUNT, defaultTolerance()) + .verifyExactColumnStatistics("n_nationkey") + .verifyExactColumnStatistics("n_regionkey")); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java index f922deafe1af9..d9c06abc5e87a 100644 --- a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java @@ -17,12 +17,12 @@ import com.facebook.presto.connector.thrift.ThriftPlugin; import com.facebook.presto.connector.thrift.location.HostList; import com.facebook.presto.connector.thrift.server.ThriftTpchService; -import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingAccessControlManager; @@ -167,9 +167,9 @@ public Metadata getMetadata() } @Override - public CostCalculator getCostCalculator() + public Lookup getLookup() { - return source.getCostCalculator(); + return source.getLookup(); } @Override diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java index dcc40d2eb059f..b45c92eb2aa0c 100644 --- a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java @@ -32,14 +32,12 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.RecordPageSource; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.tpch.TpchMetadata; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import io.airlift.json.JsonCodec; import io.airlift.tpch.TpchColumn; -import io.airlift.tpch.TpchColumnType; import io.airlift.tpch.TpchEntity; import io.airlift.tpch.TpchTable; @@ -117,7 +115,7 @@ public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTabl TpchTable tpchTable = TpchTable.getTable(schemaTableName.getTableName()); List columns = new ArrayList<>(); for (TpchColumn column : tpchTable.getColumns()) { - columns.add(new PrestoThriftColumnMetadata(column.getSimplifiedColumnName(), getTypeString(column.getType()), null, false)); + columns.add(new PrestoThriftColumnMetadata(column.getSimplifiedColumnName(), getTypeString(column), null, false)); } return new PrestoThriftNullableTableMetadata(new PrestoThriftTableMetadata(schemaTableName, columns, null)); } @@ -254,7 +252,7 @@ private static ConnectorPageSource createPageSource(TpchT private static List types(String tableName, List columnNames) { TpchTable table = TpchTable.getTable(tableName); - return columnNames.stream().map(name -> getPrestoType(table.getColumn(name).getType())).collect(toList()); + return columnNames.stream().map(name -> getPrestoType(table.getColumn(name))).collect(toList()); } private static double schemaNameToScaleFactor(String schemaName) @@ -268,8 +266,8 @@ private static double schemaNameToScaleFactor(String schemaName) throw new IllegalArgumentException("Schema is not setup: " + schemaName); } - private static String getTypeString(TpchColumnType tpchType) + private static String getTypeString(TpchColumn column) { - return TpchMetadata.getPrestoType(tpchType).getTypeSignature().toString(); + return getPrestoType(column).getTypeSignature().toString(); } } diff --git a/presto-tpch/pom.xml b/presto-tpch/pom.xml index baa0182b7c951..f3c1e12b6985d 100644 --- a/presto-tpch/pom.xml +++ b/presto-tpch/pom.xml @@ -39,6 +39,16 @@ provided + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + + com.fasterxml.jackson.core jackson-annotations diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java index 066dcc83c9a5c..27fcc19637c28 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java @@ -32,13 +32,19 @@ import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.ColumnStatistics; import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.type.BigintType; -import com.facebook.presto.spi.type.DateType; -import com.facebook.presto.spi.type.DoubleType; -import com.facebook.presto.spi.type.IntegerType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.tpch.statistics.ColumnStatisticsData; +import com.facebook.presto.tpch.statistics.StatisticsEstimator; +import com.facebook.presto.tpch.statistics.TableStatisticsData; +import com.facebook.presto.tpch.statistics.TableStatisticsDataRepository; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -59,9 +65,19 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Maps.asMap; +import static io.airlift.tpch.OrderColumn.ORDER_STATUS; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toSet; public class TpchMetadata @@ -76,14 +92,18 @@ public class TpchMetadata public static final String ROW_NUMBER_COLUMN_NAME = "row_number"; private static final Set ORDER_STATUS_VALUES = ImmutableSet.of("F", "O", "P"); + private static final Set ORDER_STATUS_VALUES_SLICES = ORDER_STATUS_VALUES.stream() + .map(Slices::utf8Slice) + .collect(toImmutableSet()); private static final Set ORDER_STATUS_NULLABLE_VALUES = ORDER_STATUS_VALUES.stream() - .map(value -> new NullableValue(getPrestoType(OrderColumn.ORDER_STATUS.getType()), Slices.utf8Slice(value))) + .map(value -> new NullableValue(getPrestoType(ORDER_STATUS), Slices.utf8Slice(value))) .collect(toSet()); private final String connectorId; private final Set tableNames; private final boolean predicatePushdownEnabled; private final ColumnNaming columnNaming; + private final StatisticsEstimator statisticsEstimator; public TpchMetadata(String connectorId) { @@ -100,6 +120,15 @@ public TpchMetadata(String connectorId, boolean predicatePushdownEnabled, Column this.connectorId = connectorId; this.predicatePushdownEnabled = predicatePushdownEnabled; this.columnNaming = columnNaming; + this.statisticsEstimator = createStatisticsEstimator(); + } + + private static StatisticsEstimator createStatisticsEstimator() + { + ObjectMapper objectMapper = new ObjectMapper() + .registerModule(new Jdk8Module()); + TableStatisticsDataRepository tableStatisticsDataRepository = new TableStatisticsDataRepository(objectMapper); + return new StatisticsEstimator(tableStatisticsDataRepository); } @Override @@ -211,7 +240,7 @@ private static ConnectorTableMetadata getTableMetadata(String schemaName, TpchTa { ImmutableList.Builder columns = ImmutableList.builder(); for (TpchColumn column : tpchTable.getColumns()) { - columns.add(new ColumnMetadata(columnNaming.getName(column), getPrestoType(column.getType()))); + columns.add(new ColumnMetadata(columnNaming.getName(column), getPrestoType(column))); } columns.add(new ColumnMetadata(ROW_NUMBER_COLUMN_NAME, BIGINT, null, true)); @@ -247,58 +276,94 @@ public Map> listTableColumns(ConnectorSess @Override public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint) { - TpchTableHandle table = (TpchTableHandle) tableHandle; - return new TableStatistics(new Estimate(getRowCount(table, Optional.of(constraint.getSummary()))), ImmutableMap.of()); + TpchTableHandle tpchTableHandle = (TpchTableHandle) tableHandle; + String tableName = tpchTableHandle.getTableName(); + TpchTable tpchTable = TpchTable.getTable(tableName); + Map, List> columnValuesRestrictions = getColumnValuesRestrictions(tpchTable, constraint); + TableStatisticsData tableStatisticsData = statisticsEstimator.estimateStats(tpchTable, columnValuesRestrictions, tpchTableHandle.getScaleFactor()); + return toTableStatistics(tableStatisticsData, tpchTableHandle, getColumnHandles(session, tpchTableHandle)); } - private long getRowCount(TpchTableHandle tpchTableHandle, Optional> predicate) + private Map, List> getColumnValuesRestrictions(TpchTable tpchTable, Constraint constraint) { - // todo expose row counts from airlift-tpch instead of hardcoding it here - // todo add stats for columns - String tableName = tpchTableHandle.getTableName(); - double scaleFactor = tpchTableHandle.getScaleFactor(); - switch (tableName.toLowerCase()) { - case "customer": - return (long) (150_000 * scaleFactor); - case "orders": - Set orderStatusValues = predicate.map(tupleDomain -> - ORDER_STATUS_NULLABLE_VALUES.stream() - .filter(convertToPredicate(tupleDomain, OrderColumn.ORDER_STATUS)) - .map(nullableValue -> ((Slice) nullableValue.getValue()).toStringUtf8()) - .collect(toSet())) - .orElse(ORDER_STATUS_VALUES); - - long totalRows = 0L; - if (orderStatusValues.contains("F")) { - totalRows = 729_413; - } - if (orderStatusValues.contains("O")) { - totalRows += 732_044; - } - if (orderStatusValues.contains("P")) { - totalRows += 38_543; - } - return (long) (totalRows * scaleFactor); - case "lineitem": - return (long) (6_000_000 * scaleFactor); - case "part": - return (long) (200_000 * scaleFactor); - case "partsupp": - return (long) (800_000 * scaleFactor); - case "supplier": - return (long) (10_000 * scaleFactor); - case "nation": - return 25; - case "region": - return 5; - default: - throw new IllegalArgumentException("unknown tpch table name '" + tableName + "'"); + TupleDomain constraintSummary = constraint.getSummary(); + if (constraintSummary.isAll()) { + return emptyMap(); + } + else if (constraintSummary.isNone()) { + Set> columns = ImmutableSet.copyOf(tpchTable.getColumns()); + return asMap(columns, key -> emptyList()); + } + else { + Map domains = constraintSummary.getDomains().get(); + Optional orderStatusDomain = Optional.ofNullable(domains.get(toColumnHandle(ORDER_STATUS))); + Optional, List>> allowedColumnValues = orderStatusDomain.map(domain -> { + List allowedValues = ORDER_STATUS_VALUES_SLICES.stream() + .filter(domain::includesNullableValue) + .collect(toList()); + return avoidTrivialOrderStatusRestriction(allowedValues); + }); + return allowedColumnValues.orElse(emptyMap()); + } + } + + private Map, List> avoidTrivialOrderStatusRestriction(List allowedValues) + { + if (allowedValues.containsAll(ORDER_STATUS_VALUES_SLICES)) { + return emptyMap(); + } + else { + return ImmutableMap.of(ORDER_STATUS, allowedValues); + } + } + + private TableStatistics toTableStatistics(TableStatisticsData tableStatisticsData, TpchTableHandle tpchTableHandle, Map columnHandles) + { + TableStatistics.Builder builder = TableStatistics.builder() + .setRowCount(new Estimate(tableStatisticsData.getRowCount())); + tableStatisticsData.getColumns().forEach((columnName, stats) -> { + TpchColumnHandle columnHandle = (TpchColumnHandle) getColumnHandle(tpchTableHandle, columnHandles, columnName); + builder.setColumnStatistics(columnHandle, toColumnStatistics(stats, columnHandle.getType())); + }); + return builder.build(); + } + + private ColumnHandle getColumnHandle(TpchTableHandle tpchTableHandle, Map columnHandles, String columnName) + { + TpchTable table = TpchTable.getTable(tpchTableHandle.getTableName()); + return columnHandles.get(columnNaming.getName(table.getColumn(columnName))); + } + + private ColumnStatistics toColumnStatistics(ColumnStatisticsData stats, Type columnType) + { + return ColumnStatistics.builder() + .addRange(rangeBuilder -> rangeBuilder + .setDistinctValuesCount(stats.getDistinctValuesCount().map(Estimate::new).orElse(unknownValue())) + .setLowValue(stats.getMin().map(value -> toPrestoValue(value, columnType))) + .setHighValue(stats.getMax().map(value -> toPrestoValue(value, columnType))) + .setFraction(new Estimate(1.0))) + .setNullsFraction(Estimate.zeroValue()) + .build(); + } + + private Object toPrestoValue(Object tpchValue, Type columnType) + { + if (columnType instanceof VarcharType) { + return Slices.utf8Slice((String) tpchValue); + } + if (columnType.equals(BIGINT) || columnType.equals(INTEGER) || columnType.equals(DATE)) { + return ((Number) tpchValue).longValue(); + } + if (columnType.equals(DOUBLE)) { + return ((Number) tpchValue).doubleValue(); } + throw new IllegalArgumentException("unsupported column type " + columnType); } - private TpchColumnHandle toColumnHandle(TpchColumn column) + @VisibleForTesting + TpchColumnHandle toColumnHandle(TpchColumn column) { - return new TpchColumnHandle(columnNaming.getName(column), getPrestoType(column.getType())); + return new TpchColumnHandle(columnNaming.getName(column), getPrestoType(column)); } @Override @@ -381,7 +446,7 @@ private static String scaleFactorSchemaName(double scaleFactor) return "sf" + scaleFactor; } - private static double schemaNameToScaleFactor(String schemaName) + public static double schemaNameToScaleFactor(String schemaName) { if (TINY_SCHEMA_NAME.equals(schemaName)) { return TINY_SCALE_FACTOR; @@ -399,17 +464,18 @@ private static double schemaNameToScaleFactor(String schemaName) } } - public static Type getPrestoType(TpchColumnType tpchType) + public static Type getPrestoType(TpchColumn column) { + TpchColumnType tpchType = column.getType(); switch (tpchType.getBase()) { case IDENTIFIER: return BigintType.BIGINT; case INTEGER: - return IntegerType.INTEGER; + return INTEGER; case DATE: - return DateType.DATE; + return DATE; case DOUBLE: - return DoubleType.DOUBLE; + return DOUBLE; case VARCHAR: return createVarcharType((int) (long) tpchType.getPrecision().get()); } diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java index 21e352622da77..b55d93ba6baac 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java @@ -75,10 +75,10 @@ public TpchRecordSet(Iterable table, Iterable> columns, Optiona this.table = table; this.columns = ImmutableList.copyOf(columns); - this.columnTypes = ImmutableList.copyOf(transform(columns, column -> getPrestoType(column.getType()))); + this.columnTypes = ImmutableList.copyOf(transform(columns, TpchMetadata::getPrestoType)); columnHandles = this.columns.stream() - .map(column -> new TpchColumnHandle(column.getColumnName(), getPrestoType(column.getType()))) + .map(column -> new TpchColumnHandle(column.getColumnName(), getPrestoType(column))) .collect(toList()); this.predicate = predicate.map(TpchRecordSet::convertToPredicate); } @@ -135,7 +135,7 @@ public long getReadTimeNanos() @Override public Type getType(int field) { - return getPrestoType(getTpchColumn(field).getType()); + return getPrestoType(getTpchColumn(field)); } @Override diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsData.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsData.java new file mode 100644 index 0000000000000..ceea8465081db --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsData.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ColumnStatisticsData +{ + private final Optional distinctValuesCount; + private final Optional nullsCount; + private final Optional min; + private final Optional max; + + @JsonCreator + public ColumnStatisticsData( + @JsonProperty("distinctValuesCount") Optional distinctValuesCount, + @JsonProperty("nullsCount") Optional nullsCount, + @JsonProperty("min") Optional min, + @JsonProperty("max") Optional max) + { + this.distinctValuesCount = requireNonNull(distinctValuesCount); + this.nullsCount = requireNonNull(nullsCount); + this.min = requireNonNull(min); + this.max = requireNonNull(max); + } + + public static ColumnStatisticsData empty() + { + return new ColumnStatisticsData( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + + public static ColumnStatisticsData zero() + { + return new ColumnStatisticsData( + Optional.of(0L), + Optional.of(0L), + Optional.empty(), + Optional.empty() + ); + } + + public Optional getDistinctValuesCount() + { + return distinctValuesCount; + } + + public Optional getNullsCount() + { + return nullsCount; + } + + public Optional getMin() + { + return min; + } + + public Optional getMax() + { + return max; + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsRecorder.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsRecorder.java new file mode 100644 index 0000000000000..39a0721a9449f --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/ColumnStatisticsRecorder.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import java.util.Optional; +import java.util.TreeSet; + +class ColumnStatisticsRecorder +{ + private final TreeSet nonNullValues = new TreeSet<>(); + private long nullsCount = 0; + + public void record(Comparable value) + { + if (value != null) { + nonNullValues.add(value); + } + else { + nullsCount++; + } + } + + public ColumnStatisticsData getRecording() + { + return new ColumnStatisticsData( + Optional.of(getUniqueValuesCount()), + Optional.of(getNullsCount()), + getLowestValue(), + getHighestValue() + ); + } + + private long getUniqueValuesCount() + { + return nonNullValues.size(); + } + + private long getNullsCount() + { + return nullsCount; + } + + private Optional getLowestValue() + { + return nonNullValues.size() > 0 ? Optional.of(nonNullValues.first()) : Optional.empty(); + } + + private Optional getHighestValue() + { + return nonNullValues.size() > 0 ? Optional.of(nonNullValues.last()) : Optional.empty(); + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/StatisticsEstimator.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/StatisticsEstimator.java new file mode 100644 index 0000000000000..70225abfbd95f --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/StatisticsEstimator.java @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import com.facebook.presto.util.Types; +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import io.airlift.tpch.CustomerColumn; +import io.airlift.tpch.LineItemColumn; +import io.airlift.tpch.OrderColumn; +import io.airlift.tpch.PartColumn; +import io.airlift.tpch.PartSupplierColumn; +import io.airlift.tpch.SupplierColumn; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchTable; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.tpch.util.Optionals.checkPresent; +import static com.facebook.presto.tpch.util.Optionals.combine; +import static com.facebook.presto.tpch.util.Optionals.withBoth; +import static com.facebook.presto.util.Types.checkSameType; +import static com.facebook.presto.util.Types.checkType; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Math.abs; +import static java.lang.String.format; +import static java.util.Optional.empty; + +public class StatisticsEstimator +{ + //These columns are unsupported because scaling them would require sophisticated logic + private static final Set> UNSUPPORTED_COLUMNS = ImmutableSet.of( + CustomerColumn.ACCOUNT_BALANCE, + LineItemColumn.EXTENDED_PRICE, + OrderColumn.TOTAL_PRICE, + PartColumn.RETAIL_PRICE, + PartSupplierColumn.AVAILABLE_QUANTITY, + SupplierColumn.ACCOUNT_BALANCE + ); + + private final TableStatisticsDataRepository tableStatisticsDataRepository; + + public StatisticsEstimator(TableStatisticsDataRepository tableStatisticsDataRepository) + { + this.tableStatisticsDataRepository = tableStatisticsDataRepository; + } + + public TableStatisticsData estimateStats(TpchTable tpchTable, Map, List> columnValuesRestrictions, double scaleFactor) + { + TableStatisticsData bigStatistics = readStatistics(tpchTable, columnValuesRestrictions, "sf1"); + TableStatisticsData smallStatistics = readStatistics(tpchTable, columnValuesRestrictions, "tiny"); + double rescalingFactor = smallStatistics.getRowCount() == bigStatistics.getRowCount() ? 1 : scaleFactor; + return new TableStatisticsData( + (long) (bigStatistics.getRowCount() * rescalingFactor), + rescale(tpchTable, bigStatistics, smallStatistics, scaleFactor) + ); + } + + private Map rescale(TpchTable table, TableStatisticsData bigStatistics, TableStatisticsData smallStatistics, double scaleFactor) + { + return bigStatistics.getColumns().entrySet().stream().collect(toImmutableMap( + Map.Entry::getKey, + entry -> { + String columnName = entry.getKey(); + TpchColumn column = table.getColumn(columnName); + ColumnStatisticsData bigColumnStatistics = entry.getValue(); + ColumnStatisticsData smallColumnStatistics = smallStatistics.getColumns().get(columnName); + return rescale(column, bigColumnStatistics, smallColumnStatistics, scaleFactor); + } + )); + } + + private ColumnStatisticsData rescale(TpchColumn column, ColumnStatisticsData big, ColumnStatisticsData small, double scaleFactor) + { + if (UNSUPPORTED_COLUMNS.contains(column)) { + return ColumnStatisticsData.empty(); + } + else { + return rescale(big, small, scaleFactor); + } + } + + private ColumnStatisticsData rescale(ColumnStatisticsData big, ColumnStatisticsData small, double scaleFactor) + { + if (columnDoesNotScale(big, small)) { + return new ColumnStatisticsData( + big.getDistinctValuesCount(), + big.getNullsCount(), + checkPresent(withBoth(big.getMin(), small.getMin(), this::checkClose)), + checkPresent(withBoth(big.getMax(), small.getMax(), this::checkClose)) + ); + } + else { + Function rescale = value -> value.doubleValue() * scaleFactor; + return new ColumnStatisticsData( + big.getDistinctValuesCount().map(rescale).map(Number::longValue), + big.getNullsCount().map(rescale).map(Number::longValue), + Types.tryCast(big.getMin(), Number.class), + Types.tryCast(big.getMax(), Number.class).map(rescale)); + } + } + + private boolean columnDoesNotScale(ColumnStatisticsData big, ColumnStatisticsData small) + { + return withBoth(small.getMin(), big.getMin(), this::areClose) + .flatMap(lowerBoundsClose -> + withBoth(small.getMax(), big.getMax(), this::areClose) + .map(upperBoundsClose -> lowerBoundsClose && upperBoundsClose)) + .orElse(false); + } + + private Object checkClose(Object leftValue, Object rightValue) + { + checkArgument( + areClose(leftValue, rightValue), + format("Values must be close to each other, got [%s] and [%s]", leftValue, rightValue)); + return leftValue; + } + + private boolean areClose(Object leftValue, Object rightValue) + { + checkSameType(leftValue, rightValue); + if (leftValue instanceof String) { + return leftValue.equals(rightValue); + } + else { + Number left = checkType(leftValue, Number.class); + Number right = checkType(rightValue, Number.class); + return areClose(left.doubleValue(), right.doubleValue()); + } + } + + private boolean areClose(double left, double right) + { + return abs(right - left) <= abs(left) * 0.01; + } + + private TableStatisticsData readStatistics(TpchTable table, Map, List> columnValuesRestrictions, String schemaName) + { + if (columnValuesRestrictions.isEmpty()) { + return tableStatisticsDataRepository.load(schemaName, table, empty(), empty()); + } + else if (columnValuesRestrictions.values().stream().allMatch(List::isEmpty)) { + return zeroStatistics(table); + } + else { + checkArgument(columnValuesRestrictions.size() <= 1, "Can only estimate stats when at most one column has value restrictions"); + TpchColumn partitionColumn = getOnlyElement(columnValuesRestrictions.keySet()); + List partitionValues = columnValuesRestrictions.get(partitionColumn); + TableStatisticsData result = zeroStatistics(table); + for (Object partitionValue : partitionValues) { + Slice value = checkType(partitionValue, Slice.class, "Only string (Slice) partition values supported for now"); + TableStatisticsData tableStatisticsData = tableStatisticsDataRepository + .load(schemaName, table, Optional.of(partitionColumn), Optional.of(value.toStringUtf8())); + result = addPartitionStats(result, tableStatisticsData, partitionColumn); + } + return result; + } + } + + private TableStatisticsData addPartitionStats(TableStatisticsData left, TableStatisticsData right, TpchColumn partitionColumn) + { + return new TableStatisticsData( + left.getRowCount() + right.getRowCount(), + addPartitionStats(left.getColumns(), right.getColumns(), partitionColumn) + ); + } + + private Map addPartitionStats(Map leftColumns, Map rightColumns, TpchColumn partitionColumn) + { + return leftColumns.entrySet().stream().collect(toImmutableMap( + Map.Entry::getKey, + entry -> { + String columnName = entry.getKey(); + ColumnStatisticsData leftStats = entry.getValue(); + ColumnStatisticsData rightStats = rightColumns.get(columnName); + return new ColumnStatisticsData( + combineUniqueValuesCount(partitionColumn, columnName, leftStats, rightStats), + combine(leftStats.getNullsCount(), rightStats.getNullsCount(), (a, b) -> a + b), + combine(leftStats.getMin(), rightStats.getMin(), this::min), + combine(leftStats.getMax(), rightStats.getMax(), this::max) + ); + })); + } + + private Optional combineUniqueValuesCount(TpchColumn partitionColumn, String columnName, ColumnStatisticsData leftStats, ColumnStatisticsData rightStats) + { + //unique values count can't be added between different partitions + //for columns other than the partition column (because almost certainly there are duplicates) + return combine(leftStats.getDistinctValuesCount(), rightStats.getDistinctValuesCount(), (a, b) -> a + b) + .filter(v -> columnName.equals(partitionColumn.getColumnName())); + } + + @SuppressWarnings("unchecked") + private Object min(Object l, Object r) + { + checkSameType(l, r); + Comparable left = checkType(l, Comparable.class); + Comparable right = checkType(r, Comparable.class); + return left.compareTo(right) < 0 ? left : right; + } + + @SuppressWarnings("unchecked") + private Object max(Object l, Object r) + { + checkSameType(l, r); + Comparable left = checkType(l, Comparable.class); + Comparable right = checkType(r, Comparable.class); + return left.compareTo(right) > 0 ? left : right; + } + + private TableStatisticsData zeroStatistics(TpchTable table) + { + return new TableStatisticsData(0, table.getColumns().stream().collect(toImmutableMap( + TpchColumn::getColumnName, + column -> ColumnStatisticsData.zero() + ))); + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsData.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsData.java new file mode 100644 index 0000000000000..cb1d1f3c2df70 --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsData.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +public class TableStatisticsData +{ + private final long rowCount; + private final Map columns; + + @JsonCreator + public TableStatisticsData( + @JsonProperty("rowCount") long rowCount, + @JsonProperty("columns") Map columns) + { + this.rowCount = rowCount; + this.columns = ImmutableMap.copyOf(columns); + } + + public long getRowCount() + { + return rowCount; + } + + public Map getColumns() + { + return columns; + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsDataRepository.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsDataRepository.java new file mode 100644 index 0000000000000..ad56395b23267 --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsDataRepository.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchTable; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Optional; + +import static com.facebook.presto.tpch.util.Optionals.withBoth; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; + +public class TableStatisticsDataRepository +{ + private final ObjectMapper objectMapper; + + public TableStatisticsDataRepository(ObjectMapper objectMapper) + { + this.objectMapper = objectMapper; + } + + public void save( + String schemaName, + TpchTable table, + Optional> partitionColumn, + Optional partitionValue, + TableStatisticsData statisticsData) + { + String filename = tableStatisticsDataFilename(table, partitionColumn, partitionValue); + Path path = Paths.get("presto-tpch", "src", "main", "resources", "tpch", "statistics", schemaName, filename + ".json"); + writeStatistics(path, statisticsData); + } + + private void writeStatistics(Path path, TableStatisticsData tableStatisticsData) + { + File file = path.toFile(); + file.getParentFile().mkdirs(); + try { + objectMapper + .writerWithDefaultPrettyPrinter() + .writeValue(file, tableStatisticsData); + try (FileWriter fileWriter = new FileWriter(file, true)) { + fileWriter.append('\n'); + } + } + catch (IOException e) { + throw new RuntimeException("Could not save table statistics data", e); + } + } + + public TableStatisticsData load(String schemaName, TpchTable table, Optional> partitionColumn, Optional partitionValue) + { + String filename = tableStatisticsDataFilename(table, partitionColumn, partitionValue); + String resourcePath = "/tpch/statistics/" + schemaName + "/" + filename + ".json"; + return readStatistics(resourcePath); + } + + private TableStatisticsData readStatistics(String resourcePath) + { + URL resource = getClass().getResource(resourcePath); + try { + return objectMapper.readValue(resource, TableStatisticsData.class); + } + catch (Exception e) { + throw new RuntimeException(format("Failed to parse stats from resource [%s]", resourcePath), e); + } + } + + private String tableStatisticsDataFilename(TpchTable table, Optional> partitionColumn, Optional partitionValue) + { + Optional partitionDescription = getPartitionDescription(partitionColumn, partitionValue); + return table.getTableName() + partitionDescription.map(value -> "." + value).orElse(""); + } + + private Optional getPartitionDescription(Optional> partitionColumn, Optional partitionValue) + { + checkArgument(partitionColumn.isPresent() == partitionValue.isPresent()); + return withBoth(partitionColumn, partitionValue, (column, value) -> column.getColumnName() + "." + value); + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsRecorder.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsRecorder.java new file mode 100644 index 0000000000000..57b7820ae211b --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/statistics/TableStatisticsRecorder.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchColumnType; +import io.airlift.tpch.TpchEntity; +import io.airlift.tpch.TpchTable; + +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.String.format; +import static java.util.function.Function.identity; + +class TableStatisticsRecorder +{ + public TableStatisticsData recordStatistics(TpchTable tpchTable, Predicate constraint, double scaleFactor) + { + Iterable rows = tpchTable.createGenerator(scaleFactor, 1, 1); + return recordStatistics(rows, tpchTable.getColumns(), constraint); + } + + private TableStatisticsData recordStatistics(Iterable rows, List> columns, Predicate constraint) + { + Map, ColumnStatisticsRecorder> statisticsRecorders = createStatisticsRecorders(columns); + long rowCount = 0; + + for (E row : rows) { + if (constraint.test(row)) { + rowCount++; + for (TpchColumn column : columns) { + Comparable value = getTpchValue(row, column); + statisticsRecorders.get(column).record(value); + } + } + } + + Map columnSampleStatistics = statisticsRecorders.entrySet().stream() + .collect(toImmutableMap( + e -> e.getKey().getColumnName(), + e -> e.getValue().getRecording() + )); + return new TableStatisticsData(rowCount, columnSampleStatistics); + } + + private Map, ColumnStatisticsRecorder> createStatisticsRecorders(List> columns) + { + return columns.stream() + .collect(toImmutableMap(identity(), (column) -> new ColumnStatisticsRecorder())); + } + + private Comparable getTpchValue(E row, TpchColumn column) + { + TpchColumnType.Base baseType = column.getType().getBase(); + switch (baseType) { + case IDENTIFIER: + return column.getIdentifier(row); + case INTEGER: + return column.getInteger(row); + case DATE: + return column.getDate(row); + case DOUBLE: + return column.getDouble(row); + case VARCHAR: + return column.getString(row); + } + throw new UnsupportedOperationException(format("Unsupported TPCH base type [%s]", baseType)); + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Optionals.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Optionals.java new file mode 100644 index 0000000000000..40202c5317107 --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Optionals.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.util; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; + +import static com.google.common.base.Preconditions.checkArgument; + +public class Optionals +{ + private Optionals() {} + + public static Optional checkPresent(Optional optional) + { + checkArgument(optional.isPresent(), "Expected a present optional, got empty()"); + return optional; + } + + public static Optional withBoth(Optional left, Optional right, BiFunction binaryFunction) + { + return left.flatMap(l -> right.map(r -> binaryFunction.apply(l, r))); + } + + public static Optional combine(Optional left, Optional right, BinaryOperator combiner) + { + if (left.isPresent() && right.isPresent()) { + return Optional.of(combiner.apply(left.get(), right.get())); + } + else if (left.isPresent()) { + return left; + } + else { + return right; + } + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Types.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Types.java new file mode 100644 index 0000000000000..957956140abe9 --- /dev/null +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/util/Types.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; + +public class Types +{ + private Types() {} + + public static T checkType(Object object, Class expectedClass) + { + return checkType(object, expectedClass, "Expected an object of type [%s]", expectedClass.getCanonicalName()); + } + + public static T checkType(Object object, Class expectedClass, String messageTemplate, Object... arguments) + { + checkArgument(expectedClass.isInstance(object), messageTemplate, arguments); + return expectedClass.cast(object); + } + + public static void checkSameType(Object left, Object right) + { + Class leftClass = left.getClass(); + Class rightClass = right.getClass(); + String message = "Values must be of same type, got [%s : %s] and [%s : %s]"; + checkArgument(leftClass.equals(rightClass), message, left, leftClass, right, rightClass); + } + + public static Optional tryCast(Optional optional, Class targetClass) + { + return optional + .filter(targetClass::isInstance) + .map(targetClass::cast); + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/customer.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/customer.json new file mode 100644 index 0000000000000..e0faa16e98d62 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/customer.json @@ -0,0 +1,53 @@ +{ + "rowCount" : 150000, + "columns" : { + "c_custkey" : { + "distinctValuesCount" : 150000, + "nullsCount" : 0, + "min" : 1, + "max" : 150000 + }, + "c_name" : { + "distinctValuesCount" : 150000, + "nullsCount" : 0, + "min" : "Customer#000000001", + "max" : "Customer#000150000" + }, + "c_address" : { + "distinctValuesCount" : 150000, + "nullsCount" : 0, + "min" : " 2uZwVhQvwA", + "max" : "zzxGktzXTMKS1BxZlgQ9nqQ" + }, + "c_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "c_phone" : { + "distinctValuesCount" : 150000, + "nullsCount" : 0, + "min" : "10-100-106-1617", + "max" : "34-999-618-6881" + }, + "c_acctbal" : { + "distinctValuesCount" : 140187, + "nullsCount" : 0, + "min" : -999.99, + "max" : 9999.99 + }, + "c_mktsegment" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "AUTOMOBILE", + "max" : "MACHINERY" + }, + "c_comment" : { + "distinctValuesCount" : 149968, + "nullsCount" : 0, + "min" : " Tiresias according to the slyly blithe instructions detect quickly at the slyly express courts. express dinos wake ", + "max" : "zzle. blithely regular instructions cajol" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/lineitem.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/lineitem.json new file mode 100644 index 0000000000000..bdb9f733fbfec --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/lineitem.json @@ -0,0 +1,101 @@ +{ + "rowCount" : 6001215, + "columns" : { + "l_orderkey" : { + "distinctValuesCount" : 1500000, + "nullsCount" : 0, + "min" : 1, + "max" : 6000000 + }, + "l_partkey" : { + "distinctValuesCount" : 200000, + "nullsCount" : 0, + "min" : 1, + "max" : 200000 + }, + "l_suppkey" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : 1, + "max" : 10000 + }, + "l_linenumber" : { + "distinctValuesCount" : 7, + "nullsCount" : 0, + "min" : 1, + "max" : 7 + }, + "l_quantity" : { + "distinctValuesCount" : 50, + "nullsCount" : 0, + "min" : 1.0, + "max" : 50.0 + }, + "l_extendedprice" : { + "distinctValuesCount" : 933900, + "nullsCount" : 0, + "min" : 901.0, + "max" : 104949.5 + }, + "l_discount" : { + "distinctValuesCount" : 11, + "nullsCount" : 0, + "min" : 0.0, + "max" : 0.1 + }, + "l_tax" : { + "distinctValuesCount" : 9, + "nullsCount" : 0, + "min" : 0.0, + "max" : 0.08 + }, + "l_returnflag" : { + "distinctValuesCount" : 3, + "nullsCount" : 0, + "min" : "A", + "max" : "R" + }, + "l_linestatus" : { + "distinctValuesCount" : 2, + "nullsCount" : 0, + "min" : "F", + "max" : "O" + }, + "l_shipdate" : { + "distinctValuesCount" : 2526, + "nullsCount" : 0, + "min" : 8036, + "max" : 10561 + }, + "l_commitdate" : { + "distinctValuesCount" : 2466, + "nullsCount" : 0, + "min" : 8065, + "max" : 10530 + }, + "l_receiptdate" : { + "distinctValuesCount" : 2554, + "nullsCount" : 0, + "min" : 8038, + "max" : 10591 + }, + "l_shipinstruct" : { + "distinctValuesCount" : 4, + "nullsCount" : 0, + "min" : "COLLECT COD", + "max" : "TAKE BACK RETURN" + }, + "l_shipmode" : { + "distinctValuesCount" : 7, + "nullsCount" : 0, + "min" : "AIR", + "max" : "TRUCK" + }, + "l_comment" : { + "distinctValuesCount" : 4580667, + "nullsCount" : 0, + "min" : " Tiresias ", + "max" : "zzle? slyly final platelets sleep quickly. " + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/nation.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/nation.json new file mode 100644 index 0000000000000..3feb6f136d7e4 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/nation.json @@ -0,0 +1,29 @@ +{ + "rowCount" : 25, + "columns" : { + "n_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "n_name" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : "ALGERIA", + "max" : "VIETNAM" + }, + "n_regionkey" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : 0, + "max" : 4 + }, + "n_comment" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : " haggle. carefully final deposits detect slyly agai", + "max" : "y final packages. slow foxes cajole quickly. quickly silent platelets breach ironic accounts. unusual pinto be" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.json new file mode 100644 index 0000000000000..6605bfad7852e --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 1500000, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 1500000, + "nullsCount" : 0, + "min" : 1, + "max" : 6000000 + }, + "o_custkey" : { + "distinctValuesCount" : 99996, + "nullsCount" : 0, + "min" : 1, + "max" : 149999 + }, + "o_orderstatus" : { + "distinctValuesCount" : 3, + "nullsCount" : 0, + "min" : "F", + "max" : "P" + }, + "o_totalprice" : { + "distinctValuesCount" : 1464556, + "nullsCount" : 0, + "min" : 857.71, + "max" : 555285.16 + }, + "o_orderdate" : { + "distinctValuesCount" : 2406, + "nullsCount" : 0, + "min" : 8035, + "max" : 10440 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 1482071, + "nullsCount" : 0, + "min" : " Tiresias about the blithely ironic a", + "max" : "zzle? furiously ironic instructions among the unusual t" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.F.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.F.json new file mode 100644 index 0000000000000..8d7a848336b9a --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.F.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 729413, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 729413, + "nullsCount" : 0, + "min" : 3, + "max" : 5999975 + }, + "o_custkey" : { + "distinctValuesCount" : 99609, + "nullsCount" : 0, + "min" : 1, + "max" : 149999 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "F", + "max" : "F" + }, + "o_totalprice" : { + "distinctValuesCount" : 720822, + "nullsCount" : 0, + "min" : 866.9, + "max" : 555285.16 + }, + "o_orderdate" : { + "distinctValuesCount" : 1261, + "nullsCount" : 0, + "min" : 8035, + "max" : 9296 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 724600, + "nullsCount" : 0, + "min" : " Tiresias above the carefully ironic packages nag about the pend", + "max" : "zzle; ironic accounts affix slyly regular pinto b" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.O.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.O.json new file mode 100644 index 0000000000000..4dec279848ee7 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.O.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 732044, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 732044, + "nullsCount" : 0, + "min" : 1, + "max" : 6000000 + }, + "o_custkey" : { + "distinctValuesCount" : 99621, + "nullsCount" : 0, + "min" : 1, + "max" : 149999 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "O", + "max" : "O" + }, + "o_totalprice" : { + "distinctValuesCount" : 723368, + "nullsCount" : 0, + "min" : 857.71, + "max" : 530604.44 + }, + "o_orderdate" : { + "distinctValuesCount" : 1262, + "nullsCount" : 0, + "min" : 9178, + "max" : 10440 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 727175, + "nullsCount" : 0, + "min" : " Tiresias about the blithely ironic a", + "max" : "zzle? furiously ironic instructions among the unusual t" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.P.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.P.json new file mode 100644 index 0000000000000..a265fedfafcb5 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/orders.o_orderstatus.P.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 38543, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 38543, + "nullsCount" : 0, + "min" : 65, + "max" : 5999875 + }, + "o_custkey" : { + "distinctValuesCount" : 31310, + "nullsCount" : 0, + "min" : 2, + "max" : 149998 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "P", + "max" : "P" + }, + "o_totalprice" : { + "distinctValuesCount" : 38515, + "nullsCount" : 0, + "min" : 2933.43, + "max" : 491549.57 + }, + "o_orderdate" : { + "distinctValuesCount" : 120, + "nullsCount" : 0, + "min" : 9178, + "max" : 9297 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 38531, + "nullsCount" : 0, + "min" : " Tiresias haggle slyly bli", + "max" : "zzle furiously. bold packa" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/part.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/part.json new file mode 100644 index 0000000000000..c22d18787ce9d --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/part.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 200000, + "columns" : { + "p_partkey" : { + "distinctValuesCount" : 200000, + "nullsCount" : 0, + "min" : 1, + "max" : 200000 + }, + "p_name" : { + "distinctValuesCount" : 199997, + "nullsCount" : 0, + "min" : "almond antique blue royal burnished", + "max" : "yellow white seashell lavender black" + }, + "p_mfgr" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "Manufacturer#1", + "max" : "Manufacturer#5" + }, + "p_brand" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : "Brand#11", + "max" : "Brand#55" + }, + "p_type" : { + "distinctValuesCount" : 150, + "nullsCount" : 0, + "min" : "ECONOMY ANODIZED BRASS", + "max" : "STANDARD POLISHED TIN" + }, + "p_size" : { + "distinctValuesCount" : 50, + "nullsCount" : 0, + "min" : 1, + "max" : 50 + }, + "p_container" : { + "distinctValuesCount" : 40, + "nullsCount" : 0, + "min" : "JUMBO BAG", + "max" : "WRAP PKG" + }, + "p_retailprice" : { + "distinctValuesCount" : 20899, + "nullsCount" : 0, + "min" : 901.0, + "max" : 2098.99 + }, + "p_comment" : { + "distinctValuesCount" : 131753, + "nullsCount" : 0, + "min" : " Tire", + "max" : "zzle. quickly si" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/partsupp.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/partsupp.json new file mode 100644 index 0000000000000..dc608b7a6970b --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/partsupp.json @@ -0,0 +1,35 @@ +{ + "rowCount" : 800000, + "columns" : { + "ps_partkey" : { + "distinctValuesCount" : 200000, + "nullsCount" : 0, + "min" : 1, + "max" : 200000 + }, + "ps_suppkey" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : 1, + "max" : 10000 + }, + "ps_availqty" : { + "distinctValuesCount" : 9999, + "nullsCount" : 0, + "min" : 1, + "max" : 9999 + }, + "ps_supplycost" : { + "distinctValuesCount" : 99865, + "nullsCount" : 0, + "min" : 1.0, + "max" : 1000.0 + }, + "ps_comment" : { + "distinctValuesCount" : 799124, + "nullsCount" : 0, + "min" : " Tiresias according to the quiet courts sleep against the ironic, final requests. carefully unusual requests affix fluffily quickly ironic packages. regular ", + "max" : "zzle. unusual decoys detect slyly blithely express frays. furiously ironic packages about the bold accounts are close requests. slowly silent reque" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/region.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/region.json new file mode 100644 index 0000000000000..5b9fb841154ef --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/region.json @@ -0,0 +1,23 @@ +{ + "rowCount" : 5, + "columns" : { + "r_regionkey" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : 0, + "max" : 4 + }, + "r_name" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "AFRICA", + "max" : "MIDDLE EAST" + }, + "r_comment" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "ges. thinly even pinto beans ca", + "max" : "uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/sf1/supplier.json b/presto-tpch/src/main/resources/tpch/statistics/sf1/supplier.json new file mode 100644 index 0000000000000..65da24956973f --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/sf1/supplier.json @@ -0,0 +1,47 @@ +{ + "rowCount" : 10000, + "columns" : { + "s_suppkey" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : 1, + "max" : 10000 + }, + "s_name" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : "Supplier#000000001", + "max" : "Supplier#000010000" + }, + "s_address" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : " 9aW1wwnBJJPnCx,nox0MA48Y0zpI1IeVfYZ", + "max" : "zzfDhdtZcvmVzA8rNFU,Yctj1zBN" + }, + "s_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "s_phone" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : "10-102-116-6785", + "max" : "34-998-900-4911" + }, + "s_acctbal" : { + "distinctValuesCount" : 9955, + "nullsCount" : 0, + "min" : -998.22, + "max" : 9999.72 + }, + "s_comment" : { + "distinctValuesCount" : 10000, + "nullsCount" : 0, + "min" : " about the blithely express foxes. bli", + "max" : "zzle furiously. bold accounts haggle furiously ironic excuses. fur" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/customer.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/customer.json new file mode 100644 index 0000000000000..89a8e2286bdb3 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/customer.json @@ -0,0 +1,53 @@ +{ + "rowCount" : 1500, + "columns" : { + "c_custkey" : { + "distinctValuesCount" : 1500, + "nullsCount" : 0, + "min" : 1, + "max" : 1500 + }, + "c_name" : { + "distinctValuesCount" : 1500, + "nullsCount" : 0, + "min" : "Customer#000000001", + "max" : "Customer#000001500" + }, + "c_address" : { + "distinctValuesCount" : 1500, + "nullsCount" : 0, + "min" : " ,cIZ,06Kg", + "max" : "zyWvi,SGc,tXTls" + }, + "c_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "c_phone" : { + "distinctValuesCount" : 1500, + "nullsCount" : 0, + "min" : "10-109-430-5638", + "max" : "34-992-529-2023" + }, + "c_acctbal" : { + "distinctValuesCount" : 1499, + "nullsCount" : 0, + "min" : -994.79, + "max" : 9987.71 + }, + "c_mktsegment" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "AUTOMOBILE", + "max" : "MACHINERY" + }, + "c_comment" : { + "distinctValuesCount" : 1500, + "nullsCount" : 0, + "min" : " about the carefully ironic pinto beans. accoun", + "max" : "ymptotes. ironic, unusual notornis wake after the ironic, special deposits. blithely fina" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/lineitem.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/lineitem.json new file mode 100644 index 0000000000000..ef2ead026bc73 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/lineitem.json @@ -0,0 +1,101 @@ +{ + "rowCount" : 60175, + "columns" : { + "l_orderkey" : { + "distinctValuesCount" : 15000, + "nullsCount" : 0, + "min" : 1, + "max" : 60000 + }, + "l_partkey" : { + "distinctValuesCount" : 2000, + "nullsCount" : 0, + "min" : 1, + "max" : 2000 + }, + "l_suppkey" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : 1, + "max" : 100 + }, + "l_linenumber" : { + "distinctValuesCount" : 7, + "nullsCount" : 0, + "min" : 1, + "max" : 7 + }, + "l_quantity" : { + "distinctValuesCount" : 50, + "nullsCount" : 0, + "min" : 1.0, + "max" : 50.0 + }, + "l_extendedprice" : { + "distinctValuesCount" : 35921, + "nullsCount" : 0, + "min" : 904.0, + "max" : 94949.5 + }, + "l_discount" : { + "distinctValuesCount" : 11, + "nullsCount" : 0, + "min" : 0.0, + "max" : 0.1 + }, + "l_tax" : { + "distinctValuesCount" : 9, + "nullsCount" : 0, + "min" : 0.0, + "max" : 0.08 + }, + "l_returnflag" : { + "distinctValuesCount" : 3, + "nullsCount" : 0, + "min" : "A", + "max" : "R" + }, + "l_linestatus" : { + "distinctValuesCount" : 2, + "nullsCount" : 0, + "min" : "F", + "max" : "O" + }, + "l_shipdate" : { + "distinctValuesCount" : 2518, + "nullsCount" : 0, + "min" : 8038, + "max" : 10559 + }, + "l_commitdate" : { + "distinctValuesCount" : 2460, + "nullsCount" : 0, + "min" : 8067, + "max" : 10527 + }, + "l_receiptdate" : { + "distinctValuesCount" : 2529, + "nullsCount" : 0, + "min" : 8043, + "max" : 10585 + }, + "l_shipinstruct" : { + "distinctValuesCount" : 4, + "nullsCount" : 0, + "min" : "COLLECT COD", + "max" : "TAKE BACK RETURN" + }, + "l_shipmode" : { + "distinctValuesCount" : 7, + "nullsCount" : 0, + "min" : "AIR", + "max" : "TRUCK" + }, + "l_comment" : { + "distinctValuesCount" : 58616, + "nullsCount" : 0, + "min" : " Tiresias ", + "max" : "zzle: pending i" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/nation.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/nation.json new file mode 100644 index 0000000000000..3feb6f136d7e4 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/nation.json @@ -0,0 +1,29 @@ +{ + "rowCount" : 25, + "columns" : { + "n_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "n_name" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : "ALGERIA", + "max" : "VIETNAM" + }, + "n_regionkey" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : 0, + "max" : 4 + }, + "n_comment" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : " haggle. carefully final deposits detect slyly agai", + "max" : "y final packages. slow foxes cajole quickly. quickly silent platelets breach ironic accounts. unusual pinto be" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.json new file mode 100644 index 0000000000000..2eb34e70809e1 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 15000, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 15000, + "nullsCount" : 0, + "min" : 1, + "max" : 60000 + }, + "o_custkey" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : 1, + "max" : 1499 + }, + "o_orderstatus" : { + "distinctValuesCount" : 3, + "nullsCount" : 0, + "min" : "F", + "max" : "P" + }, + "o_totalprice" : { + "distinctValuesCount" : 14996, + "nullsCount" : 0, + "min" : 874.89, + "max" : 466001.28 + }, + "o_orderdate" : { + "distinctValuesCount" : 2401, + "nullsCount" : 0, + "min" : 8035, + "max" : 10440 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 14995, + "nullsCount" : 0, + "min" : " about the accounts. slyly express accounts wa", + "max" : "zzle. carefully enticing deposits nag furio" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.F.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.F.json new file mode 100644 index 0000000000000..96f16239eab71 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.F.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 7304, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 7304, + "nullsCount" : 0, + "min" : 3, + "max" : 59975 + }, + "o_custkey" : { + "distinctValuesCount" : 996, + "nullsCount" : 0, + "min" : 1, + "max" : 1499 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "F", + "max" : "F" + }, + "o_totalprice" : { + "distinctValuesCount" : 7303, + "nullsCount" : 0, + "min" : 874.89, + "max" : 408345.74 + }, + "o_orderdate" : { + "distinctValuesCount" : 1213, + "nullsCount" : 0, + "min" : 8035, + "max" : 9277 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 7303, + "nullsCount" : 0, + "min" : " about the blithely ironic requests. fluffily ironic ", + "max" : "zzle final, final dependencies. final, final accounts are blith" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.O.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.O.json new file mode 100644 index 0000000000000..7c7658c774a25 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.O.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 7333, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 7333, + "nullsCount" : 0, + "min" : 1, + "max" : 59974 + }, + "o_custkey" : { + "distinctValuesCount" : 998, + "nullsCount" : 0, + "min" : 1, + "max" : 1499 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "O", + "max" : "O" + }, + "o_totalprice" : { + "distinctValuesCount" : 7331, + "nullsCount" : 0, + "min" : 974.04, + "max" : 466001.28 + }, + "o_orderdate" : { + "distinctValuesCount" : 1213, + "nullsCount" : 0, + "min" : 9197, + "max" : 10440 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 1000, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 7333, + "nullsCount" : 0, + "min" : " about the accounts. slyly express accounts wa", + "max" : "zzle. carefully enticing deposits nag furio" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.P.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.P.json new file mode 100644 index 0000000000000..33c99f2479410 --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/orders.o_orderstatus.P.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 363, + "columns" : { + "o_orderkey" : { + "distinctValuesCount" : 363, + "nullsCount" : 0, + "min" : 65, + "max" : 60000 + }, + "o_custkey" : { + "distinctValuesCount" : 304, + "nullsCount" : 0, + "min" : 16, + "max" : 1499 + }, + "o_orderstatus" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : "P", + "max" : "P" + }, + "o_totalprice" : { + "distinctValuesCount" : 363, + "nullsCount" : 0, + "min" : 16145.49, + "max" : 376904.18 + }, + "o_orderdate" : { + "distinctValuesCount" : 107, + "nullsCount" : 0, + "min" : 9182, + "max" : 9292 + }, + "o_orderpriority" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "1-URGENT", + "max" : "5-LOW" + }, + "o_clerk" : { + "distinctValuesCount" : 310, + "nullsCount" : 0, + "min" : "Clerk#000000001", + "max" : "Clerk#000001000" + }, + "o_shippriority" : { + "distinctValuesCount" : 1, + "nullsCount" : 0, + "min" : 0, + "max" : 0 + }, + "o_comment" : { + "distinctValuesCount" : 363, + "nullsCount" : 0, + "min" : " according to the final asymptotes. carefully silent de", + "max" : "yly pending platelets sleep c" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/part.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/part.json new file mode 100644 index 0000000000000..454c3f4478a2e --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/part.json @@ -0,0 +1,59 @@ +{ + "rowCount" : 2000, + "columns" : { + "p_partkey" : { + "distinctValuesCount" : 2000, + "nullsCount" : 0, + "min" : 1, + "max" : 2000 + }, + "p_name" : { + "distinctValuesCount" : 2000, + "nullsCount" : 0, + "min" : "almond aquamarine mint misty red", + "max" : "yellow white puff orange rosy" + }, + "p_mfgr" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "Manufacturer#1", + "max" : "Manufacturer#5" + }, + "p_brand" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : "Brand#11", + "max" : "Brand#55" + }, + "p_type" : { + "distinctValuesCount" : 150, + "nullsCount" : 0, + "min" : "ECONOMY ANODIZED BRASS", + "max" : "STANDARD POLISHED TIN" + }, + "p_size" : { + "distinctValuesCount" : 50, + "nullsCount" : 0, + "min" : 1, + "max" : 50 + }, + "p_container" : { + "distinctValuesCount" : 40, + "nullsCount" : 0, + "min" : "JUMBO BAG", + "max" : "WRAP PKG" + }, + "p_retailprice" : { + "distinctValuesCount" : 1099, + "nullsCount" : 0, + "min" : 901.0, + "max" : 1900.99 + }, + "p_comment" : { + "distinctValuesCount" : 1959, + "nullsCount" : 0, + "min" : " about the furio", + "max" : "zzle among t" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/partsupp.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/partsupp.json new file mode 100644 index 0000000000000..371690094399b --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/partsupp.json @@ -0,0 +1,35 @@ +{ + "rowCount" : 8000, + "columns" : { + "ps_partkey" : { + "distinctValuesCount" : 2000, + "nullsCount" : 0, + "min" : 1, + "max" : 2000 + }, + "ps_suppkey" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : 1, + "max" : 100 + }, + "ps_availqty" : { + "distinctValuesCount" : 5497, + "nullsCount" : 0, + "min" : 3, + "max" : 9998 + }, + "ps_supplycost" : { + "distinctValuesCount" : 7665, + "nullsCount" : 0, + "min" : 1.05, + "max" : 999.99 + }, + "ps_comment" : { + "distinctValuesCount" : 8000, + "nullsCount" : 0, + "min" : " about the instructions. carefully final platelets cajole carefully furiously regular requests. carefully regular theodolites along the carefully regular pinto beans haggle about ", + "max" : "zzle blithely about the furiously final foxes. pen" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/region.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/region.json new file mode 100644 index 0000000000000..5b9fb841154ef --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/region.json @@ -0,0 +1,23 @@ +{ + "rowCount" : 5, + "columns" : { + "r_regionkey" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : 0, + "max" : 4 + }, + "r_name" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "AFRICA", + "max" : "MIDDLE EAST" + }, + "r_comment" : { + "distinctValuesCount" : 5, + "nullsCount" : 0, + "min" : "ges. thinly even pinto beans ca", + "max" : "uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl" + } + } +} diff --git a/presto-tpch/src/main/resources/tpch/statistics/tiny/supplier.json b/presto-tpch/src/main/resources/tpch/statistics/tiny/supplier.json new file mode 100644 index 0000000000000..497fc66357eea --- /dev/null +++ b/presto-tpch/src/main/resources/tpch/statistics/tiny/supplier.json @@ -0,0 +1,47 @@ +{ + "rowCount" : 100, + "columns" : { + "s_suppkey" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : 1, + "max" : 100 + }, + "s_name" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : "Supplier#000000001", + "max" : "Supplier#000000100" + }, + "s_address" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : " N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ", + "max" : "zyIeWzbbpkTV37vm1nmSGBxSgd2Kp" + }, + "s_nationkey" : { + "distinctValuesCount" : 25, + "nullsCount" : 0, + "min" : 0, + "max" : 24 + }, + "s_phone" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : "10-470-144-1330", + "max" : "34-876-912-6007" + }, + "s_acctbal" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : -966.2, + "max" : 9915.24 + }, + "s_comment" : { + "distinctValuesCount" : 100, + "nullsCount" : 0, + "min" : " across the furiously regular platelets wake even deposits. quickly express she", + "max" : "yly final accounts could are carefully. fluffily ironic instruct" + } + } +} diff --git a/presto-tpch/src/test/java/com/facebook/presto/tpch/EstimateAssertion.java b/presto-tpch/src/test/java/com/facebook/presto/tpch/EstimateAssertion.java new file mode 100644 index 0000000000000..fe68a372f4ea2 --- /dev/null +++ b/presto-tpch/src/test/java/com/facebook/presto/tpch/EstimateAssertion.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch; + +import com.facebook.presto.spi.statistics.Estimate; +import io.airlift.slice.Slice; + +import java.util.Optional; + +import static java.lang.String.format; +import static java.util.Optional.empty; +import static org.testng.Assert.assertEquals; + +class EstimateAssertion +{ + private final double tolerance; + + public EstimateAssertion(double tolerance) + { + this.tolerance = tolerance; + } + + public void assertClose(Estimate actual, Estimate expected, String message) + { + assertClose(toOptional(actual), toOptional(expected), message); + } + + private Optional toOptional(Estimate estimate) + { + return estimate.isValueUnknown() ? empty() : Optional.of(estimate.getValue()); + } + + public void assertClose(Optional actual, Optional expected, String message) + { + assertEquals(actual.isPresent(), expected.isPresent(), message); + if (actual.isPresent()) { + Object actualValue = actual.get(); + Object expectedValue = expected.get(); + assertClose(actualValue, expectedValue, message); + } + } + + private void assertClose(Object actual, Object expected, String message) + { + if (actual instanceof Slice) { + assertEquals(actual.getClass(), expected.getClass(), message); + assertEquals(((Slice) actual).toStringUtf8(), ((Slice) expected).toStringUtf8()); + } + else { + double actualDouble = toDouble(actual); + double expectedDouble = toDouble(expected); + assertEquals(actualDouble, expectedDouble, expectedDouble * tolerance, message); + } + } + + private double toDouble(Object object) + { + if (object instanceof Number) { + return ((Number) object).doubleValue(); + } + else { + String message = "Can't compare with tolerance objects of class %s. Use assertEquals."; + throw new UnsupportedOperationException(format(message, object.getClass())); + } + } +} diff --git a/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchMetadata.java b/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchMetadata.java new file mode 100644 index 0000000000000..25c7f7a2d9a8b --- /dev/null +++ b/presto-tpch/src/test/java/com/facebook/presto/tpch/TestTpchMetadata.java @@ -0,0 +1,351 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchTable; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.spi.Constraint.alwaysFalse; +import static com.facebook.presto.spi.Constraint.alwaysTrue; +import static com.facebook.presto.spi.predicate.TupleDomain.fromFixedValues; +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static com.facebook.presto.spi.statistics.Estimate.zeroValue; +import static com.facebook.presto.tpch.TpchRecordSet.convertToPredicate; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.tpch.CustomerColumn.ADDRESS; +import static io.airlift.tpch.CustomerColumn.CUSTOMER_KEY; +import static io.airlift.tpch.CustomerColumn.MARKET_SEGMENT; +import static io.airlift.tpch.CustomerColumn.NAME; +import static io.airlift.tpch.LineItemColumn.COMMIT_DATE; +import static io.airlift.tpch.LineItemColumn.DISCOUNT; +import static io.airlift.tpch.LineItemColumn.EXTENDED_PRICE; +import static io.airlift.tpch.LineItemColumn.LINE_NUMBER; +import static io.airlift.tpch.LineItemColumn.QUANTITY; +import static io.airlift.tpch.LineItemColumn.RECEIPT_DATE; +import static io.airlift.tpch.LineItemColumn.RETURN_FLAG; +import static io.airlift.tpch.LineItemColumn.SHIP_DATE; +import static io.airlift.tpch.LineItemColumn.SHIP_INSTRUCTIONS; +import static io.airlift.tpch.LineItemColumn.SHIP_MODE; +import static io.airlift.tpch.LineItemColumn.STATUS; +import static io.airlift.tpch.LineItemColumn.TAX; +import static io.airlift.tpch.NationColumn.NATION_KEY; +import static io.airlift.tpch.OrderColumn.CLERK; +import static io.airlift.tpch.OrderColumn.ORDER_DATE; +import static io.airlift.tpch.OrderColumn.ORDER_KEY; +import static io.airlift.tpch.OrderColumn.ORDER_PRIORITY; +import static io.airlift.tpch.OrderColumn.ORDER_STATUS; +import static io.airlift.tpch.OrderColumn.SHIP_PRIORITY; +import static io.airlift.tpch.OrderColumn.TOTAL_PRICE; +import static io.airlift.tpch.PartColumn.BRAND; +import static io.airlift.tpch.PartColumn.CONTAINER; +import static io.airlift.tpch.PartColumn.MANUFACTURER; +import static io.airlift.tpch.PartColumn.PART_KEY; +import static io.airlift.tpch.PartColumn.RETAIL_PRICE; +import static io.airlift.tpch.PartColumn.SIZE; +import static io.airlift.tpch.PartColumn.TYPE; +import static io.airlift.tpch.PartSupplierColumn.AVAILABLE_QUANTITY; +import static io.airlift.tpch.PartSupplierColumn.COMMENT; +import static io.airlift.tpch.RegionColumn.REGION_KEY; +import static io.airlift.tpch.SupplierColumn.ACCOUNT_BALANCE; +import static io.airlift.tpch.SupplierColumn.SUPPLIER_KEY; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; +import static java.util.Arrays.stream; +import static java.util.Optional.empty; +import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertEquals; + +public class TestTpchMetadata +{ + private static final double TOLERANCE = 0.01; + + private final TpchMetadata tpchMetadata = new TpchMetadata("tpch"); + private final ConnectorSession session = null; + + @Test + public void testTableStats() + { + TpchMetadata.SCHEMA_NAMES.forEach(schema -> { + double scaleFactor = TpchMetadata.schemaNameToScaleFactor(schema); + + testTableStats(schema, REGION, 5); + testTableStats(schema, NATION, 25); + testTableStats(schema, SUPPLIER, 10_000 * scaleFactor); + testTableStats(schema, CUSTOMER, 150_000 * scaleFactor); + testTableStats(schema, PART, 200_000 * scaleFactor); + testTableStats(schema, PART_SUPPLIER, 800_000 * scaleFactor); + testTableStats(schema, ORDERS, 1_500_000 * scaleFactor); + testTableStats(schema, LINE_ITEM, 6_000_000 * scaleFactor); + }); + } + + @Test + public void testTableStatsWithConstraints() + { + TpchMetadata.SCHEMA_NAMES.forEach(schema -> { + double scaleFactor = TpchMetadata.schemaNameToScaleFactor(schema); + + testTableStats(schema, ORDERS, alwaysFalse(), 0); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "NO SUCH STATUS"), 0); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "F"), 730_400 * scaleFactor); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "O"), 733_300 * scaleFactor); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "P"), 38_543 * scaleFactor); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "F", "NO SUCH STATUS"), 730_400 * scaleFactor); + testTableStats(schema, ORDERS, constraint(ORDER_STATUS, "F", "O", "P"), 1_500_000 * scaleFactor); + }); + } + + private void testTableStats(String schema, TpchTable table, double expectedRowCount) + { + testTableStats(schema, table, alwaysTrue(), expectedRowCount); + } + + private void testTableStats(String schema, TpchTable table, Constraint constraint, double expectedRowCount) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(session, new SchemaTableName(schema, table.getTableName())); + TableStatistics tableStatistics = tpchMetadata.getTableStatistics(session, tableHandle, constraint); + + double actualRowCountValue = tableStatistics.getRowCount().getValue(); + assertEquals(tableStatistics.getTableStatistics(), ImmutableMap.of("row_count", new Estimate(actualRowCountValue))); + assertEquals(actualRowCountValue, expectedRowCount, expectedRowCount * TOLERANCE); + } + + @Test + public void testColumnStats() + { + TpchMetadata.SCHEMA_NAMES.forEach(schema -> { + double scaleFactor = TpchMetadata.schemaNameToScaleFactor(schema); + + //id columns + testColumnStats(schema, REGION, REGION_KEY, columnStatistics(5, 0, 4)); + testColumnStats(schema, NATION, NATION_KEY, columnStatistics(25, 0, 24)); + testColumnStats(schema, SUPPLIER, SUPPLIER_KEY, columnStatistics(10_000 * scaleFactor, 1, 10_000 * scaleFactor)); + testColumnStats(schema, CUSTOMER, CUSTOMER_KEY, columnStatistics(150_000 * scaleFactor, 1, 150_000 * scaleFactor)); + testColumnStats(schema, PART, PART_KEY, columnStatistics(200_000 * scaleFactor, 1, 200_000 * scaleFactor)); + testColumnStats(schema, ORDERS, ORDER_KEY, columnStatistics(1_500_000 * scaleFactor, 1, 6_000_000 * scaleFactor)); + + //foreign keys to dictionary identifier columns + testColumnStats(schema, NATION, REGION_KEY, columnStatistics(5, 0, 4)); + testColumnStats(schema, SUPPLIER, NATION_KEY, columnStatistics(25, 0, 24)); + + //foreign keys to scalable identifier columns + testColumnStats(schema, PART_SUPPLIER, SUPPLIER_KEY, columnStatistics(10_000 * scaleFactor, 1, 10_000 * scaleFactor)); + testColumnStats(schema, PART_SUPPLIER, PART_KEY, columnStatistics(200_000 * scaleFactor, 1, 200_000 * scaleFactor)); + + //semi-uniquely valued varchar columns + testColumnStats(schema, PART_SUPPLIER, COMMENT, columnStatistics(800_000 * scaleFactor)); + testColumnStats(schema, CUSTOMER, NAME, columnStatistics(150_000 * scaleFactor)); + testColumnStats(schema, CUSTOMER, ADDRESS, columnStatistics(150_000 * scaleFactor)); + testColumnStats(schema, CUSTOMER, COMMENT, columnStatistics(150_000 * scaleFactor)); + + //non-scalable columns: + //dictionaries: + testColumnStats(schema, CUSTOMER, MARKET_SEGMENT, columnStatistics(5, "AUTOMOBILE", "MACHINERY")); + testColumnStats(schema, ORDERS, CLERK, columnStatistics(1000, "Clerk#000000001", "Clerk#000001000")); + testColumnStats(schema, ORDERS, ORDER_STATUS, columnStatistics(3, "F", "P")); + testColumnStats(schema, ORDERS, ORDER_PRIORITY, columnStatistics(5, "1-URGENT", "5-LOW")); + testColumnStats(schema, PART, BRAND, columnStatistics(25, "Brand#11", "Brand#55")); + testColumnStats(schema, PART, CONTAINER, columnStatistics(40, "JUMBO BAG", "WRAP PKG")); + testColumnStats(schema, PART, MANUFACTURER, columnStatistics(5, "Manufacturer#1", "Manufacturer#5")); + testColumnStats(schema, PART, SIZE, columnStatistics(50, 1, 50)); + testColumnStats(schema, PART, TYPE, columnStatistics(150, "ECONOMY ANODIZED BRASS", "STANDARD POLISHED TIN")); + testColumnStats(schema, LINE_ITEM, RETURN_FLAG, columnStatistics(3, "A", "R")); + testColumnStats(schema, LINE_ITEM, SHIP_INSTRUCTIONS, columnStatistics(4, "COLLECT COD", "TAKE BACK RETURN")); + testColumnStats(schema, LINE_ITEM, SHIP_MODE, columnStatistics(7, "AIR", "TRUCK")); + testColumnStats(schema, LINE_ITEM, STATUS, columnStatistics(2, "F", "O")); + + //low-valued numeric columns + testColumnStats(schema, ORDERS, SHIP_PRIORITY, columnStatistics(1, 0, 0)); + testColumnStats(schema, LINE_ITEM, LINE_NUMBER, columnStatistics(7, 1, 7)); + testColumnStats(schema, LINE_ITEM, QUANTITY, columnStatistics(50, 1, 50)); + testColumnStats(schema, LINE_ITEM, DISCOUNT, columnStatistics(11, 0, 0.1)); + testColumnStats(schema, LINE_ITEM, TAX, columnStatistics(9, 0, 0.08)); + + //dates: + testColumnStats(schema, ORDERS, ORDER_DATE, columnStatistics(2_400, 8_035, 10_440)); + testColumnStats(schema, LINE_ITEM, COMMIT_DATE, columnStatistics(2_450, 8_035, 10_500)); + testColumnStats(schema, LINE_ITEM, SHIP_DATE, columnStatistics(2_525, 8_035, 10_500)); + testColumnStats(schema, LINE_ITEM, RECEIPT_DATE, columnStatistics(2_550, 8_035, 10_500)); + + //AVAILABLE_QUANTITY and all money-related columns have quite visible non-scalable min and max + //but their ndv reaches a plateau for bigger SFs because of the data type used + //for this reason, those can't be estimated easily + testColumnStats(schema, PART_SUPPLIER, AVAILABLE_QUANTITY, unknownStatistics()); + testColumnStats(schema, PART, RETAIL_PRICE, unknownStatistics()); + testColumnStats(schema, LINE_ITEM, EXTENDED_PRICE, unknownStatistics()); + testColumnStats(schema, ORDERS, TOTAL_PRICE, unknownStatistics()); + testColumnStats(schema, SUPPLIER, ACCOUNT_BALANCE, unknownStatistics()); + testColumnStats(schema, CUSTOMER, ACCOUNT_BALANCE, unknownStatistics()); + }); + } + + @Test + public void testColumnStatsWithConstraints() + { + TpchMetadata.SCHEMA_NAMES.forEach(schema -> { + double scaleFactor = TpchMetadata.schemaNameToScaleFactor(schema); + + //value count, min and max are supported for the constrained column + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "F"), columnStatistics(1, "F", "F")); + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "O"), columnStatistics(1, "O", "O")); + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "P"), columnStatistics(1, "P", "P")); + + //only min and max values for non-scaling columns can be estimated for non-constrained columns + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "F"), rangeStatistics(3, 6_000_000 * scaleFactor)); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "O"), rangeStatistics(1, 6_000_000 * scaleFactor)); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "P"), rangeStatistics(65, 6_000_000 * scaleFactor)); + testColumnStats(schema, ORDERS, CLERK, constraint(ORDER_STATUS, "O"), rangeStatistics("Clerk#000000001", "Clerk#000001000")); + testColumnStats(schema, ORDERS, COMMENT, constraint(ORDER_STATUS, "O"), unknownStatistics()); + + //nothing can be said for always false constraints + testColumnStats(schema, ORDERS, ORDER_STATUS, alwaysFalse(), columnStatistics(0)); + testColumnStats(schema, ORDERS, ORDER_KEY, alwaysFalse(), columnStatistics(0)); + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "NO SUCH STATUS"), columnStatistics(0)); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "NO SUCH STATUS"), columnStatistics(0)); + + //unmodified stats are returned for the always true constraint + testColumnStats(schema, ORDERS, ORDER_STATUS, alwaysTrue(), columnStatistics(3, "F", "P")); + testColumnStats(schema, ORDERS, ORDER_KEY, alwaysTrue(), columnStatistics(1_500_000 * scaleFactor, 1, 6_000_000 * scaleFactor)); + + //constraints on columns other than ORDER_STATUS are not supported and are ignored + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(CLERK, "NO SUCH CLERK"), columnStatistics(3, "F", "P")); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(CLERK, "Clerk#000000001"), columnStatistics(1_500_000 * scaleFactor, 1, 6_000_000 * scaleFactor)); + + //compound constraints are supported + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "F", "NO SUCH STATUS"), columnStatistics(1, "F", "F")); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "F", "NO SUCH STATUS"), rangeStatistics(3, 6_000_000 * scaleFactor)); + + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "F", "O"), columnStatistics(2, "F", "O")); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "F", "O"), rangeStatistics(1, 6_000_000 * scaleFactor)); + + testColumnStats(schema, ORDERS, ORDER_STATUS, constraint(ORDER_STATUS, "F", "O", "P"), columnStatistics(3, "F", "P")); + testColumnStats(schema, ORDERS, ORDER_KEY, constraint(ORDER_STATUS, "F", "O", "P"), columnStatistics(1_500_000 * scaleFactor, 1, 6_000_000 * scaleFactor)); + }); + } + + private void testColumnStats(String schema, TpchTable table, TpchColumn column, ColumnStatistics expectedStatistics) + { + testColumnStats(schema, table, column, alwaysTrue(), expectedStatistics); + } + + private void testColumnStats(String schema, TpchTable table, TpchColumn column, Constraint constraint, ColumnStatistics expected) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(session, new SchemaTableName(schema, table.getTableName())); + TableStatistics tableStatistics = tpchMetadata.getTableStatistics(session, tableHandle, constraint); + ColumnHandle columnHandle = tpchMetadata.getColumnHandles(session, tableHandle).get(column.getSimplifiedColumnName()); + + ColumnStatistics actual = tableStatistics.getColumnStatistics().get(columnHandle); + + EstimateAssertion estimateAssertion = new EstimateAssertion(TOLERANCE); + + estimateAssertion.assertClose( + actual.getOnlyRangeColumnStatistics().getDistinctValuesCount(), + expected.getOnlyRangeColumnStatistics().getDistinctValuesCount(), + "distinctValuesCount-s differ"); + estimateAssertion.assertClose( + actual.getOnlyRangeColumnStatistics().getDataSize(), + expected.getOnlyRangeColumnStatistics().getDataSize(), + "dataSize-s differ"); + estimateAssertion.assertClose( + actual.getNullsFraction(), + expected.getNullsFraction(), + "nullsFraction-s differ"); + estimateAssertion.assertClose( + actual.getOnlyRangeColumnStatistics().getLowValue(), + expected.getOnlyRangeColumnStatistics().getLowValue(), + "lowValue-s differ"); + estimateAssertion.assertClose( + actual.getOnlyRangeColumnStatistics().getHighValue(), + expected.getOnlyRangeColumnStatistics().getHighValue(), + "highValue-s differ"); + } + + private Constraint constraint(TpchColumn column, String... values) + { + List> valueDomains = stream(values) + .map(value -> fromFixedValues(valueBinding(column, value))) + .collect(toList()); + TupleDomain domain = TupleDomain.columnWiseUnion(valueDomains); + return new Constraint<>(domain, convertToPredicate(domain)); + } + + private ImmutableMap valueBinding(TpchColumn column, String value) + { + return ImmutableMap.of( + tpchMetadata.toColumnHandle(column), + new NullableValue(TpchMetadata.getPrestoType(column), utf8Slice(value))); + } + + private ColumnStatistics columnStatistics(double distinctValuesCount) + { + return createColumnStatistics(Optional.of(distinctValuesCount), empty(), empty()); + } + + private ColumnStatistics columnStatistics(double distinctValuesCount, String min, String max) + { + return createColumnStatistics(Optional.of(distinctValuesCount), Optional.of(utf8Slice(min)), Optional.of(utf8Slice(max))); + } + + private ColumnStatistics columnStatistics(double distinctValuesCount, double min, double max) + { + return createColumnStatistics(Optional.of(distinctValuesCount), Optional.of(min), Optional.of(max)); + } + + private ColumnStatistics rangeStatistics(String min, String max) + { + return createColumnStatistics(empty(), Optional.of(utf8Slice(min)), Optional.of(utf8Slice(max))); + } + + private ColumnStatistics rangeStatistics(double min, double max) + { + return createColumnStatistics(empty(), Optional.of(min), Optional.of(max)); + } + + private ColumnStatistics unknownStatistics() + { + return createColumnStatistics(empty(), empty(), empty()); + } + + private ColumnStatistics createColumnStatistics(Optional distinctValuesCount, Optional min, Optional max) + { + return ColumnStatistics.builder() + .addRange(rb -> rb + .setDistinctValuesCount(distinctValuesCount.map(Estimate::new).orElse(unknownValue())) + .setLowValue(min) + .setHighValue(max) + .setFraction(new Estimate(1.0))) + .setNullsFraction(zeroValue()) + .build(); + } +} diff --git a/presto-tpch/src/test/java/com/facebook/presto/tpch/statistics/RecordTpchTableStatsTool.java b/presto-tpch/src/test/java/com/facebook/presto/tpch/statistics/RecordTpchTableStatsTool.java new file mode 100644 index 0000000000000..d4135ba11e87e --- /dev/null +++ b/presto-tpch/src/test/java/com/facebook/presto/tpch/statistics/RecordTpchTableStatsTool.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tpch.statistics; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.common.collect.ImmutableList; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchEntity; +import io.airlift.tpch.TpchTable; + +import java.io.IOException; +import java.util.Optional; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static com.facebook.presto.tpch.TpchMetadata.schemaNameToScaleFactor; +import static io.airlift.tpch.OrderColumn.ORDER_STATUS; +import static io.airlift.tpch.TpchTable.ORDERS; +import static java.lang.String.format; +import static java.util.Optional.empty; + +/** + * This is a tool used to record statistics for TPCH tables. + *

+ * The results are output to {@code presto-tpch/src/main/resources/tpch/statistics/${schemaName}} directory. + *

+ * The tool is run by invoking its {@code main} method. + */ +public class RecordTpchTableStatsTool +{ + private final TableStatisticsRecorder tableStatisticsRecorder; + private final TableStatisticsDataRepository tableStatisticsDataRepository; + + public RecordTpchTableStatsTool(TableStatisticsRecorder tableStatisticsRecorder, TableStatisticsDataRepository tableStatisticsDataRepository) + { + this.tableStatisticsRecorder = tableStatisticsRecorder; + this.tableStatisticsDataRepository = tableStatisticsDataRepository; + } + + public static void main(String[] args) + throws IOException + { + RecordTpchTableStatsTool tool = new RecordTpchTableStatsTool(new TableStatisticsRecorder(), new TableStatisticsDataRepository(createObjectMapper())); + + ImmutableList.of("tiny", "sf1").forEach(schemaName -> { + TpchTable.getTables() + .forEach(table -> tool.computeAndOutputStatsFor(schemaName, table)); + + Stream.of("F", "O", "P").forEach(partitionValue -> { + tool.computeAndOutputStatsFor(schemaName, ORDERS, ORDER_STATUS, partitionValue); + }); + }); + } + + private static ObjectMapper createObjectMapper() + { + return new ObjectMapper() + .registerModule(new Jdk8Module()); + } + + private void computeAndOutputStatsFor(String schemaName, TpchTable table) + { + computeAndOutputStatsFor(schemaName, table, row -> true, empty(), empty()); + } + + private void computeAndOutputStatsFor(String schemaName, TpchTable table, TpchColumn partitionColumn, String partitionValue) + { + Predicate predicate = row -> partitionColumn.getString(row).equals(partitionValue); + computeAndOutputStatsFor(schemaName, table, predicate, Optional.of(partitionColumn), Optional.of(partitionValue)); + } + + private void computeAndOutputStatsFor(String schemaName, TpchTable table, Predicate predicate, Optional> partitionColumn, Optional partitionValue) + { + double scaleFactor = schemaNameToScaleFactor(schemaName); + + long start = System.nanoTime(); + + TableStatisticsData statisticsData = tableStatisticsRecorder.recordStatistics(table, predicate, scaleFactor); + + long duration = (System.nanoTime() - start) / 1_000_000; + System.out.println(format("Finished stats recording for %s[%s] sf %s, took %s ms", table.getTableName(), partitionValue.orElse(""), scaleFactor, duration)); + + tableStatisticsDataRepository.save(schemaName, table, partitionColumn, partitionValue, statisticsData); + } +}