From 160b31b6bb2a1f27fef679cb50418dd9e7d4d6e1 Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Wed, 16 Jul 2025 20:52:42 +0200 Subject: [PATCH 1/9] support join pushdown for oracle --- .../sql/jdbc/v2/V2JDBCPushdownTestUtils.scala | 130 +++++ .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 112 +---- .../JDBCJoinPushdownIntegrationSuite.scala | 454 ++++++++++++++++++ .../OracleJoinPushdownIntegrationSuite.scala | 75 +++ .../spark/sql/jdbc/JdbcSQLQueryBuilder.scala | 2 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 4 +- 6 files changed, 667 insertions(+), 110 deletions(-) create mode 100644 connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala create mode 100644 connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala create mode 100644 connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala new file mode 100644 index 0000000000000..78d1016a0d1f2 --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import org.apache.spark.sql.{DataFrame} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sample, Sort} +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} + +trait V2JDBCPushdownTestUtils { + protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } + } + + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } + } + + protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { + val limit = df.queryExecution.optimizedPlan.collect { + case l: LocalLimit => l + case g: GlobalLimit => g + } + if (pushed) { + assert(limit.isEmpty) + } else { + assert(limit.nonEmpty) + } + } + + protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.limit == limit) + } + } + } + + protected def checkColumnPruned(df: DataFrame, col: String): Unit = { + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq(col))) + } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + + + protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _, _, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.length == 1) + assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) + assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + } + } + + protected def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + + if (pushed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) + } + } + + protected def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = { + val offsets = df.queryExecution.optimizedPlan.collect { + case o: Offset => o + } + + if (pushed) { + assert(offsets.isEmpty) + } else { + assert(offsets.nonEmpty) + } + } + + protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.offset == offset) + } + } + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index ee4a479960145..683e5ec77ac1c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -22,20 +22,20 @@ import org.apache.logging.log4j.Level import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sample, Sort} import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.NullOrdering -import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @DockerTest -private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite { +private[v2] trait V2JDBCTest + extends SharedSparkSession + with DockerIntegrationFunSuite + with V2JDBCPushdownTestUtils { import testImplicits._ val catalogName: String @@ -468,56 +468,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def supportsTableSample: Boolean = false - private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { - val sample = df.queryExecution.optimizedPlan.collect { - case s: Sample => s - } - if (pushed) { - assert(sample.isEmpty) - } else { - assert(sample.nonEmpty) - } - } - - protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { - val filter = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - if (pushed) { - assert(filter.isEmpty) - } else { - assert(filter.nonEmpty) - } - } - - protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val limit = df.queryExecution.optimizedPlan.collect { - case l: LocalLimit => l - case g: GlobalLimit => g - } - if (pushed) { - assert(limit.isEmpty) - } else { - assert(limit.nonEmpty) - } - } - - private def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit == limit) - } - } - } - - private def checkColumnPruned(df: DataFrame, col: String): Unit = { - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq(col))) - } - test("SPARK-48172: Test CONTAINS") { val df1 = spark.sql( s""" @@ -841,39 +791,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } - - if (pushed) { - assert(sorts.isEmpty) - } else { - assert(sorts.nonEmpty) - } - } - - private def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val offsets = df.queryExecution.optimizedPlan.collect { - case o: Offset => o - } - - if (pushed) { - assert(offsets.isEmpty) - } else { - assert(offsets.nonEmpty) - } - } - - private def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.offset == offset) - } - } - } - gridTest("simple scan")(partitioningEnabledTestCase) { partitioningEnabled => val (tableOptions, partitionInfo) = getTableOptions("employee", partitioningEnabled) val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + @@ -1028,27 +945,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkAggregateRemoved(df: DataFrame): Unit = { - val aggregates = df.queryExecution.optimizedPlan.collect { - case agg: Aggregate => agg - } - assert(aggregates.isEmpty) - } - - private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _, _, _) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions.length == 1) - assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) - assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) - } - } - protected def caseConvert(tableName: String): String = tableName Seq(true, false).foreach { isDistinct => diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala new file mode 100644 index 0000000000000..3787881b1ec3e --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2.join + +import java.sql.Connection + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.{DockerIntegrationFunSuite, JdbcDialect} +import org.apache.spark.sql.jdbc.v2.V2JDBCPushdownTestUtils +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.tags.DockerTest + +@DockerTest +trait JDBCJoinPushdownIntegrationSuite + extends QueryTest + with SharedSparkSession + with DockerIntegrationFunSuite + with V2JDBCPushdownTestUtils { + val catalogName: String + val namespaceOpt: Option[String] = None + val joinTableName1: String = "join_table_1" + val joinTableName2: String = "join_table_2" + + // Concrete suite must provide the dialect for its DB + def jdbcDialect: JdbcDialect + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + + def fullyQualifiedTableName1: String = namespaceOpt + .map(namespace => s"$namespace.$joinTableName1").getOrElse(joinTableName1) + + def fullyQualifiedTableName2: String = namespaceOpt + .map(namespace => s"$namespace.$joinTableName2").getOrElse(joinTableName2) + + protected def getJDBCTypeString(dt: DataType): String = { + JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() + } + + protected def caseConvert(tableName: String): String = tableName + + def dataPreparation(connection: Connection): Unit = { + tablePreparation(connection) + fillJoinTables(connection) + } + + def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName1 ( + | id ${getJDBCTypeString(DataTypes.IntegerType)}, + | amount ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, + | name ${getJDBCTypeString(DataTypes.StringType)} + |)""".stripMargin).executeUpdate() + + connection.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName2 ( + | id ${getJDBCTypeString(DataTypes.LongType)}, + | salary ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, + | surname ${getJDBCTypeString(DataTypes.StringType)} + |)""".stripMargin).executeUpdate() + } + + def fillJoinTables(connection: Connection): Unit = { + val random = new java.util.Random(42) + val table1Data = (1 to 100).map { i => + val id = i % 11 + val amount = BigDecimal.valueOf(random.nextDouble() * 10000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val name = s"name_$i" + (id, amount, name) + } + val table2Data = (1 to 100).map { i => + val id = (i % 17).toLong + val salary = BigDecimal.valueOf(random.nextDouble() * 50000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val surname = s"surname_$i" + (id, salary, surname) + } + + // Use parameterized queries to handle different data types properly + val insertStmt1 = connection.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName1 (id, amount, name) VALUES (?, ?, ?)" + ) + table1Data.foreach { case (id, amount, name) => + insertStmt1.setInt(1, id) + insertStmt1.setBigDecimal(2, amount.bigDecimal) + insertStmt1.setString(3, name) + insertStmt1.executeUpdate() + } + insertStmt1.close() + + val insertStmt2 = connection.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName2 (id, salary, surname) VALUES (?, ?, ?)" + ) + table2Data.foreach { case (id, salary, surname) => + insertStmt2.setLong(1, id) + insertStmt2.setBigDecimal(2, salary.bigDecimal) + insertStmt2.setString(3, surname) + insertStmt2.executeUpdate() + } + insertStmt2.close() + } + + /** + * Runs the plan and makes sure the plans contains all of the keywords. + */ + protected def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain(extended = true) + } + val normalizedOutput = output.toString.replaceAll("#\\d+", "#x") + for (key <- keywords) { + assert(normalizedOutput.contains(key), s"Expected keyword '$key' not found in explain output") + } + } + + private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { + withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) + } + } + } + + private def checkJoinNotPushed(df: DataFrame): Unit = { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.nonEmpty, "Join should not be pushed down") + } + + private def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.isEmpty, "Join should be pushed down") + if (expectedTables.nonEmpty) { + checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") + } + } + + test("Test basic inner join pushdown with column pruning") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname, t1.amount, t2.salary + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + // Verify we have non-empty results + assert(rows.nonEmpty, "Join should produce non-empty results") + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } + + + test("Test join with additional filters") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname, t1.amount, t2.salary + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |WHERE t1.amount > 5000 AND t2.salary > 25000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkFilterPushed(df) + checkAnswer(df, rows) + } + } + + test("Test self join should be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.name as name2 + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test join without condition should not be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |CROSS JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with complex condition") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname, t1.amount + t2.salary as total + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 + |ON t1.id = t2.id AND t1.amount > 1000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } + + test("Test join with limit and order") { + // ORDER BY is used to have same ordering on Spark and database. Otherwise, different results + // could be returned. + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |ORDER BY t1.id, t1.name, t2.surname + |LIMIT 5 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkSortRemoved(df) + checkLimitRemoved(df) + checkAnswer(df, rows) + } + } + + test("Test join with order by") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |ORDER BY t1.id, t1.name, t2.surname + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + // Order without limit is not supported in DSv2 + checkSortRemoved(df, false) + checkAnswer(df, rows) + } + } + + test("Test join with aggregation") { + val sqlQuery = s""" + |SELECT t1.id, COUNT(*), AVG(t1.amount), MAX(t2.salary) + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |GROUP BY t1.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAggregateRemoved(df) + checkAnswer(df, rows) + } + } + + test("Test left outer join should not be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |LEFT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test right outer join should not be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |RIGHT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test full outer join should not be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |FULL OUTER JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with subquery should be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, sub.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN ( + | SELECT id, surname FROM $catalogAndNamespace.${caseConvert(joinTableName2)} + | WHERE salary > 25000 + |) sub ON t1.id = sub.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } + + test("Test join with non-equality condition should be pushed down") { + val sqlQuery = s""" + |SELECT t1.id, t1.name, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id > t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala new file mode 100644 index 0000000000000..329d8b72e9e62 --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2.join + +import java.util.Locale + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog +import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} +import org.apache.spark.tags.DockerTest + +/** + * The following are the steps to test this: + * + * 1. Choose to use a prebuilt image or build Oracle database in a container + * - The documentation on how to build Oracle RDBMS in a container is at + * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md + * - Official Oracle container images can be found at https://container-registry.oracle.com + * - Trustable and streamlined Oracle Database Free images can be found on Docker Hub at + * https://hub.docker.com/r/gvenzl/oracle-free + * see also https://github.com/gvenzl/oci-oracle-free + * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing + * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-free:latest + * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1 + * 4. Start docker: sudo service docker start + * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME + * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.v2.OracleIntegrationSuite" + * + * A sequence of commands to build the Oracle Database Free container image: + * $ git clone https://github.com/oracle/docker-images.git + * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles0 + * $ ./buildContainerImage.sh -v 23.4.0 -f + * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:23.4.0-free + * + * This procedure has been validated with Oracle Database Free version 23.4.0, + * and with Oracle Express Edition versions 18.4.0 and 21.4.0 + */ +@DockerTest +class OracleJoinPushdownIntegrationSuite + extends DockerJDBCIntegrationSuite + with JDBCJoinPushdownIntegrationSuite { + override val catalogName: String = "oracle" + + override val namespaceOpt: Option[String] = Some("SYSTEM") + + override val db = new OracleDatabaseOnDocker + + override def sparkConf: SparkConf = super.sparkConf + .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) + .set(s"spark.sql.catalog.$catalogName.url", db.getJdbcUrl(dockerIp, externalPort)) + .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") + + override def jdbcDialect: JdbcDialect = OracleDialect() + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala index 97ec093a0e297..e25cfb9668e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala @@ -208,7 +208,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) { // If join has been pushed down, reuse join query as a subquery. Otherwise, fallback to // what is provided in options. - private def tableOrQuery = joinQuery.getOrElse(options.tableOrQuery) + def tableOrQuery: String = joinQuery.getOrElse(options.tableOrQuery) /** * Build the final SQL query that following dialect's SQL syntax. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index a9f6a727a7241..0c9c84f3f3e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -233,7 +233,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N extends JdbcSQLQueryBuilder(dialect, options) { override def build(): String = { - val selectStmt = s"SELECT $hintClause$columnList FROM ${options.tableOrQuery}" + + val selectStmt = s"SELECT $hintClause$columnList FROM $tableOrQuery" + s" $tableSampleClause $whereClause $groupByClause $orderByClause" val finalSelectStmt = if (limit > 0) { if (offset > 0) { @@ -268,6 +268,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N override def supportsHint: Boolean = true + override def supportsJoin: Boolean = true + override def classifyException( e: Throwable, condition: String, From 3ea316b09d6bf5946e0d270d644755f2a6591ccd Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Thu, 17 Jul 2025 12:21:28 +0200 Subject: [PATCH 2/9] refactor tests --- .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 4 +- .../JDBCJoinPushdownIntegrationSuite.scala | 454 ---------------- .../OracleJoinPushdownIntegrationSuite.scala | 12 +- .../sql/jdbc/JDBCV2JoinPushdownSuite.scala | 390 +------------ .../v2/JDBCJoinPushdownIntegrationSuite.scala | 511 ++++++++++++++++++ .../sql/jdbc/v2/V2JDBCPushdownTestUtils.scala | 41 +- 6 files changed, 580 insertions(+), 832 deletions(-) delete mode 100644 connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala rename {connector/docker-integration-tests => sql/core}/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala (75%) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 683e5ec77ac1c..0c225a0592ace 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -33,9 +33,9 @@ import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCTest - extends SharedSparkSession + extends V2JDBCPushdownTestUtils with DockerIntegrationFunSuite - with V2JDBCPushdownTestUtils { + with SharedSparkSession { import testImplicits._ val catalogName: String diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala deleted file mode 100644 index 3787881b1ec3e..0000000000000 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/JDBCJoinPushdownIntegrationSuite.scala +++ /dev/null @@ -1,454 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc.v2.join - -import java.sql.Connection - -import org.apache.spark.sql.{DataFrame, QueryTest} -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.jdbc.{DockerIntegrationFunSuite, JdbcDialect} -import org.apache.spark.sql.jdbc.v2.V2JDBCPushdownTestUtils -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DataType, DataTypes} -import org.apache.spark.tags.DockerTest - -@DockerTest -trait JDBCJoinPushdownIntegrationSuite - extends QueryTest - with SharedSparkSession - with DockerIntegrationFunSuite - with V2JDBCPushdownTestUtils { - val catalogName: String - val namespaceOpt: Option[String] = None - val joinTableName1: String = "join_table_1" - val joinTableName2: String = "join_table_2" - - // Concrete suite must provide the dialect for its DB - def jdbcDialect: JdbcDialect - - private def catalogAndNamespace = - namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) - - def fullyQualifiedTableName1: String = namespaceOpt - .map(namespace => s"$namespace.$joinTableName1").getOrElse(joinTableName1) - - def fullyQualifiedTableName2: String = namespaceOpt - .map(namespace => s"$namespace.$joinTableName2").getOrElse(joinTableName2) - - protected def getJDBCTypeString(dt: DataType): String = { - JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() - } - - protected def caseConvert(tableName: String): String = tableName - - def dataPreparation(connection: Connection): Unit = { - tablePreparation(connection) - fillJoinTables(connection) - } - - def tablePreparation(connection: Connection): Unit = { - connection.prepareStatement( - s"""CREATE TABLE $fullyQualifiedTableName1 ( - | id ${getJDBCTypeString(DataTypes.IntegerType)}, - | amount ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, - | name ${getJDBCTypeString(DataTypes.StringType)} - |)""".stripMargin).executeUpdate() - - connection.prepareStatement( - s"""CREATE TABLE $fullyQualifiedTableName2 ( - | id ${getJDBCTypeString(DataTypes.LongType)}, - | salary ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, - | surname ${getJDBCTypeString(DataTypes.StringType)} - |)""".stripMargin).executeUpdate() - } - - def fillJoinTables(connection: Connection): Unit = { - val random = new java.util.Random(42) - val table1Data = (1 to 100).map { i => - val id = i % 11 - val amount = BigDecimal.valueOf(random.nextDouble() * 10000) - .setScale(2, BigDecimal.RoundingMode.HALF_UP) - val name = s"name_$i" - (id, amount, name) - } - val table2Data = (1 to 100).map { i => - val id = (i % 17).toLong - val salary = BigDecimal.valueOf(random.nextDouble() * 50000) - .setScale(2, BigDecimal.RoundingMode.HALF_UP) - val surname = s"surname_$i" - (id, salary, surname) - } - - // Use parameterized queries to handle different data types properly - val insertStmt1 = connection.prepareStatement( - s"INSERT INTO $fullyQualifiedTableName1 (id, amount, name) VALUES (?, ?, ?)" - ) - table1Data.foreach { case (id, amount, name) => - insertStmt1.setInt(1, id) - insertStmt1.setBigDecimal(2, amount.bigDecimal) - insertStmt1.setString(3, name) - insertStmt1.executeUpdate() - } - insertStmt1.close() - - val insertStmt2 = connection.prepareStatement( - s"INSERT INTO $fullyQualifiedTableName2 (id, salary, surname) VALUES (?, ?, ?)" - ) - table2Data.foreach { case (id, salary, surname) => - insertStmt2.setLong(1, id) - insertStmt2.setBigDecimal(2, salary.bigDecimal) - insertStmt2.setString(3, surname) - insertStmt2.executeUpdate() - } - insertStmt2.close() - } - - /** - * Runs the plan and makes sure the plans contains all of the keywords. - */ - protected def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { - val output = new java.io.ByteArrayOutputStream() - Console.withOut(output) { - df.explain(extended = true) - } - val normalizedOutput = output.toString.replaceAll("#\\d+", "#x") - for (key <- keywords) { - assert(normalizedOutput.contains(key), s"Expected keyword '$key' not found in explain output") - } - } - - private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { - withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) - } - } - } - - private def checkJoinNotPushed(df: DataFrame): Unit = { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - assert(joinNodes.nonEmpty, "Join should not be pushed down") - } - - private def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - assert(joinNodes.isEmpty, "Join should be pushed down") - if (expectedTables.nonEmpty) { - checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") - } - } - - test("Test basic inner join pushdown with column pruning") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname, t1.amount, t2.salary - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - // Verify we have non-empty results - assert(rows.nonEmpty, "Join should produce non-empty results") - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkAnswer(df, rows) - } - } - - - test("Test join with additional filters") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname, t1.amount, t2.salary - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |WHERE t1.amount > 5000 AND t2.salary > 25000 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkFilterPushed(df) - checkAnswer(df, rows) - } - } - - test("Test self join should be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.name as name2 - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} t2 ON t1.id = t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName1)}" - ) - checkAnswer(df, rows) - } - } - - test("Test join without condition should not be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |CROSS JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinNotPushed(df) - checkAnswer(df, rows) - } - } - - test("Test join with complex condition") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname, t1.amount + t2.salary as total - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 - |ON t1.id = t2.id AND t1.amount > 1000 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkAnswer(df, rows) - } - } - - test("Test join with limit and order") { - // ORDER BY is used to have same ordering on Spark and database. Otherwise, different results - // could be returned. - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |ORDER BY t1.id, t1.name, t2.surname - |LIMIT 5 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkSortRemoved(df) - checkLimitRemoved(df) - checkAnswer(df, rows) - } - } - - test("Test join with order by") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |ORDER BY t1.id, t1.name, t2.surname - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - // Order without limit is not supported in DSv2 - checkSortRemoved(df, false) - checkAnswer(df, rows) - } - } - - test("Test join with aggregation") { - val sqlQuery = s""" - |SELECT t1.id, COUNT(*), AVG(t1.amount), MAX(t2.salary) - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |GROUP BY t1.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkAggregateRemoved(df) - checkAnswer(df, rows) - } - } - - test("Test left outer join should not be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |LEFT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinNotPushed(df) - checkAnswer(df, rows) - } - } - - test("Test right outer join should not be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |RIGHT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinNotPushed(df) - checkAnswer(df, rows) - } - } - - test("Test full outer join should not be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |FULL OUTER JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinNotPushed(df) - checkAnswer(df, rows) - } - } - - test("Test join with subquery should be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, sub.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN ( - | SELECT id, surname FROM $catalogAndNamespace.${caseConvert(joinTableName2)} - | WHERE salary > 25000 - |) sub ON t1.id = sub.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkAnswer(df, rows) - } - } - - test("Test join with non-equality condition should be pushed down") { - val sqlQuery = s""" - |SELECT t1.id, t1.name, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id > t2.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - checkJoinPushed( - df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" - ) - checkAnswer(df, rows) - } - } -} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala index 329d8b72e9e62..dc2de6f17b550 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.jdbc.v2.join +import java.sql.Connection import java.util.Locale import org.apache.spark.SparkConf import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} +import org.apache.spark.sql.jdbc.v2.JDBCJoinPushdownIntegrationSuite import org.apache.spark.tags.DockerTest /** @@ -57,19 +59,23 @@ class OracleJoinPushdownIntegrationSuite with JDBCJoinPushdownIntegrationSuite { override val catalogName: String = "oracle" - override val namespaceOpt: Option[String] = Some("SYSTEM") + override def namespaceOpt: Option[String] = Some("SYSTEM") override val db = new OracleDatabaseOnDocker + override val url = db.getJdbcUrl(dockerIp, externalPort) + override def sparkConf: SparkConf = super.sparkConf .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) - .set(s"spark.sql.catalog.$catalogName.url", db.getJdbcUrl(dockerIp, externalPort)) + .set(s"spark.sql.catalog.$catalogName.url", url) .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") - override def jdbcDialect: JdbcDialect = OracleDialect() + override val jdbcDialect: JdbcDialect = OracleDialect() override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + override def schemaPreparation(connection: Connection): Unit = {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala index b77e905fea5d0..8b16310b7a4cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala @@ -17,21 +17,28 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, DriverManager} -import java.util.Properties +import java.sql.Connection import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit, Join, Sort} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.v2.{JDBCJoinPushdownIntegrationSuite, V2JDBCPushdownTestUtils} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils -class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { +class JDBCV2JoinPushdownSuite + extends QueryTest + with SharedSparkSession + with ExplainSuiteHelper + with V2JDBCPushdownTestUtils + with JDBCJoinPushdownIntegrationSuite { val tempDir = Utils.createTempDir() - val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" + override val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" + + override val catalogName: String = "h2" + override def namespaceOpt: Option[String] = Some("test") + + override val jdbcDialect: JdbcDialect = H2Dialect() override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) @@ -42,372 +49,23 @@ class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with Exp .set("spark.sql.catalog.h2.pushDownOffset", "true") .set("spark.sql.catalog.h2.pushDownJoin", "true") - private def withConnection[T](f: Connection => T): T = { - val conn = DriverManager.getConnection(url, new Properties()) - try { - f(conn) - } finally { - conn.close() - } + override def qualifyTableName(tableName: String): String = namespaceOpt + .map(namespace => s""""$namespace"."$tableName"""").getOrElse(s""""$tableName"""") + + override def schemaPreparation(connection: Connection): Unit = { + connection + .prepareStatement(s"""CREATE SCHEMA IF NOT EXISTS "${namespaceOpt.get}"""") + .executeUpdate() } override def beforeAll(): Unit = { - super.beforeAll() Utils.classForName("org.h2.Driver") - withConnection { conn => - conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"people\" (name TEXT(32) NOT NULL, id INTEGER NOT NULL)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate() - conn.prepareStatement( - "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate() - - // scalastyle:off - conn.prepareStatement( - "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate() - // scalastyle:on - conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() - conn.prepareStatement( - """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() - conn.prepareStatement( - """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() - } + super.beforeAll() + withConnection(dataPreparation) } override def afterAll(): Unit = { Utils.deleteRecursively(tempDir) super.afterAll() } - - private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { - withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) - } - } - } - - // Conditionless joins are not supported in join pushdown - test("Test that 2-way join without condition should not have join pushed down") { - val sqlQuery = "SELECT * FROM h2.test.employee a, h2.test.employee b" - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - // Conditionless joins are not supported in join pushdown - test("Test that multi-way join without condition should not have join pushed down") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a, - |h2.test.employee b, - |h2.test.employee c - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - test("Test self join with condition") { - val sqlQuery = "SELECT * FROM h2.test.employee a JOIN h2.test.employee b ON a.dept = b.dept + 1" - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test multi-way self join with conditions") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - assert(!rows.isEmpty) - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test self join with column pruning") { - val sqlQuery = """ - |SELECT a.dept + 2, b.dept, b.salary FROM - |h2.test.employee a JOIN h2.test.employee b - |ON a.dept = b.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test 2-way join with column pruning - different tables") { - val sqlQuery = """ - |SELECT * FROM - |h2.test.employee a JOIN h2.test.people b - |ON a.dept = b.id - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.people]") - checkPushedInfo(df, - "PushedFilters: [DEPT IS NOT NULL, ID IS NOT NULL, DEPT = ID]") - checkAnswer(df, rows) - } - } - - test("Test multi-way self join with column pruning") { - val sqlQuery = """ - |SELECT a.dept, b.*, c.dept, c.salary + a.salary - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test aliases not supported in join pushdown") { - val sqlQuery = """ - |SELECT a.dept, bc.* - |FROM h2.test.employee a - |JOIN ( - | SELECT b.*, c.dept AS c_dept, c.salary AS c_salary - | FROM h2.test.employee b - | JOIN h2.test.employee c ON c.dept = b.dept - 1 - |) bc ON bc.dept = a.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.nonEmpty) - checkAnswer(df, rows) - } - } - - test("Test join with dataframe with duplicated columns") { - val df1 = sql("SELECT dept FROM h2.test.employee") - val df2 = sql("SELECT dept, dept FROM h2.test.employee") - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - df1.join(df2, "dept").collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val joinDf = df1.join(df2, "dept") - val joinNodes = joinDf.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - assert(joinNodes.isEmpty) - checkPushedInfo(joinDf, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(joinDf, rows) - } - } - - test("Test aggregate on top of 2-way self join") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), min(a.dept) - |FROM h2.test.employee a - |JOIN h2.test.employee b ON a.dept = b.dept + 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val aggNodes = df.queryExecution.optimizedPlan.collect { - case a: Aggregate => a - } - - assert(joinNodes.isEmpty) - assert(aggNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test aggregate on top of multi-way self join") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2) - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |JOIN h2.test.employee c ON c.dept = b.dept - 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val aggNodes = df.queryExecution.optimizedPlan.collect { - case a: Aggregate => a - } - - assert(joinNodes.isEmpty) - assert(aggNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } - - test("Test sort limit on top of join is pushed down") { - val sqlQuery = """ - |SELECT min(a.dept + b.dept), a.dept, b.dept - |FROM h2.test.employee a - |JOIN h2.test.employee b ON b.dept = a.dept + 1 - |GROUP BY a.dept, b.dept - |ORDER BY a.dept - |LIMIT 1 - |""".stripMargin - - val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { - sql(sqlQuery).collect().toSeq - } - - withSQLConf( - SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { - val df = sql(sqlQuery) - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - - val sortNodes = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } - - val limitNodes = df.queryExecution.optimizedPlan.collect { - case l: GlobalLimit => l - } - - assert(joinNodes.isEmpty) - assert(sortNodes.isEmpty) - assert(limitNodes.isEmpty) - checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]") - checkAnswer(df, rows) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala new file mode 100644 index 0000000000000..e867ff419ba6d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.JdbcDialect +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes} + +trait JDBCJoinPushdownIntegrationSuite + extends QueryTest + with SharedSparkSession + with V2JDBCPushdownTestUtils { + val catalogName: String + def namespaceOpt: Option[String] = None + val url: String + + val joinTableName1: String = "join_table_1" + val joinTableName2: String = "join_table_2" + + val jdbcDialect: JdbcDialect + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + + def qualifyTableName(tableName: String): String = namespaceOpt + .map(namespace => s"$namespace.$tableName").getOrElse(tableName) + + private val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) + + private val fullyQualifiedTableName2: String = qualifyTableName(joinTableName2) + + protected def getJDBCTypeString(dt: DataType): String = { + JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() + } + + protected def caseConvert(tableName: String): String = tableName + + protected def withConnection[T](f: Connection => T): T = { + val conn = DriverManager.getConnection(url, new Properties()) + try { + f(conn) + } finally { + conn.close() + } + } + + def dataPreparation(connection: Connection): Unit = { + schemaPreparation(connection) + tablePreparation(connection) + fillJoinTables(connection) + } + + def schemaPreparation(connection: Connection): Unit = { + connection.prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${namespaceOpt.get}").executeUpdate() + } + + def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName1 ( + | ID ${getJDBCTypeString(DataTypes.IntegerType)}, + | AMOUNT ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, + | ADDRESS ${getJDBCTypeString(DataTypes.StringType)} + |)""".stripMargin + ).executeUpdate() + + connection.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName2 ( + | ID ${getJDBCTypeString(DataTypes.IntegerType)}, + | NEXT_ID ${getJDBCTypeString(DataTypes.IntegerType)}, + | SALARY ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, + | SURNAME ${getJDBCTypeString(DataTypes.StringType)} + |)""".stripMargin + ).executeUpdate() + } + + def fillJoinTables(connection: Connection): Unit = { + val random = new java.util.Random(42) + val table1Data = (1 to 100).map { i => + val id = i % 11 + val amount = BigDecimal.valueOf(random.nextDouble() * 10000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val address = s"address_$i" + (id, amount, address) + } + val table2Data = (1 to 100).map { i => + val id = (i % 17) + val next_id = (id + 1) % 17 + val salary = BigDecimal.valueOf(random.nextDouble() * 50000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val surname = s"surname_$i" + (id, next_id, salary, surname) + } + + val insertStmt1 = connection.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName1 (id, amount, address) VALUES (?, ?, ?)" + ) + table1Data.foreach { case (id, amount, address) => + insertStmt1.setInt(1, id) + insertStmt1.setBigDecimal(2, amount.bigDecimal) + insertStmt1.setString(3, address) + insertStmt1.executeUpdate() + } + insertStmt1.close() + + val insertStmt2 = connection.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName2 (id, next_id, salary, surname) VALUES (?, ?, ?, ?)" + ) + table2Data.foreach { case (id, next_id, salary, surname) => + insertStmt2.setInt(1, id) + insertStmt2.setInt(2, next_id) + insertStmt2.setBigDecimal(3, salary.bigDecimal) + insertStmt2.setString(4, surname) + insertStmt2.executeUpdate() + } + insertStmt2.close() + } + + // Condition-less joins are not supported in join pushdown + test("Test that 2-way join without condition should not have join pushed down") { + val sqlQuery = + s""" + |SELECT * FROM + |$catalogAndNamespace.${caseConvert(joinTableName1)} a, + |$catalogAndNamespace.${caseConvert(joinTableName1)} b + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + // Condition-less joins are not supported in join pushdown + test("Test that multi-way join without condition should not have join pushed down") { + val sqlQuery = s""" + |SELECT * FROM + |$catalogAndNamespace.${caseConvert(joinTableName1)} a, + |$catalogAndNamespace.${caseConvert(joinTableName1)} b, + |$catalogAndNamespace.${caseConvert(joinTableName1)} c + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test self join with condition") { + val sqlQuery = s""" + |SELECT * FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b + |ON a.id = b.id + 1""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test multi-way self join with conditions") { + val sqlQuery = s""" + |SELECT * FROM + |$catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + assert(!rows.isEmpty) + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed(df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test self join with column pruning") { + val sqlQuery = s""" + |SELECT a.id + 2, b.id, b.amount FROM + |$catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b + |ON a.id = b.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test 2-way join with column pruning - different tables") { + val sqlQuery = s""" + |SELECT a.id, b.next_id FROM + |$catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} b + |ON a.id = b.next_id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + + s" $catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkPushedInfo(df, + "PushedFilters: [ID IS NOT NULL, NEXT_ID IS NOT NULL, ID = NEXT_ID]") + checkAnswer(df, rows) + } + } + + test("Test multi-way self join with column pruning") { + val sqlQuery = s""" + |SELECT a.id, b.*, c.id, c.amount + a.amount + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") + checkAnswer(df, rows) + } + } + + test("Test aliases not supported in join pushdown") { + val sqlQuery = s""" + |SELECT a.id, bc.* + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN ( + | SELECT b.*, c.id AS c_id, c.amount AS c_amount + | FROM $catalogAndNamespace.${caseConvert(joinTableName1)} b + | JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |) bc ON bc.id = a.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with dataframe with duplicated columns") { + val df1 = sql(s"SELECT id FROM $catalogAndNamespace.${caseConvert(joinTableName1)}") + val df2 = sql(s"SELECT id, id FROM $catalogAndNamespace.${caseConvert(joinTableName1)}") + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + df1.join(df2, "id").collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val joinDf = df1.join(df2, "id") + + checkJoinPushed( + joinDf, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(joinDf, rows) + } + } + + test("Test aggregate on top of 2-way self join") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), min(a.id) + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON a.id = b.id + 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkAggregateRemoved(df) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + + checkAnswer(df, rows) + } + } + + test("Test aggregate on top of multi-way self join") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), min(a.id), min(c.id - 2) + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + + s" $catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") + checkAnswer(df, rows) + } + } + + test("Test sort limit on top of join is pushed down") { + val sqlQuery = s""" + |SELECT min(a.id + b.id), a.id, b.id + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 + |GROUP BY a.id, b.id + |ORDER BY a.id + |LIMIT 1 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf( + SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + + checkSortRemoved(df) + checkLimitRemoved(df) + + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" + ) + checkAnswer(df, rows) + } + } + + test("Test join with additional filters") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname, t1.amount, t2.salary + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |WHERE t1.amount > 5000 AND t2.salary > 25000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkFilterPushed(df) + checkAnswer(df, rows) + } + } + + test("Test join with complex condition") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname, t1.amount + t2.salary as total + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 + |ON t1.id = t2.id AND t1.amount > 1000 + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinPushed( + df, + s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + ) + checkAnswer(df, rows) + } + } + + test("Test left outer join should not be pushed down") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |LEFT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } + + test("Test right outer join should not be pushed down") { + val sqlQuery = + s""" + |SELECT t1.id, t1.address, t2.surname + |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 + |RIGHT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |""".stripMargin + + val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { + sql(sqlQuery).collect().toSeq + } + + withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { + val df = sql(sqlQuery) + checkJoinNotPushed(df) + checkAnswer(df, rows) + } + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala similarity index 75% rename from connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala index 78d1016a0d1f2..f7f9ba0798b91 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.jdbc.v2 -import org.apache.spark.sql.{DataFrame} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sample, Sort} +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, LocalLimit, Offset, Sample, Sort} import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} +import org.apache.spark.sql.internal.SQLConf -trait V2JDBCPushdownTestUtils { +trait V2JDBCPushdownTestUtils extends ExplainSuiteHelper { protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { val sample = df.queryExecution.optimizedPlan.collect { case s: Sample => s @@ -80,7 +81,6 @@ trait V2JDBCPushdownTestUtils { assert(aggregates.isEmpty) } - protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { df.queryExecution.optimizedPlan.collect { case DataSourceV2ScanRelation(_, scan, _, _, _) => @@ -89,9 +89,10 @@ trait V2JDBCPushdownTestUtils { assert(wrapper.pushedDownOperators.aggregation.isDefined) val aggregationExpressions = wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions.length == 1) - assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) - assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + assert(aggregationExpressions.exists { expr => + expr.isInstanceOf[GeneralAggregateFunc] && + expr.asInstanceOf[GeneralAggregateFunc].name() == funcName + }) } } @@ -127,4 +128,30 @@ trait V2JDBCPushdownTestUtils { } } } + + protected def checkJoinNotPushed(df: DataFrame): Unit = { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.nonEmpty, "Join should not be pushed down") + } + + protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.isEmpty, "Join should be pushed down") + if (expectedTables.nonEmpty) { + checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") + } + } + + protected def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = { + withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expectedPlanFragment: _*) + } + } + } } From 98a05c9f5121f8cea7a2405c7901c7e28a581b48 Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Thu, 17 Jul 2025 18:15:22 +0200 Subject: [PATCH 3/9] resolve comments --- .../org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 3 ++- .../v2/join/OracleJoinPushdownIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala | 2 +- .../DataSourcePushdownTestUtils.scala} | 6 +++--- ...a => JDBCV2JoinPushdownIntegrationSuiteBase.scala} | 11 ++++++----- .../sql/jdbc/{ => v2}/JDBCV2JoinPushdownSuite.scala | 11 ++++++----- 6 files changed, 21 insertions(+), 18 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{jdbc/v2/V2JDBCPushdownTestUtils.scala => connector/DataSourcePushdownTestUtils.scala} (95%) rename sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/{JDBCJoinPushdownIntegrationSuite.scala => JDBCV2JoinPushdownIntegrationSuiteBase.scala} (97%) rename sql/core/src/test/scala/org/apache/spark/sql/jdbc/{ => v2}/JDBCV2JoinPushdownSuite.scala (89%) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 0c225a0592ace..fb18dca97994e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -22,6 +22,7 @@ import org.apache.logging.log4j.Level import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.NullOrdering @@ -33,7 +34,7 @@ import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCTest - extends V2JDBCPushdownTestUtils + extends DataSourcePushdownTestUtils with DockerIntegrationFunSuite with SharedSparkSession { import testImplicits._ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala index dc2de6f17b550..b86b2b800080d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -23,7 +23,7 @@ import java.util.Locale import org.apache.spark.SparkConf import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} -import org.apache.spark.sql.jdbc.v2.JDBCJoinPushdownIntegrationSuite +import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase import org.apache.spark.tags.DockerTest /** @@ -56,10 +56,10 @@ import org.apache.spark.tags.DockerTest @DockerTest class OracleJoinPushdownIntegrationSuite extends DockerJDBCIntegrationSuite - with JDBCJoinPushdownIntegrationSuite { + with JDBCV2JoinPushdownIntegrationSuiteBase { override val catalogName: String = "oracle" - override def namespaceOpt: Option[String] = Some("SYSTEM") + override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new OracleDatabaseOnDocker diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala index e25cfb9668e77..d0f53fa20e1aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala @@ -208,7 +208,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) { // If join has been pushed down, reuse join query as a subquery. Otherwise, fallback to // what is provided in options. - def tableOrQuery: String = joinQuery.getOrElse(options.tableOrQuery) + protected final def tableOrQuery: String = joinQuery.getOrElse(options.tableOrQuery) /** * Build the final SQL query that following dialect's SQL syntax. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala index f7f9ba0798b91..ac78faa09f22c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCPushdownTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc.v2 +package org.apache.spark.sql.connector import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, LocalLimit, Offset, Sample, Sort} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.internal.SQLConf -trait V2JDBCPushdownTestUtils extends ExplainSuiteHelper { +trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { val sample = df.queryExecution.optimizedPlan.collect { case s: Sample => s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index e867ff419ba6d..9bac5c31d6e4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCJoinPushdownIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -21,18 +21,19 @@ import java.sql.{Connection, DriverManager} import java.util.Properties import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialect import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{DataType, DataTypes} -trait JDBCJoinPushdownIntegrationSuite +trait JDBCV2JoinPushdownIntegrationSuiteBase extends QueryTest with SharedSparkSession - with V2JDBCPushdownTestUtils { + with DataSourcePushdownTestUtils { val catalogName: String - def namespaceOpt: Option[String] = None + val namespaceOpt: Option[String] = None val url: String val joinTableName1: String = "join_table_1" @@ -46,9 +47,9 @@ trait JDBCJoinPushdownIntegrationSuite def qualifyTableName(tableName: String): String = namespaceOpt .map(namespace => s"$namespace.$tableName").getOrElse(tableName) - private val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) + private lazy val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) - private val fullyQualifiedTableName2: String = qualifyTableName(joinTableName2) + private lazy val fullyQualifiedTableName2: String = qualifyTableName(joinTableName2) protected def getJDBCTypeString(dt: DataType): String = { JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala similarity index 89% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala index 8b16310b7a4cf..3b897b5a2b39f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} +import org.apache.spark.sql.connector.DataSourcePushdownTestUtils import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.v2.{JDBCJoinPushdownIntegrationSuite, V2JDBCPushdownTestUtils} +import org.apache.spark.sql.jdbc.{H2Dialect, JdbcDialect} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -30,13 +31,13 @@ class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with ExplainSuiteHelper - with V2JDBCPushdownTestUtils - with JDBCJoinPushdownIntegrationSuite { + with DataSourcePushdownTestUtils + with JDBCV2JoinPushdownIntegrationSuiteBase { val tempDir = Utils.createTempDir() override val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" override val catalogName: String = "h2" - override def namespaceOpt: Option[String] = Some("test") + override val namespaceOpt: Option[String] = Some("test") override val jdbcDialect: JdbcDialect = H2Dialect() From 064ba5eed57d5ae1ff1a943346dd90113c415c1e Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Fri, 18 Jul 2025 09:59:10 +0200 Subject: [PATCH 4/9] resolve comments --- .../v2/join/OracleJoinPushdownIntegrationSuite.scala | 12 ------------ .../v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala | 12 +++++++++++- .../spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala | 8 -------- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala index b86b2b800080d..8c4665608fdfe 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.jdbc.v2.join import java.sql.Connection import java.util.Locale -import org.apache.spark.SparkConf -import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase import org.apache.spark.tags.DockerTest @@ -57,22 +55,12 @@ import org.apache.spark.tags.DockerTest class OracleJoinPushdownIntegrationSuite extends DockerJDBCIntegrationSuite with JDBCV2JoinPushdownIntegrationSuiteBase { - override val catalogName: String = "oracle" - override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new OracleDatabaseOnDocker override val url = db.getJdbcUrl(dockerIp, externalPort) - override def sparkConf: SparkConf = super.sparkConf - .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) - .set(s"spark.sql.catalog.$catalogName.url", url) - .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") - .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") - .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") - .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") - override val jdbcDialect: JdbcDialect = OracleDialect() override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index 9bac5c31d6e4e..b8019095b304c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.{Connection, DriverManager} import java.util.Properties +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.connector.DataSourcePushdownTestUtils import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialect import org.apache.spark.sql.test.SharedSparkSession @@ -32,7 +34,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase extends QueryTest with SharedSparkSession with DataSourcePushdownTestUtils { - val catalogName: String + val catalogName: String = "join_pushdown_catalog" val namespaceOpt: Option[String] = None val url: String @@ -41,6 +43,14 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val jdbcDialect: JdbcDialect + override def sparkConf: SparkConf = super.sparkConf + .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) + .set(s"spark.sql.catalog.$catalogName.url", url) + .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") + .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") + private def catalogAndNamespace = namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala index 3b897b5a2b39f..804dc1fc761e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -22,7 +22,6 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} import org.apache.spark.sql.connector.DataSourcePushdownTestUtils -import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{H2Dialect, JdbcDialect} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -36,19 +35,12 @@ class JDBCV2JoinPushdownSuite val tempDir = Utils.createTempDir() override val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" - override val catalogName: String = "h2" override val namespaceOpt: Option[String] = Some("test") override val jdbcDialect: JdbcDialect = H2Dialect() override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) - .set("spark.sql.catalog.h2.url", url) .set("spark.sql.catalog.h2.driver", "org.h2.Driver") - .set("spark.sql.catalog.h2.pushDownAggregate", "true") - .set("spark.sql.catalog.h2.pushDownLimit", "true") - .set("spark.sql.catalog.h2.pushDownOffset", "true") - .set("spark.sql.catalog.h2.pushDownJoin", "true") override def qualifyTableName(tableName: String): String = namespaceOpt .map(namespace => s""""$namespace"."$tableName"""").getOrElse(s""""$tableName"""") From e5ec574e00a198f2b0cb81ab3cdb98faa8f4519c Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Fri, 18 Jul 2025 11:39:06 +0200 Subject: [PATCH 5/9] resolve comments --- .../OracleJoinPushdownIntegrationSuite.scala | 12 +- .../DataSourcePushdownTestUtils.scala | 222 +++++++++++------ ...BCV2JoinPushdownIntegrationSuiteBase.scala | 224 +++++++++++------- .../sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala | 15 +- 4 files changed, 302 insertions(+), 171 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala index 8c4665608fdfe..28d119e859908 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect} import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase +import org.apache.spark.sql.types.DataTypes import org.apache.spark.tags.DockerTest /** @@ -55,7 +56,7 @@ import org.apache.spark.tags.DockerTest class OracleJoinPushdownIntegrationSuite extends DockerJDBCIntegrationSuite with JDBCV2JoinPushdownIntegrationSuiteBase { - override val namespaceOpt: Option[String] = Some("SYSTEM") + override val namespace: String = "SYSTEM" override val db = new OracleDatabaseOnDocker @@ -63,7 +64,14 @@ class OracleJoinPushdownIntegrationSuite override val jdbcDialect: JdbcDialect = OracleDialect() + override val integerType = DataTypes.createDecimalType(10, 0) + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) - override def schemaPreparation(connection: Connection): Unit = {} + override def schemaPreparation(): Unit = {} + + // This method comes from DockerJDBCIntegrationSuite + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala index ac78faa09f22c..2816eb79f1f82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala @@ -22,127 +22,173 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { + protected val supportsSamplePushdown: Boolean = true + + protected val supportsFilterPushdown: Boolean = true + + protected val supportsLimitPushdown: Boolean = true + + protected val supportsAggregatePushdown: Boolean = true + + protected val supportsSortPushdown: Boolean = true + + protected val supportsOffsetPushdown: Boolean = true + + protected val supportsColumnPruning: Boolean = true + + protected val supportsJoinPushdown: Boolean = true + + protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { - val sample = df.queryExecution.optimizedPlan.collect { - case s: Sample => s - } - if (pushed) { - assert(sample.isEmpty) - } else { - assert(sample.nonEmpty) + if (supportsSamplePushdown) { + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } } } protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { - val filter = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - if (pushed) { - assert(filter.isEmpty) - } else { - assert(filter.nonEmpty) + if (supportsFilterPushdown) { + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } } } protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val limit = df.queryExecution.optimizedPlan.collect { - case l: LocalLimit => l - case g: GlobalLimit => g - } - if (pushed) { - assert(limit.isEmpty) - } else { - assert(limit.nonEmpty) + if (supportsLimitPushdown) { + val limit = df.queryExecution.optimizedPlan.collect { + case l: LocalLimit => l + case g: GlobalLimit => g + } + if (pushed) { + assert(limit.isEmpty) + } else { + assert(limit.nonEmpty) + } } } protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit == limit) + if (supportsLimitPushdown) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.limit == limit) + } } } } protected def checkColumnPruned(df: DataFrame, col: String): Unit = { - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq(col))) + if (supportsColumnPruning) { + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq(col))) + } } protected def checkAggregateRemoved(df: DataFrame): Unit = { - val aggregates = df.queryExecution.optimizedPlan.collect { - case agg: Aggregate => agg + if (supportsAggregatePushdown) { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) } - assert(aggregates.isEmpty) } protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _, _, _) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions.exists { expr => - expr.isInstanceOf[GeneralAggregateFunc] && - expr.asInstanceOf[GeneralAggregateFunc].name() == funcName - }) + if (supportsAggregatePushdown) { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _, _, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.exists { expr => + expr.isInstanceOf[GeneralAggregateFunc] && + expr.asInstanceOf[GeneralAggregateFunc].name() == funcName + }) + } } } - protected def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } + protected def checkSortRemoved( + df: DataFrame, + pushed: Boolean = true): Unit = { + if (supportsSortPushdown) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } - if (pushed) { - assert(sorts.isEmpty) - } else { - assert(sorts.nonEmpty) + if (pushed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) + } } } - protected def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - val offsets = df.queryExecution.optimizedPlan.collect { - case o: Offset => o - } + protected def checkOffsetRemoved( + df: DataFrame, + pushed: Boolean = true): Unit = { + if (supportsOffsetPushdown) { + val offsets = df.queryExecution.optimizedPlan.collect { + case o: Offset => o + } - if (pushed) { - assert(offsets.isEmpty) - } else { - assert(offsets.nonEmpty) + if (pushed) { + assert(offsets.isEmpty) + } else { + assert(offsets.nonEmpty) + } } } protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.offset == offset) + if (supportsOffsetPushdown) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.offset == offset) + } } } } protected def checkJoinNotPushed(df: DataFrame): Unit = { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j + if (supportsJoinPushdown) { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.nonEmpty, "Join should not be pushed down") } - assert(joinNodes.nonEmpty, "Join should not be pushed down") } protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - assert(joinNodes.isEmpty, "Join should be pushed down") - if (expectedTables.nonEmpty) { - checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") + if (supportsJoinPushdown) { + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.isEmpty, "Join should be pushed down") + if (expectedTables.nonEmpty) { + checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") + } } } @@ -154,4 +200,34 @@ trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { } } } + + /** + * Check if the output schema of dataframe {@code df} is same as {@code schema}. There is one + * limitation: if expected schema name is empty, assertion on same names will be skipped. + *
+ * For example, it is not really possible to use {@code checkPrunedColumns} for join pushdown, + * because in case of duplicate names, columns will have random UUID suffixes. For this reason, + * the best we can do is test that the size is same, and other fields beside names do match. + */ + protected def checkPrunedColumnsDataTypeAndNullability( + df: DataFrame, + schema: StructType): Unit = { + if (supportsColumnPruning) { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + val dfSchema = v1.readSchema() + + assert(dfSchema.length == schema.length) + dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) => + if (f2.name.nonEmpty) { + assert(f1.name == f2.name) + } + assert(f1.dataType == f2.dataType) + assert(f1.nullable == f2.nullable) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index b8019095b304c..4d46091fb2a89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialect import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} trait JDBCV2JoinPushdownIntegrationSuiteBase extends QueryTest with SharedSparkSession with DataSourcePushdownTestUtils { val catalogName: String = "join_pushdown_catalog" - val namespaceOpt: Option[String] = None + val namespace: String = "join_schema" val url: String val joinTableName1: String = "join_table_1" @@ -51,11 +51,11 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") - private def catalogAndNamespace = - namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + private def catalogAndNamespace = s"$catalogName.$namespace" - def qualifyTableName(tableName: String): String = namespaceOpt - .map(namespace => s"$namespace.$tableName").getOrElse(tableName) + def qualifyTableName(tableName: String): String = s"$namespace.$tableName" + + def qualifySchemaName(schemaName: String): String = namespace private lazy val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) @@ -76,75 +76,103 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase } } - def dataPreparation(connection: Connection): Unit = { - schemaPreparation(connection) - tablePreparation(connection) - fillJoinTables(connection) + protected val integerType = DataTypes.IntegerType + + protected val stringType = DataTypes.StringType + + protected val decimalType = DataTypes.createDecimalType(10, 2) + + /** + * This method should cover the following: + *
    + *
  • Create the schema where testing tables will be stored. + *
  • Create the testing tables {@code joinTableName1} and {@code joinTableName2} + * in above schema. + *
  • Populate the tables with the data. + *
+ */ + def dataPreparation(): Unit = { + schemaPreparation() + tablePreparation() + fillJoinTables() } - def schemaPreparation(connection: Connection): Unit = { - connection.prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${namespaceOpt.get}").executeUpdate() + def schemaPreparation(): Unit = { + withConnection {conn => + conn + .prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${qualifySchemaName(namespace)}") + .executeUpdate() + } } - def tablePreparation(connection: Connection): Unit = { - connection.prepareStatement( - s"""CREATE TABLE $fullyQualifiedTableName1 ( - | ID ${getJDBCTypeString(DataTypes.IntegerType)}, - | AMOUNT ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, - | ADDRESS ${getJDBCTypeString(DataTypes.StringType)} - |)""".stripMargin - ).executeUpdate() - - connection.prepareStatement( - s"""CREATE TABLE $fullyQualifiedTableName2 ( - | ID ${getJDBCTypeString(DataTypes.IntegerType)}, - | NEXT_ID ${getJDBCTypeString(DataTypes.IntegerType)}, - | SALARY ${getJDBCTypeString(DataTypes.createDecimalType(10, 2))}, - | SURNAME ${getJDBCTypeString(DataTypes.StringType)} - |)""".stripMargin - ).executeUpdate() + def tablePreparation(): Unit = { + withConnection{ conn => + conn.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName1 ( + | ID ${getJDBCTypeString(integerType)}, + | AMOUNT ${getJDBCTypeString(decimalType)}, + | ADDRESS ${getJDBCTypeString(stringType)} + |)""".stripMargin + ).executeUpdate() + + conn.prepareStatement( + s"""CREATE TABLE $fullyQualifiedTableName2 ( + | ID ${getJDBCTypeString(integerType)}, + | NEXT_ID ${getJDBCTypeString(integerType)}, + | SALARY ${getJDBCTypeString(decimalType)}, + | SURNAME ${getJDBCTypeString(stringType)} + |)""".stripMargin + ).executeUpdate() + } } - def fillJoinTables(connection: Connection): Unit = { - val random = new java.util.Random(42) - val table1Data = (1 to 100).map { i => - val id = i % 11 - val amount = BigDecimal.valueOf(random.nextDouble() * 10000) - .setScale(2, BigDecimal.RoundingMode.HALF_UP) - val address = s"address_$i" - (id, amount, address) - } - val table2Data = (1 to 100).map { i => - val id = (i % 17) - val next_id = (id + 1) % 17 - val salary = BigDecimal.valueOf(random.nextDouble() * 50000) - .setScale(2, BigDecimal.RoundingMode.HALF_UP) - val surname = s"surname_$i" - (id, next_id, salary, surname) - } - - val insertStmt1 = connection.prepareStatement( - s"INSERT INTO $fullyQualifiedTableName1 (id, amount, address) VALUES (?, ?, ?)" - ) - table1Data.foreach { case (id, amount, address) => - insertStmt1.setInt(1, id) - insertStmt1.setBigDecimal(2, amount.bigDecimal) - insertStmt1.setString(3, address) - insertStmt1.executeUpdate() - } - insertStmt1.close() - - val insertStmt2 = connection.prepareStatement( - s"INSERT INTO $fullyQualifiedTableName2 (id, next_id, salary, surname) VALUES (?, ?, ?, ?)" - ) - table2Data.foreach { case (id, next_id, salary, surname) => - insertStmt2.setInt(1, id) - insertStmt2.setInt(2, next_id) - insertStmt2.setBigDecimal(3, salary.bigDecimal) - insertStmt2.setString(4, surname) - insertStmt2.executeUpdate() - } - insertStmt2.close() + private val random = new java.util.Random(42) + + private val table1Data = (1 to 100).map { i => + val id = i % 11 + val amount = BigDecimal.valueOf(random.nextDouble() * 10000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val address = s"address_$i" + (id, amount, address) + } + + private val table2Data = (1 to 100).map { i => + val id = (i % 17) + val next_id = (id + 1) % 17 + val salary = BigDecimal.valueOf(random.nextDouble() * 50000) + .setScale(2, BigDecimal.RoundingMode.HALF_UP) + val surname = s"surname_$i" + (id, next_id, salary, surname) + } + + def fillJoinTables(): Unit = { + withConnection { conn => + val insertStmt1 = conn.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName1 (id, amount, address) VALUES (?, ?, ?)" + ) + table1Data.foreach { case (id, amount, address) => + insertStmt1.setInt(1, id) + insertStmt1.setBigDecimal(2, amount.bigDecimal) + insertStmt1.setString(3, address) + insertStmt1.addBatch() + } + insertStmt1.executeBatch() + insertStmt1.close() + + val insertStmt2 = conn.prepareStatement( + s"INSERT INTO $fullyQualifiedTableName2 (id, next_id, salary, surname) VALUES (?, ?, ?, ?)" + ) + table2Data.foreach { case (id, next_id, salary, surname) => + insertStmt2.setInt(1, id) + insertStmt2.setInt(2, next_id) + insertStmt2.setBigDecimal(3, salary.bigDecimal) + insertStmt2.setString(4, surname) + insertStmt2.addBatch() + } + insertStmt2.executeBatch() + insertStmt2.close() + + } } // Condition-less joins are not supported in join pushdown @@ -204,7 +232,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) checkAnswer(df, rows) @@ -228,8 +256,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) - checkJoinPushed(df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + checkJoinPushed( + df, + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) @@ -252,9 +281,17 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) + val expectedSchemaWithoutNames = StructType( + Seq( + StructField("", integerType), // ID + StructField("", integerType), // NEXT_ID + StructField("AMOUNT", decimalType) // AMOUNT + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) checkAnswer(df, rows) @@ -276,10 +313,17 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) + val expectedSchemaWithoutNames = StructType( + Seq( + StructField("ID", integerType), // ID + StructField("NEXT_ID", integerType) // NEXT_ID + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + - s" $catalogAndNamespace.${caseConvert(joinTableName2)}" + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" ) checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, NEXT_ID IS NOT NULL, ID = NEXT_ID]") @@ -302,9 +346,21 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) + val expectedSchemaWithoutNames = StructType( + Seq( + StructField("", integerType), // ID_UUID + StructField("", decimalType), // AMOUNT_UUID + StructField("", integerType), // ID_UUID + StructField("", decimalType), // AMOUNT_UUID + StructField("ADDRESS", stringType), // ADDRESS + StructField("ID", integerType), // ID + StructField("AMOUNT", decimalType) // AMOUNT + ) + ) + checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") checkAnswer(df, rows) @@ -347,7 +403,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase checkJoinPushed( joinDf, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) checkAnswer(joinDf, rows) @@ -371,7 +427,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase checkAggregateRemoved(df) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) @@ -396,7 +452,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}," + s" $catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}") checkAnswer(df, rows) @@ -426,7 +482,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + s"$catalogAndNamespace.${caseConvert(joinTableName1)}" ) checkAnswer(df, rows) @@ -450,8 +506,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val df = sql(sqlQuery) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" ) checkFilterPushed(df) checkAnswer(df, rows) @@ -475,8 +531,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val df = sql(sqlQuery) checkJoinPushed( df, - s"$catalogAndNamespace.${caseConvert(joinTableName1)}", - s"$catalogAndNamespace.${caseConvert(joinTableName2)}" + expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", + s"$catalogAndNamespace.${caseConvert(joinTableName2)}" ) checkAnswer(df, rows) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala index 804dc1fc761e4..90160fb1b305a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.jdbc.v2 -import java.sql.Connection - import org.apache.spark.SparkConf import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} import org.apache.spark.sql.connector.DataSourcePushdownTestUtils @@ -35,26 +33,19 @@ class JDBCV2JoinPushdownSuite val tempDir = Utils.createTempDir() override val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" - override val namespaceOpt: Option[String] = Some("test") - override val jdbcDialect: JdbcDialect = H2Dialect() override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2.driver", "org.h2.Driver") - override def qualifyTableName(tableName: String): String = namespaceOpt - .map(namespace => s""""$namespace"."$tableName"""").getOrElse(s""""$tableName"""") + override def qualifyTableName(tableName: String): String = s""""$namespace"."$tableName"""" - override def schemaPreparation(connection: Connection): Unit = { - connection - .prepareStatement(s"""CREATE SCHEMA IF NOT EXISTS "${namespaceOpt.get}"""") - .executeUpdate() - } + override def qualifySchemaName(schemaName: String): String = s""""$namespace"""" override def beforeAll(): Unit = { Utils.classForName("org.h2.Driver") super.beforeAll() - withConnection(dataPreparation) + dataPreparation() } override def afterAll(): Unit = { From 274a612a1ccb9bcb1e8759b7ce70c6214490f952 Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Fri, 18 Jul 2025 18:09:57 +0200 Subject: [PATCH 6/9] use quoteIdentifier from dialect --- .../OracleJoinPushdownIntegrationSuite.scala | 2 +- ...BCV2JoinPushdownIntegrationSuiteBase.scala | 108 ++++++++++-------- .../sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala | 8 +- 3 files changed, 64 insertions(+), 54 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala index 28d119e859908..ecc0c5489bceb 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala @@ -66,7 +66,7 @@ class OracleJoinPushdownIntegrationSuite override val integerType = DataTypes.createDecimalType(10, 0) - override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + override def caseConvert(identifier: String): String = identifier.toUpperCase(Locale.ROOT) override def schemaPreparation(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index 4d46091fb2a89..c7f8675187548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -50,12 +50,20 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true") .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true") + .set(s"spark.sql.catalog.$catalogName.caseSensitive", "false") - private def catalogAndNamespace = s"$catalogName.$namespace" + private def catalogAndNamespace = s"$catalogName.${caseConvert(namespace)}" + private def casedJoinTableName1 = caseConvert(joinTableName1) + private def casedJoinTableName2 = caseConvert(joinTableName2) - def qualifyTableName(tableName: String): String = s"$namespace.$tableName" + def qualifyTableName(tableName: String): String = { + val fullyQualifiedCasedNamespace = jdbcDialect.quoteIdentifier(caseConvert(namespace)) + val fullyQualifiedCasedTableName = jdbcDialect.quoteIdentifier(caseConvert(tableName)) + s"$fullyQualifiedCasedNamespace.$fullyQualifiedCasedTableName" + } - def qualifySchemaName(schemaName: String): String = namespace + def quoteSchemaName(schemaName: String): String = + jdbcDialect.quoteIdentifier(caseConvert(namespace)) private lazy val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1) @@ -65,7 +73,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase() } - protected def caseConvert(tableName: String): String = tableName + protected def caseConvert(identifier: String): String = identifier protected def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -100,7 +108,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase def schemaPreparation(): Unit = { withConnection {conn => conn - .prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${qualifySchemaName(namespace)}") + .prepareStatement(s"CREATE SCHEMA IF NOT EXISTS ${quoteSchemaName(namespace)}") .executeUpdate() } } @@ -180,8 +188,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val sqlQuery = s""" |SELECT * FROM - |$catalogAndNamespace.${caseConvert(joinTableName1)} a, - |$catalogAndNamespace.${caseConvert(joinTableName1)} b + |$catalogAndNamespace.$casedJoinTableName1 a, + |$catalogAndNamespace.$casedJoinTableName1 b |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -200,9 +208,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test that multi-way join without condition should not have join pushed down") { val sqlQuery = s""" |SELECT * FROM - |$catalogAndNamespace.${caseConvert(joinTableName1)} a, - |$catalogAndNamespace.${caseConvert(joinTableName1)} b, - |$catalogAndNamespace.${caseConvert(joinTableName1)} c + |$catalogAndNamespace.$casedJoinTableName1 a, + |$catalogAndNamespace.$casedJoinTableName1 b, + |$catalogAndNamespace.$casedJoinTableName1 c |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -219,8 +227,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test self join with condition") { val sqlQuery = s""" - |SELECT * FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b + |SELECT * FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b |ON a.id = b.id + 1""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -242,9 +250,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test multi-way self join with conditions") { val sqlQuery = s""" |SELECT * FROM - |$catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -269,8 +277,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test self join with column pruning") { val sqlQuery = s""" |SELECT a.id + 2, b.id, b.amount FROM - |$catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b |ON a.id = b.id + 1 |""".stripMargin @@ -285,7 +293,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase Seq( StructField("", integerType), // ID StructField("", integerType), // NEXT_ID - StructField("AMOUNT", decimalType) // AMOUNT + StructField(caseConvert("amount"), decimalType) // AMOUNT ) ) checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) @@ -301,8 +309,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test 2-way join with column pruning - different tables") { val sqlQuery = s""" |SELECT a.id, b.next_id FROM - |$catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} b + |$catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName2 b |ON a.id = b.next_id |""".stripMargin @@ -315,8 +323,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val expectedSchemaWithoutNames = StructType( Seq( - StructField("ID", integerType), // ID - StructField("NEXT_ID", integerType) // NEXT_ID + StructField(caseConvert("id"), integerType), // ID + StructField(caseConvert("next_id"), integerType) // NEXT_ID ) ) checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) @@ -326,7 +334,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase s"$catalogAndNamespace.${caseConvert(joinTableName2)}" ) checkPushedInfo(df, - "PushedFilters: [ID IS NOT NULL, NEXT_ID IS NOT NULL, ID = NEXT_ID]") + s"PushedFilters: [${caseConvert("id")} IS NOT NULL, " + + s"${caseConvert("next_id")} IS NOT NULL, " + + s"${caseConvert("id")} = ${caseConvert("next_id")}]") checkAnswer(df, rows) } } @@ -334,9 +344,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test multi-way self join with column pruning") { val sqlQuery = s""" |SELECT a.id, b.*, c.id, c.amount + a.amount - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -352,9 +362,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase StructField("", decimalType), // AMOUNT_UUID StructField("", integerType), // ID_UUID StructField("", decimalType), // AMOUNT_UUID - StructField("ADDRESS", stringType), // ADDRESS - StructField("ID", integerType), // ID - StructField("AMOUNT", decimalType) // AMOUNT + StructField(caseConvert("address"), stringType), // ADDRESS + StructField(caseConvert("id"), integerType), // ID + StructField(caseConvert("amount"), decimalType) // AMOUNT ) ) checkPrunedColumnsDataTypeAndNullability(df, expectedSchemaWithoutNames) @@ -370,11 +380,11 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test aliases not supported in join pushdown") { val sqlQuery = s""" |SELECT a.id, bc.* - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a + |FROM $catalogAndNamespace.$casedJoinTableName1 a |JOIN ( | SELECT b.*, c.id AS c_id, c.amount AS c_amount - | FROM $catalogAndNamespace.${caseConvert(joinTableName1)} b - | JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + | FROM $catalogAndNamespace.$casedJoinTableName1 b + | JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 |) bc ON bc.id = a.id + 1 |""".stripMargin @@ -391,8 +401,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase } test("Test join with dataframe with duplicated columns") { - val df1 = sql(s"SELECT id FROM $catalogAndNamespace.${caseConvert(joinTableName1)}") - val df2 = sql(s"SELECT id, id FROM $catalogAndNamespace.${caseConvert(joinTableName1)}") + val df1 = sql(s"SELECT id FROM $catalogAndNamespace.$casedJoinTableName1") + val df2 = sql(s"SELECT id, id FROM $catalogAndNamespace.$casedJoinTableName1") val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { df1.join(df2, "id").collect().toSeq @@ -413,8 +423,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test aggregate on top of 2-way self join") { val sqlQuery = s""" |SELECT min(a.id + b.id), min(a.id) - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON a.id = b.id + 1 + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON a.id = b.id + 1 |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -438,9 +448,9 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test aggregate on top of multi-way self join") { val sqlQuery = s""" |SELECT min(a.id + b.id), min(a.id), min(c.id - 2) - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} c ON c.id = b.id - 1 + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 + |JOIN $catalogAndNamespace.$casedJoinTableName1 c ON c.id = b.id - 1 |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -462,8 +472,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase test("Test sort limit on top of join is pushed down") { val sqlQuery = s""" |SELECT min(a.id + b.id), a.id, b.id - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} a - |JOIN $catalogAndNamespace.${caseConvert(joinTableName1)} b ON b.id = a.id + 1 + |FROM $catalogAndNamespace.$casedJoinTableName1 a + |JOIN $catalogAndNamespace.$casedJoinTableName1 b ON b.id = a.id + 1 |GROUP BY a.id, b.id |ORDER BY a.id |LIMIT 1 @@ -493,8 +503,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val sqlQuery = s""" |SELECT t1.id, t1.address, t2.surname, t1.amount, t2.salary - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id |WHERE t1.amount > 5000 AND t2.salary > 25000 |""".stripMargin @@ -518,8 +528,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val sqlQuery = s""" |SELECT t1.id, t1.address, t2.surname, t1.amount + t2.salary as total - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |JOIN $catalogAndNamespace.$casedJoinTableName2 t2 |ON t1.id = t2.id AND t1.amount > 1000 |""".stripMargin @@ -542,8 +552,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val sqlQuery = s""" |SELECT t1.id, t1.address, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |LEFT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |LEFT JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { @@ -561,8 +571,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase val sqlQuery = s""" |SELECT t1.id, t1.address, t2.surname - |FROM $catalogAndNamespace.${caseConvert(joinTableName1)} t1 - |RIGHT JOIN $catalogAndNamespace.${caseConvert(joinTableName2)} t2 ON t1.id = t2.id + |FROM $catalogAndNamespace.$casedJoinTableName1 t1 + |RIGHT JOIN $catalogAndNamespace.$casedJoinTableName2 t2 ON t1.id = t2.id |""".stripMargin val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala index 90160fb1b305a..026fcf3126302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.jdbc.v2 +import java.util.Locale + import org.apache.spark.SparkConf import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} import org.apache.spark.sql.connector.DataSourcePushdownTestUtils @@ -36,11 +38,9 @@ class JDBCV2JoinPushdownSuite override val jdbcDialect: JdbcDialect = H2Dialect() override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.h2.driver", "org.h2.Driver") - - override def qualifyTableName(tableName: String): String = s""""$namespace"."$tableName"""" + .set(s"spark.sql.catalog.$catalogName.driver", "org.h2.Driver") - override def qualifySchemaName(schemaName: String): String = s""""$namespace"""" + override def caseConvert(identifier: String): String = identifier.toUpperCase(Locale.ROOT) override def beforeAll(): Unit = { Utils.classForName("org.h2.Driver") From 206bf5f74e5bddf511d3b6524048877f80f0da7d Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Mon, 21 Jul 2025 14:27:10 +0200 Subject: [PATCH 7/9] enable ansi: --- .../sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index c7f8675187548..244ae40c48a9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -45,6 +45,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase override def sparkConf: SparkConf = super.sparkConf .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName) + .set(SQLConf.ANSI_ENABLED.key, "true") .set(s"spark.sql.catalog.$catalogName.url", url) .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true") .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true") From cf1275eb737d531653dffb40d4e26a1546677d61 Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Mon, 21 Jul 2025 16:34:05 +0200 Subject: [PATCH 8/9] register h2 dialect --- .../apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala index 026fcf3126302..6304c1a1c54f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.SparkConf import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest} import org.apache.spark.sql.connector.DataSourcePushdownTestUtils -import org.apache.spark.sql.jdbc.{H2Dialect, JdbcDialect} +import org.apache.spark.sql.jdbc.{H2Dialect, JdbcDialect, JdbcDialects} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -46,6 +46,11 @@ class JDBCV2JoinPushdownSuite Utils.classForName("org.h2.Driver") super.beforeAll() dataPreparation() + // Registering the dialect because of CI running multiple tests. For example, in + // QueryExecutionErrorsSuite H2 dialect is being registered, and somewhere it is + // not registered back. The suite should be fixed, but to be safe for now, we are + // always registering H2 dialect before test execution. + JdbcDialects.registerDialect(H2Dialect()) } override def afterAll(): Unit = { From 1a78dd757fc1a919c43345843357e54af956b038 Mon Sep 17 00:00:00 2001 From: Petar Vasiljevic Date: Tue, 22 Jul 2025 00:42:14 +0200 Subject: [PATCH 9/9] move supportsXYZ to JDBCV2JoinPushdownSuiteBase --- .../DataSourcePushdownTestUtils.scala | 213 +++++++----------- ...BCV2JoinPushdownIntegrationSuiteBase.scala | 22 +- 2 files changed, 105 insertions(+), 130 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala index 2816eb79f1f82..7b8774980d2cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala @@ -25,170 +25,133 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { - protected val supportsSamplePushdown: Boolean = true - - protected val supportsFilterPushdown: Boolean = true - - protected val supportsLimitPushdown: Boolean = true - - protected val supportsAggregatePushdown: Boolean = true - - protected val supportsSortPushdown: Boolean = true - - protected val supportsOffsetPushdown: Boolean = true - - protected val supportsColumnPruning: Boolean = true - - protected val supportsJoinPushdown: Boolean = true - - protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { - if (supportsSamplePushdown) { - val sample = df.queryExecution.optimizedPlan.collect { - case s: Sample => s - } - if (pushed) { - assert(sample.isEmpty) - } else { - assert(sample.nonEmpty) - } + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) } } protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { - if (supportsFilterPushdown) { - val filter = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - if (pushed) { - assert(filter.isEmpty) - } else { - assert(filter.nonEmpty) - } + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) } } protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = { - if (supportsLimitPushdown) { - val limit = df.queryExecution.optimizedPlan.collect { - case l: LocalLimit => l - case g: GlobalLimit => g - } - if (pushed) { - assert(limit.isEmpty) - } else { - assert(limit.nonEmpty) - } + val limit = df.queryExecution.optimizedPlan.collect { + case l: LocalLimit => l + case g: GlobalLimit => g + } + if (pushed) { + assert(limit.isEmpty) + } else { + assert(limit.nonEmpty) } } protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { - if (supportsLimitPushdown) { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit == limit) - } + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.limit == limit) } } } protected def checkColumnPruned(df: DataFrame, col: String): Unit = { - if (supportsColumnPruning) { - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq(col))) - } + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq(col))) } - protected def checkAggregateRemoved(df: DataFrame): Unit = { - if (supportsAggregatePushdown) { - val aggregates = df.queryExecution.optimizedPlan.collect { - case agg: Aggregate => agg - } + protected def checkAggregateRemoved(df: DataFrame, pushed: Boolean = true): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + if (pushed) { assert(aggregates.isEmpty) + } else { + assert(aggregates.nonEmpty) } } protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { - if (supportsAggregatePushdown) { - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _, _, _) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions.exists { expr => - expr.isInstanceOf[GeneralAggregateFunc] && - expr.asInstanceOf[GeneralAggregateFunc].name() == funcName - }) - } + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _, _, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.exists { expr => + expr.isInstanceOf[GeneralAggregateFunc] && + expr.asInstanceOf[GeneralAggregateFunc].name() == funcName + }) } } protected def checkSortRemoved( df: DataFrame, pushed: Boolean = true): Unit = { - if (supportsSortPushdown) { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } - if (pushed) { - assert(sorts.isEmpty) - } else { - assert(sorts.nonEmpty) - } + if (pushed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) } } protected def checkOffsetRemoved( df: DataFrame, pushed: Boolean = true): Unit = { - if (supportsOffsetPushdown) { - val offsets = df.queryExecution.optimizedPlan.collect { - case o: Offset => o - } + val offsets = df.queryExecution.optimizedPlan.collect { + case o: Offset => o + } - if (pushed) { - assert(offsets.isEmpty) - } else { - assert(offsets.nonEmpty) - } + if (pushed) { + assert(offsets.isEmpty) + } else { + assert(offsets.nonEmpty) } } protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = { - if (supportsOffsetPushdown) { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.offset == offset) - } + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.offset == offset) } } } protected def checkJoinNotPushed(df: DataFrame): Unit = { - if (supportsJoinPushdown) { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - assert(joinNodes.nonEmpty, "Join should not be pushed down") + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j } + assert(joinNodes.nonEmpty, "Join should not be pushed down") } protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = { - if (supportsJoinPushdown) { - val joinNodes = df.queryExecution.optimizedPlan.collect { - case j: Join => j - } - assert(joinNodes.isEmpty, "Join should be pushed down") - if (expectedTables.nonEmpty) { - checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") - } + val joinNodes = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joinNodes.isEmpty, "Join should be pushed down") + if (expectedTables.nonEmpty) { + checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]") } } @@ -212,21 +175,19 @@ trait DataSourcePushdownTestUtils extends ExplainSuiteHelper { protected def checkPrunedColumnsDataTypeAndNullability( df: DataFrame, schema: StructType): Unit = { - if (supportsColumnPruning) { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - val dfSchema = v1.readSchema() - - assert(dfSchema.length == schema.length) - dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) => - if (f2.name.nonEmpty) { - assert(f1.name == f2.name) - } - assert(f1.dataType == f2.dataType) - assert(f1.nullable == f2.nullable) + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + val dfSchema = v1.readSchema() + + assert(dfSchema.length == schema.length) + dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) => + if (f2.name.nonEmpty) { + assert(f1.name == f2.name) } - } + assert(f1.dataType == f2.dataType) + assert(f1.nullable == f2.nullable) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala index 244ae40c48a9d..27ca84bc2a2a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala @@ -184,6 +184,20 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase } } + protected val supportsFilterPushdown: Boolean = true + + protected val supportsLimitPushdown: Boolean = true + + protected val supportsAggregatePushdown: Boolean = true + + protected val supportsSortPushdown: Boolean = true + + protected val supportsOffsetPushdown: Boolean = true + + protected val supportsColumnPruning: Boolean = true + + protected val supportsJoinPushdown: Boolean = true + // Condition-less joins are not supported in join pushdown test("Test that 2-way join without condition should not have join pushed down") { val sqlQuery = @@ -435,7 +449,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) - checkAggregateRemoved(df) + checkAggregateRemoved(df, supportsAggregatePushdown) checkJoinPushed( df, expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}, " + @@ -488,8 +502,8 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") { val df = sql(sqlQuery) - checkSortRemoved(df) - checkLimitRemoved(df) + checkSortRemoved(df, supportsSortPushdown) + checkLimitRemoved(df, supportsLimitPushdown) checkJoinPushed( df, @@ -520,7 +534,7 @@ trait JDBCV2JoinPushdownIntegrationSuiteBase expectedTables = s"$catalogAndNamespace.${caseConvert(joinTableName1)}", s"$catalogAndNamespace.${caseConvert(joinTableName2)}" ) - checkFilterPushed(df) + checkFilterPushed(df, supportsFilterPushdown) checkAnswer(df, rows) } }