diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index ee4a479960145..fb18dca97994e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -22,20 +22,21 @@ import org.apache.logging.log4j.Level import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sample, Sort} +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.NullOrdering -import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @DockerTest -private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite { +private[v2] trait V2JDBCTest + extends DataSourcePushdownTestUtils + with DockerIntegrationFunSuite + with SharedSparkSession { import testImplicits._ val catalogName: String @@ -468,56 +469,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def supportsTableSample: Boolean = false - private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { - val sample = df.queryExecution.optimizedPlan.collect { - case s: Sample => s - } - if (pushed) { - assert(sample.isEmpty) - } else { - assert(sample.nonEmpty) - } - } - - protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { - val filter = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - if (pushed) { - assert(filter.isEmpty) - } else { - assert(filter.nonEmpty) - } - } - - protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val limit = df.queryExecution.optimizedPlan.collect { - case l: LocalLimit => l - case g: GlobalLimit => g - } - if (pushed) { - assert(limit.isEmpty) - } else { - assert(limit.nonEmpty) - } - } - - private def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit == limit) - } - } - } - - private def checkColumnPruned(df: DataFrame, col: String): Unit = { - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq(col))) - } - test("SPARK-48172: Test CONTAINS") { val df1 = spark.sql( s""" @@ -841,39 +792,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } - - if (pushed) { - assert(sorts.isEmpty) - } else { - assert(sorts.nonEmpty) - } - } - - private def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val offsets = df.queryExecution.optimizedPlan.collect { - case o: Offset => o - } - - if (pushed) { - assert(offsets.isEmpty) - } else { - assert(offsets.nonEmpty) - } - } - - private def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.offset == offset) - } - } - } - gridTest("simple scan")(partitioningEnabledTestCase) { partitioningEnabled => val (tableOptions, partitionInfo) = getTableOptions("employee", partitioningEnabled) val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + @@ -1028,27 +946,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkAggregateRemoved(df: DataFrame): Unit = { - val aggregates = df.queryExecution.optimizedPlan.collect { - case agg: Aggregate => agg - } - assert(aggregates.isEmpty) - } - - private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _, _, _) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions.length == 1) - assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) - assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) - } - } - protected def caseConvert(tableName: String): String = tableName Seq(true, false).foreach { isDistinct => diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala new file mode 100644 index 0000000000000..ecc0c5489bceb --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2.join + +import java.sql.Connection +import java.util.Locale + +import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} +import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.tags.DockerTest + +/** + * The following are the steps to test this: + * + * 1. Choose to use a prebuilt image or build Oracle database in a container + * - The documentation on how to build Oracle RDBMS in a container is at + * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md + * - Official Oracle container images can be found at https://container-registry.oracle.com + * - Trustable and streamlined Oracle Database Free images can be found on Docker Hub at + * https://hub.docker.com/r/gvenzl/oracle-free + * see also https://github.com/gvenzl/oci-oracle-free + * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing + * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-free:latest + * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1 + * 4. Start docker: sudo service docker start + * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME + * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.v2.OracleIntegrationSuite" + * + * A sequence of commands to build the Oracle Database Free container image: + * $ git clone https://github.com/oracle/docker-images.git + * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles0 + * $ ./buildContainerImage.sh -v 23.4.0 -f + * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:23.4.0-free + * + * This procedure has been validated with Oracle Database Free version 23.4.0, + * and with Oracle Express Edition versions 18.4.0 and 21.4.0 + */ +@DockerTest +class OracleJoinPushdownIntegrationSuite + extends DockerJDBCIntegrationSuite + with JDBCV2JoinPushdownIntegrationSuiteBase { + override val namespace: String = "SYSTEM" + + override val db = new OracleDatabaseOnDocker + + override val url = db.getJdbcUrl(dockerIp, externalPort) + + override val jdbcDialect: JdbcDialect = OracleDialect() + + override val integerType = DataTypes.createDecimalType(10, 0) + + override def caseConvert(identifier: String): String = identifier.toUpperCase(Locale.ROOT) + + override def schemaPreparation(): Unit = {} + + // This method comes from DockerJDBCIntegrationSuite + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala index 97ec093a0e297..d0f53fa20e1aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala @@ -208,7 +208,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) { // If join has been pushed down, reuse join query as a subquery. Otherwise, fallback to // what is provided in options. - private def tableOrQuery = joinQuery.getOrElse(options.tableOrQuery) + protected final def tableOrQuery: String = joinQuery.getOrElse(options.tableOrQuery) /** * Build the final SQL query that following dialect's SQL syntax. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index a9f6a727a7241..0c9c84f3f3e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -233,7 +233,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N extends JdbcSQLQueryBuilder(dialect, options) { override def build(): String = { - val selectStmt = s"SELECT $hintClause$columnList FROM ${options.tableOrQuery}" + + val selectStmt = s"SELECT $hintClause$columnList FROM $tableOrQuery" + s" $tableSampleClause $whereClause $groupByClause $orderByClause" val finalSelectStmt = if (limit > 0) { if (offset > 0) { @@ -268,6 +268,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N override def supportsHint: Boolean = true + override def supportsJoin: Boolean = true + override def classifyException( e: Throwable, condition: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala new file mode 100644 index 0000000000000..2816eb79f1f82 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { + protected val supportsSamplePushdown: Boolean = true + + protected val supportsFilterPushdown: Boolean = true + + protected val supportsLimitPushdown: Boolean = true + + protected val supportsAggregatePushdown: Boolean = true + + protected val supportsSortPushdown: Boolean = true + + protected val supportsOffsetPushdown: Boolean = true + + protected val supportsColumnPruning: Boolean = true + + protected val supportsJoinPushdown: Boolean = true + + + protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { + if (supportsSamplePushdown) { + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } + } + } + + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + if (supportsFilterPushdown) { + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } + } + } + + protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { + if (supportsLimitPushdown) { + val limit = df.queryExecution.optimizedPlan.collect { + case l: LocalLimit => l + case g: GlobalLimit => g + } + if (pushed) { + assert(limit.isEmpty) + } else { + assert(limit.nonEmpty) + } + } + } + + protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { + if (supportsLimitPushdown) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.limit == limit) + } + } + } + } + + protected def checkColumnPruned(df: DataFrame, col: String): Unit = { + if (supportsColumnPruning) { + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq(col))) + } + } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + if (supportsAggregatePushdown) { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + } + + protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + if (supportsAggregatePushdown) { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _, _, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.exists { expr => + expr.isInstanceOf[GeneralAggregateFunc] && + expr.asInstanceOf[GeneralAggregateFunc].name() == funcName + }) + } + } + } + + protected def checkSortRemoved( + df: DataFrame, + pushed: Boolean = true): Unit = { + if (supportsSortPushdown) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + + if (pushed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) + } + } + } + + protected def checkOffsetRemoved( + df: DataFrame, + pushed: Boolean = true): Unit = { + if (supportsOffsetPushdown) { + val offsets = df.queryExecution.optimizedPlan.collect { + case o: Offset => o + } + + if (pushed) { + assert(offsets.isEmpty) + } else { + assert(offsets.nonEmpty) + } + } + } + + protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { + if (supportsOffsetPushdown) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.offset == offset) + } + } + } + } + + protected def checkJoinNotPushed(df: DataFrame): Unit = { + if (supportsJoinPushdown) { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.nonEmpty, "Join should not be pushed down") + } + } + + protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { + if (supportsJoinPushdown) { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.isEmpty, "Join should be pushed down") + if (expectedTables.nonEmpty) { + checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") + } + } + } + + protected def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { + withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) + } + } + } + + /** + * Check if the output schema of dataframe {@code df} is same as {@code schema}. There is one + * limitation: if expected schema name is empty, assertion on same names will be skipped. + *
+ * For example, it is not really possible to use {@code checkPrunedColumns} for join pushdown, + * because in case of duplicate names, columns will have random UUID suffixes. For this reason, + * the best we can do is test that the size is same, and other fields beside names do match. + */ + protected def checkPrunedColumnsDataTypeAndNullability( + df: DataFrame, + schema: StructType): Unit = { + if (supportsColumnPruning) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + val dfSchema = v1.readSchema() + + assert(dfSchema.length == schema.length) + dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) => + if (f2.name.nonEmpty) { + assert(f1.name == f2.name) + } + assert(f1.dataType == f2.dataType) + assert(f1.nullable == f2.nullable) + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala deleted file mode 100644 index b77e905fea5d0..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit, Join, Sort} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation -import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.Utils - -class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { - val tempDir = Utils.createTempDir() - val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" - - override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) - .set("spark.sql.catalog.h2.url", url) - .set("spark.sql.catalog.h2.driver", "org.h2.Driver") - .set("spark.sql.catalog.h2.pushDownAggregate", "true") - .set("spark.sql.catalog.h2.pushDownLimit", "true") - .set("spark.sql.catalog.h2.pushDownOffset", "true") - .set("spark.sql.catalog.h2.pushDownJoin", "true") - - private def withConnection[T](f: Connection => T): T = { - val conn = DriverManager.getConnection(url, new Properties()) - try { - f(conn) - } finally { - conn.close() - } - } - - override def beforeAll(): Unit = { - super.beforeAll() - Utils.classForName("org.h2.Driver") - withConnection { conn => - conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"people\" (name TEXT(32) NOT NULL, id INTEGER NOT NULL)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate() - - // scalastyle:off - conn.prepareStatement( - "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate() - // scalastyle:on - conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() - conn.prepareStatement( - """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() - conn.prepareStatement( - """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() - } - } - - override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) - super.afterAll() - } - - private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { - withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) - } - } - } - - // Conditionless joins are not supported in join pushdown - test("Test that 2-way join without condition should not have join pushed down") { - val sqlQuery = "SELECT * FROM h2.test.employee a, h2.test.employee b" - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - // Conditionless joins are not supported in join pushdown - test("Test that multi-way join without condition should not have join pushed down") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a, - |h2.test.employee b, - |h2.test.employee c - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - test("Test self join with condition") { - val sqlQuery = "SELECT * FROM h2.test.employee a JOIN h2.test.employee b ON a.dept = b.dept + 1" - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test multi-way self join with conditions") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - assert(!rows.isEmpty) - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test self join with column pruning") { - val sqlQuery = """ - |SELECT a.dept + 2, b.dept, b.salary FROM - |h2.test.employee a JOIN h2.test.employee b - |ON a.dept = b.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test 2-way join with column pruning - different tables") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a JOIN h2.test.people b - |ON a.dept = b.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.people]") - checkPushedInfo(df, - "PushedFilters: [DEPT IS NOT NULL, ID IS NOT NULL, DEPT = ID]") - checkAnswer(df, rows) - } - } - - test("Test multi-way self join with column pruning") { - val sqlQuery = """ - |SELECT a.dept, b.*, c.dept, c.salary + a.salary - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test aliases not supported in join pushdown") { - val sqlQuery = """ - |SELECT a.dept, bc.* - |FROM h2.test.employee a - |JOIN ( - | SELECT b.*, c.dept AS c_dept, c.salary AS c_salary - | FROM h2.test.employee b - | JOIN h2.test.employee c ON c.dept = b.dept - 1 - |) bc ON bc.dept = a.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - test("Test join with dataframe with duplicated columns") { - val df1 = sql("SELECT dept FROM h2.test.employee") - val df2 = sql("SELECT dept, dept FROM h2.test.employee") - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - df1.join(df2, "dept").collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val joinDf = df1.join(df2, "dept") - val joinNodes = joinDf.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(joinDf, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(joinDf, rows) - } - } - - test("Test aggregate on top of 2-way self join") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), min(a.dept) - |FROM h2.test.employee a - |JOIN h2.test.employee b ON a.dept = b.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val aggNodes = df.queryExecution.optimizedPlan.collect { - case a: Aggregate => a - } - - assert(joinNodes.isEmpty) - assert(aggNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test aggregate on top of multi-way self join") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2) - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val aggNodes = df.queryExecution.optimizedPlan.collect { - case a: Aggregate => a - } - - assert(joinNodes.isEmpty) - assert(aggNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test sort limit on top of join is pushed down") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), a.dept, b.dept - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |GROUP BY a.dept, b.dept - |ORDER BY a.dept - |LIMIT 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf( - SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val sortNodes = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } - - val limitNodes = df.queryExecution.optimizedPlan.collect { - case l: GlobalLimit => l - } - - assert(joinNodes.isEmpty) - assert(sortNodes.isEmpty) - assert(limitNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala new file mode 100644 index 0000000000000..244ae40c48a9d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import org.apache.spark.SparkConf +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.JdbcDialect +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} + +trait JDBCV2JoinPushdownIntegrationSuiteBase + extends QueryTest + with SharedSparkSession + with DataSourcePushdownTestUtils { + val catalogName: String = "join_pushdown_catalog" + val namespace: String = "join_schema" + val url: String + + val joinTableName1: String = "join_table_1" + val joinTableName2: String = "join_table_2" + + val jdbcDialect: JdbcDialect + + override def sparkConf: SparkConf = super.sparkConf + .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) + .set(SQLConf.ANSI_ENABLED.key, "true") + .set(s"spark.sql.catalog.$catalogName.url", url) + .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") + .set(s"spark.sql.catalog.$catalogName.caseSensitive", "false") + + private def catalogAndNamespace = s"$catalogName.${caseConvert(namespace)}" + private def casedJoinTableName1 = caseConvert(joinTableName1) + private def casedJoinTableName2 = caseConvert(joinTableName2) + + def qualifyTableName(tableName: String): String = { + val fullyQualifiedCasedNamespace = jdbcDialect.quoteIdentifier(caseConvert(namespace)) + val fullyQualifiedCasedTableName = jdbcDialect.quoteIdentifier(caseConvert(tableName)) + s"$fullyQualifiedCasedNamespace.$fullyQualifiedCasedTableName" + } + + def quoteSchemaName(schemaName: String): String = + jdbcDialect.quoteIdentifier(caseConvert(namespace)) + + private lazy val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) + + private lazy val fullyQualifiedTableName2: String = qualifyTableName(joinTableName2) + + protected def getJDBCTypeString(dt: DataType): String = { + JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() + } + + protected def caseConvert(identifier: String): String = identifier + + protected def withConnection[T](f: Connection => T): T = { + val conn = DriverManager.getConnection(url, new Properties()) + try { + f(conn) + } finally { + conn.close() + } + } + + protected val integerType = DataTypes.IntegerType + + protected val stringType = DataTypes.StringType + + protected val decimalType = DataTypes.createDecimalType(10, 2) + + /** + * This method should cover the following: + * + */ + def dataPreparation(): Unit = { + schemaPreparation() + tablePreparation() + fillJoinTables() + } + + def schemaPreparation(): Unit = { + withConnection {conn => + conn + .prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${quoteSchemaName(namespace)}") + .executeUpdate() + } + } + + def tablePreparation(): Unit = { + withConnection{ conn => + conn.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName1 ( + | ID ${getJDBCTypeString(integerType)}, + | AMOUNT ${getJDBCTypeString(decimalType)}, + | ADDRESS ${getJDBCTypeString(stringType)} + |)""".stripMargin + ).executeUpdate() + + conn.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName2 ( + | ID ${getJDBCTypeString(integerType)}, + | NEXT_ID ${getJDBCTypeString(integerType)}, + | SALARY ${getJDBCTypeString(decimalType)}, + | SURNAME ${getJDBCTypeString(stringType)} + |)""".stripMargin + ).executeUpdate() + } + } + + private val random = new java.util.Random(42) + + private val table1Data = (1 to 100).map { i => + val id = i % 11 + val amount = BigDecimal.valueOf(random.nextDouble() * 10000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val address = s"address_$i" + (id, amount, address) + } + + private val table2Data = (1 to 100).map { i => + val id = (i % 17) + val next_id = (id + 1) % 17 + val salary = BigDecimal.valueOf(random.nextDouble() * 50000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val surname = s"surname_$i" + (id, next_id, salary, surname) + } + + def fillJoinTables(): Unit = { + withConnection { conn => + val insertStmt1 = conn.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName1 (id, amount, address) VALUES (?, ?, ?)" + ) + table1Data.foreach { case (id, amount, address) => + insertStmt1.setInt(1, id) + insertStmt1.setBigDecimal(2, amount.bigDecimal) + insertStmt1.setString(3, address) + insertStmt1.addBatch() + } + insertStmt1.executeBatch() + insertStmt1.close() + + val insertStmt2 = conn.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName2 (id, next_id, salary, surname) VALUES (?, ?, ?, ?)" + ) + table2Data.foreach { case (id, next_id, salary, surname) => + insertStmt2.setInt(1, id) + insertStmt2.setInt(2, next_id) + insertStmt2.setBigDecimal(3, salary.bigDecimal) + insertStmt2.setString(4, surname) + insertStmt2.addBatch() + } + insertStmt2.executeBatch() + insertStmt2.close() + + } + } + + // Condition-less joins are not supported in join pushdown + test("Test that 2-way join without condition should not have join pushed down") { + val sqlQuery = + s""" + |SELECT * FROM + |$catalogAndNamespace.$casedJoinTableName1 a, + |$catalogAndNamespace.$casedJoinTableName1 b + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + // Condition-less joins are not supported in join pushdown + test("Test that multi-way join without condition should not have join pushed down") { + val sqlQuery = s""" + |SELECT * FROM + |$catalogAndNamespace.$casedJoinTableName1 a, + |$catalogAndNamespace.$casedJoinTableName1 b, + |$catalogAndNamespace.$casedJoinTableName1 c + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test self join with condition") { + val sqlQuery = s""" + |SELECT * FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b + |ON a.id = b.id + 1""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test multi-way self join with conditions") { + val sqlQuery = s""" + |SELECT * FROM + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + assert(!rows.isEmpty) + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test self join with column pruning") { + val sqlQuery = s""" + |SELECT a.id + 2, b.id, b.amount FROM + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b + |ON a.id = b.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + val expectedSchemaWithoutNames = StructType( + Seq( + StructField("", integerType), // ID + StructField("", integerType), // NEXT_ID + StructField(caseConvert("amount"), decimalType) // AMOUNT + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test 2-way join with column pruning - different tables") { + val sqlQuery = s""" + |SELECT a.id, b.next_id FROM + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName2 b + |ON a.id = b.next_id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + val expectedSchemaWithoutNames = StructType( + Seq( + StructField(caseConvert("id"), integerType), // ID + StructField(caseConvert("next_id"), integerType) // NEXT_ID + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkPushedInfo(df, + s"PushedFilters: [${caseConvert("id")} IS NOT NULL, " + + s"${caseConvert("next_id")} IS NOT NULL, " + + s"${caseConvert("id")} = ${caseConvert("next_id")}]") + checkAnswer(df, rows) + } + } + + test("Test multi-way self join with column pruning") { + val sqlQuery = s""" + |SELECT a.id, b.*, c.id, c.amount + a.amount + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + val expectedSchemaWithoutNames = StructType( + Seq( + StructField("", integerType), // ID_UUID + StructField("", decimalType), // AMOUNT_UUID + StructField("", integerType), // ID_UUID + StructField("", decimalType), // AMOUNT_UUID + StructField(caseConvert("address"), stringType), // ADDRESS + StructField(caseConvert("id"), integerType), // ID + StructField(caseConvert("amount"), decimalType) // AMOUNT + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") + checkAnswer(df, rows) + } + } + + test("Test aliases not supported in join pushdown") { + val sqlQuery = s""" + |SELECT a.id, bc.* + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN ( + | SELECT b.*, c.id AS c_id, c.amount AS c_amount + | FROM $catalogAndNamespace.$casedJoinTableName1 b + | JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 + |) bc ON bc.id = a.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with dataframe with duplicated columns") { + val df1 = sql(s"SELECT id FROM $catalogAndNamespace.$casedJoinTableName1") + val df2 = sql(s"SELECT id, id FROM $catalogAndNamespace.$casedJoinTableName1") + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + df1.join(df2, "id").collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val joinDf = df1.join(df2, "id") + + checkJoinPushed( + joinDf, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(joinDf, rows) + } + } + + test("Test aggregate on top of 2-way self join") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), min(a.id) + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON a.id = b.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkAggregateRemoved(df) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + + checkAnswer(df, rows) + } + } + + test("Test aggregate on top of multi-way self join") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), min(a.id), min(c.id - 2) + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + + s" $catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") + checkAnswer(df, rows) + } + } + + test("Test sort limit on top of join is pushed down") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), a.id, b.id + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |GROUP BY a.id, b.id + |ORDER BY a.id + |LIMIT 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf( + SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkSortRemoved(df) + checkLimitRemoved(df) + + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test join with additional filters") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname, t1.amount, t2.salary + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id + |WHERE t1.amount > 5000 AND t2.salary > 25000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkFilterPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with complex condition") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname, t1.amount + t2.salary as total + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |JOIN $catalogAndNamespace.$casedJoinTableName2 t2 + |ON t1.id = t2.id AND t1.amount > 1000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } + + test("Test left outer join should not be pushed down") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |LEFT JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test right outer join should not be pushed down") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |RIGHT JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala new file mode 100644 index 0000000000000..026fcf3126302 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.util.Locale + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils +import org.apache.spark.sql.jdbc.{H2Dialect, JdbcDialect} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class JDBCV2JoinPushdownSuite + extends QueryTest + with SharedSparkSession + with ExplainSuiteHelper + with DataSourcePushdownTestUtils + with JDBCV2JoinPushdownIntegrationSuiteBase { + val tempDir = Utils.createTempDir() + override val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" + + override val jdbcDialect: JdbcDialect = H2Dialect() + + override def sparkConf: SparkConf = super.sparkConf + .set(s"spark.sql.catalog.$catalogName.driver", "org.h2.Driver") + + override def caseConvert(identifier: String): String = identifier.toUpperCase(Locale.ROOT) + + override def beforeAll(): Unit = { + Utils.classForName("org.h2.Driver") + super.beforeAll() + dataPreparation() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } +}