diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index ee4a479960145..fb18dca97994e 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -22,20 +22,21 @@ import org.apache.logging.log4j.Level
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sample, Sort}
+import org.apache.spark.sql.connector.DataSourcePushdownTestUtils
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.NullOrdering
-import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest
@DockerTest
-private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite {
+private[v2] trait V2JDBCTest
+ extends DataSourcePushdownTestUtils
+ with DockerIntegrationFunSuite
+ with SharedSparkSession {
import testImplicits._
val catalogName: String
@@ -468,56 +469,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
def supportsTableSample: Boolean = false
- private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = {
- val sample = df.queryExecution.optimizedPlan.collect {
- case s: Sample => s
- }
- if (pushed) {
- assert(sample.isEmpty)
- } else {
- assert(sample.nonEmpty)
- }
- }
-
- protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
- val filter = df.queryExecution.optimizedPlan.collect {
- case f: Filter => f
- }
- if (pushed) {
- assert(filter.isEmpty)
- } else {
- assert(filter.nonEmpty)
- }
- }
-
- protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
- val limit = df.queryExecution.optimizedPlan.collect {
- case l: LocalLimit => l
- case g: GlobalLimit => g
- }
- if (pushed) {
- assert(limit.isEmpty)
- } else {
- assert(limit.nonEmpty)
- }
- }
-
- private def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = {
- df.queryExecution.optimizedPlan.collect {
- case relation: DataSourceV2ScanRelation => relation.scan match {
- case v1: V1ScanWrapper =>
- assert(v1.pushedDownOperators.limit == limit)
- }
- }
- }
-
- private def checkColumnPruned(df: DataFrame, col: String): Unit = {
- val scan = df.queryExecution.optimizedPlan.collectFirst {
- case s: DataSourceV2ScanRelation => s
- }.get
- assert(scan.schema.names.sameElements(Seq(col)))
- }
-
test("SPARK-48172: Test CONTAINS") {
val df1 = spark.sql(
s"""
@@ -841,39 +792,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}
- private def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
- val sorts = df.queryExecution.optimizedPlan.collect {
- case s: Sort => s
- }
-
- if (pushed) {
- assert(sorts.isEmpty)
- } else {
- assert(sorts.nonEmpty)
- }
- }
-
- private def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
- val offsets = df.queryExecution.optimizedPlan.collect {
- case o: Offset => o
- }
-
- if (pushed) {
- assert(offsets.isEmpty)
- } else {
- assert(offsets.nonEmpty)
- }
- }
-
- private def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = {
- df.queryExecution.optimizedPlan.collect {
- case relation: DataSourceV2ScanRelation => relation.scan match {
- case v1: V1ScanWrapper =>
- assert(v1.pushedDownOperators.offset == offset)
- }
- }
- }
-
gridTest("simple scan")(partitioningEnabledTestCase) { partitioningEnabled =>
val (tableOptions, partitionInfo) = getTableOptions("employee", partitioningEnabled)
val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
@@ -1028,27 +946,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}
- private def checkAggregateRemoved(df: DataFrame): Unit = {
- val aggregates = df.queryExecution.optimizedPlan.collect {
- case agg: Aggregate => agg
- }
- assert(aggregates.isEmpty)
- }
-
- private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = {
- df.queryExecution.optimizedPlan.collect {
- case DataSourceV2ScanRelation(_, scan, _, _, _) =>
- assert(scan.isInstanceOf[V1ScanWrapper])
- val wrapper = scan.asInstanceOf[V1ScanWrapper]
- assert(wrapper.pushedDownOperators.aggregation.isDefined)
- val aggregationExpressions =
- wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
- assert(aggregationExpressions.length == 1)
- assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc])
- assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName)
- }
- }
-
protected def caseConvert(tableName: String): String = tableName
Seq(true, false).foreach { isDistinct =>
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala
new file mode 100644
index 0000000000000..ecc0c5489bceb
--- /dev/null
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc.v2.join
+
+import java.sql.Connection
+import java.util.Locale
+
+import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect}
+import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase
+import org.apache.spark.sql.types.DataTypes
+import org.apache.spark.tags.DockerTest
+
+/**
+ * The following are the steps to test this:
+ *
+ * 1. Choose to use a prebuilt image or build Oracle database in a container
+ * - The documentation on how to build Oracle RDBMS in a container is at
+ * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md
+ * - Official Oracle container images can be found at https://container-registry.oracle.com
+ * - Trustable and streamlined Oracle Database Free images can be found on Docker Hub at
+ * https://hub.docker.com/r/gvenzl/oracle-free
+ * see also https://github.com/gvenzl/oci-oracle-free
+ * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing
+ * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-free:latest
+ * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1
+ * 4. Start docker: sudo service docker start
+ * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME
+ * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests
+ * "testOnly org.apache.spark.sql.jdbc.v2.OracleIntegrationSuite"
+ *
+ * A sequence of commands to build the Oracle Database Free container image:
+ * $ git clone https://github.com/oracle/docker-images.git
+ * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles0
+ * $ ./buildContainerImage.sh -v 23.4.0 -f
+ * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:23.4.0-free
+ *
+ * This procedure has been validated with Oracle Database Free version 23.4.0,
+ * and with Oracle Express Edition versions 18.4.0 and 21.4.0
+ */
+@DockerTest
+class OracleJoinPushdownIntegrationSuite
+ extends DockerJDBCIntegrationSuite
+ with JDBCV2JoinPushdownIntegrationSuiteBase {
+ override val namespace: String = "SYSTEM"
+
+ override val db = new OracleDatabaseOnDocker
+
+ override val url = db.getJdbcUrl(dockerIp, externalPort)
+
+ override val jdbcDialect: JdbcDialect = OracleDialect()
+
+ override val integerType = DataTypes.createDecimalType(10, 0)
+
+ override def caseConvert(identifier: String): String = identifier.toUpperCase(Locale.ROOT)
+
+ override def schemaPreparation(): Unit = {}
+
+ // This method comes from DockerJDBCIntegrationSuite
+ override def dataPreparation(connection: Connection): Unit = {
+ super.dataPreparation()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
index 97ec093a0e297..d0f53fa20e1aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
@@ -208,7 +208,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) {
// If join has been pushed down, reuse join query as a subquery. Otherwise, fallback to
// what is provided in options.
- private def tableOrQuery = joinQuery.getOrElse(options.tableOrQuery)
+ protected final def tableOrQuery: String = joinQuery.getOrElse(options.tableOrQuery)
/**
* Build the final SQL query that following dialect's SQL syntax.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index a9f6a727a7241..0c9c84f3f3e75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -233,7 +233,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
extends JdbcSQLQueryBuilder(dialect, options) {
override def build(): String = {
- val selectStmt = s"SELECT $hintClause$columnList FROM ${options.tableOrQuery}" +
+ val selectStmt = s"SELECT $hintClause$columnList FROM $tableOrQuery" +
s" $tableSampleClause $whereClause $groupByClause $orderByClause"
val finalSelectStmt = if (limit > 0) {
if (offset > 0) {
@@ -268,6 +268,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
override def supportsHint: Boolean = true
+ override def supportsJoin: Boolean = true
+
override def classifyException(
e: Throwable,
condition: String,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala
new file mode 100644
index 0000000000000..2816eb79f1f82
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
+ protected val supportsSamplePushdown: Boolean = true
+
+ protected val supportsFilterPushdown: Boolean = true
+
+ protected val supportsLimitPushdown: Boolean = true
+
+ protected val supportsAggregatePushdown: Boolean = true
+
+ protected val supportsSortPushdown: Boolean = true
+
+ protected val supportsOffsetPushdown: Boolean = true
+
+ protected val supportsColumnPruning: Boolean = true
+
+ protected val supportsJoinPushdown: Boolean = true
+
+
+ protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = {
+ if (supportsSamplePushdown) {
+ val sample = df.queryExecution.optimizedPlan.collect {
+ case s: Sample => s
+ }
+ if (pushed) {
+ assert(sample.isEmpty)
+ } else {
+ assert(sample.nonEmpty)
+ }
+ }
+ }
+
+ protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
+ if (supportsFilterPushdown) {
+ val filter = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ if (pushed) {
+ assert(filter.isEmpty)
+ } else {
+ assert(filter.nonEmpty)
+ }
+ }
+ }
+
+ protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
+ if (supportsLimitPushdown) {
+ val limit = df.queryExecution.optimizedPlan.collect {
+ case l: LocalLimit => l
+ case g: GlobalLimit => g
+ }
+ if (pushed) {
+ assert(limit.isEmpty)
+ } else {
+ assert(limit.nonEmpty)
+ }
+ }
+ }
+
+ protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = {
+ if (supportsLimitPushdown) {
+ df.queryExecution.optimizedPlan.collect {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
+ case v1: V1ScanWrapper =>
+ assert(v1.pushedDownOperators.limit == limit)
+ }
+ }
+ }
+ }
+
+ protected def checkColumnPruned(df: DataFrame, col: String): Unit = {
+ if (supportsColumnPruning) {
+ val scan = df.queryExecution.optimizedPlan.collectFirst {
+ case s: DataSourceV2ScanRelation => s
+ }.get
+ assert(scan.schema.names.sameElements(Seq(col)))
+ }
+ }
+
+ protected def checkAggregateRemoved(df: DataFrame): Unit = {
+ if (supportsAggregatePushdown) {
+ val aggregates = df.queryExecution.optimizedPlan.collect {
+ case agg: Aggregate => agg
+ }
+ assert(aggregates.isEmpty)
+ }
+ }
+
+ protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = {
+ if (supportsAggregatePushdown) {
+ df.queryExecution.optimizedPlan.collect {
+ case DataSourceV2ScanRelation(_, scan, _, _, _) =>
+ assert(scan.isInstanceOf[V1ScanWrapper])
+ val wrapper = scan.asInstanceOf[V1ScanWrapper]
+ assert(wrapper.pushedDownOperators.aggregation.isDefined)
+ val aggregationExpressions =
+ wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
+ assert(aggregationExpressions.exists { expr =>
+ expr.isInstanceOf[GeneralAggregateFunc] &&
+ expr.asInstanceOf[GeneralAggregateFunc].name() == funcName
+ })
+ }
+ }
+ }
+
+ protected def checkSortRemoved(
+ df: DataFrame,
+ pushed: Boolean = true): Unit = {
+ if (supportsSortPushdown) {
+ val sorts = df.queryExecution.optimizedPlan.collect {
+ case s: Sort => s
+ }
+
+ if (pushed) {
+ assert(sorts.isEmpty)
+ } else {
+ assert(sorts.nonEmpty)
+ }
+ }
+ }
+
+ protected def checkOffsetRemoved(
+ df: DataFrame,
+ pushed: Boolean = true): Unit = {
+ if (supportsOffsetPushdown) {
+ val offsets = df.queryExecution.optimizedPlan.collect {
+ case o: Offset => o
+ }
+
+ if (pushed) {
+ assert(offsets.isEmpty)
+ } else {
+ assert(offsets.nonEmpty)
+ }
+ }
+ }
+
+ protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = {
+ if (supportsOffsetPushdown) {
+ df.queryExecution.optimizedPlan.collect {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
+ case v1: V1ScanWrapper =>
+ assert(v1.pushedDownOperators.offset == offset)
+ }
+ }
+ }
+ }
+
+ protected def checkJoinNotPushed(df: DataFrame): Unit = {
+ if (supportsJoinPushdown) {
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+ assert(joinNodes.nonEmpty, "Join should not be pushed down")
+ }
+ }
+
+ protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = {
+ if (supportsJoinPushdown) {
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+ assert(joinNodes.isEmpty, "Join should be pushed down")
+ if (expectedTables.nonEmpty) {
+ checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]")
+ }
+ }
+ }
+
+ protected def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = {
+ withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") {
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ checkKeywordsExistsInExplain(df, expectedPlanFragment: _*)
+ }
+ }
+ }
+
+ /**
+ * Check if the output schema of dataframe {@code df} is same as {@code schema}. There is one
+ * limitation: if expected schema name is empty, assertion on same names will be skipped.
+ *
+ * For example, it is not really possible to use {@code checkPrunedColumns} for join pushdown,
+ * because in case of duplicate names, columns will have random UUID suffixes. For this reason,
+ * the best we can do is test that the size is same, and other fields beside names do match.
+ */
+ protected def checkPrunedColumnsDataTypeAndNullability(
+ df: DataFrame,
+ schema: StructType): Unit = {
+ if (supportsColumnPruning) {
+ df.queryExecution.optimizedPlan.collect {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
+ case v1: V1ScanWrapper =>
+ val dfSchema = v1.readSchema()
+
+ assert(dfSchema.length == schema.length)
+ dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) =>
+ if (f2.name.nonEmpty) {
+ assert(f1.name == f2.name)
+ }
+ assert(f1.dataType == f2.dataType)
+ assert(f1.nullable == f2.nullable)
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala
deleted file mode 100644
index b77e905fea5d0..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala
+++ /dev/null
@@ -1,413 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.jdbc
-
-import java.sql.{Connection, DriverManager}
-import java.util.Properties
-
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit, Join, Sort}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
-import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.util.Utils
-
-class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
- val tempDir = Utils.createTempDir()
- val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
-
- override def sparkConf: SparkConf = super.sparkConf
- .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
- .set("spark.sql.catalog.h2.url", url)
- .set("spark.sql.catalog.h2.driver", "org.h2.Driver")
- .set("spark.sql.catalog.h2.pushDownAggregate", "true")
- .set("spark.sql.catalog.h2.pushDownLimit", "true")
- .set("spark.sql.catalog.h2.pushDownOffset", "true")
- .set("spark.sql.catalog.h2.pushDownJoin", "true")
-
- private def withConnection[T](f: Connection => T): T = {
- val conn = DriverManager.getConnection(url, new Properties())
- try {
- f(conn)
- } finally {
- conn.close()
- }
- }
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- Utils.classForName("org.h2.Driver")
- withConnection { conn =>
- conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate()
- conn.prepareStatement(
- "CREATE TABLE \"test\".\"people\" (name TEXT(32) NOT NULL, id INTEGER NOT NULL)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
- conn.prepareStatement(
- "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
- " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate()
- conn.prepareStatement(
- "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate()
- conn.prepareStatement(
- "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate()
- conn.prepareStatement(
- "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate()
- conn.prepareStatement(
- "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate()
- conn.prepareStatement(
- "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate()
- conn.prepareStatement(
- "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate()
-
- // scalastyle:off
- conn.prepareStatement(
- "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate()
- // scalastyle:on
- conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate()
- conn.prepareStatement(
- """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate()
- conn.prepareStatement(
- """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate()
- }
- }
-
- override def afterAll(): Unit = {
- Utils.deleteRecursively(tempDir)
- super.afterAll()
- }
-
- private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = {
- withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") {
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- checkKeywordsExistsInExplain(df, expectedPlanFragment: _*)
- }
- }
- }
-
- // Conditionless joins are not supported in join pushdown
- test("Test that 2-way join without condition should not have join pushed down") {
- val sqlQuery = "SELECT * FROM h2.test.employee a, h2.test.employee b"
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.nonEmpty)
- checkAnswer(df, rows)
- }
- }
-
- // Conditionless joins are not supported in join pushdown
- test("Test that multi-way join without condition should not have join pushed down") {
- val sqlQuery = """
- |SELECT * FROM
- |h2.test.employee a,
- |h2.test.employee b,
- |h2.test.employee c
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
-
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.nonEmpty)
- checkAnswer(df, rows)
- }
- }
-
- test("Test self join with condition") {
- val sqlQuery = "SELECT * FROM h2.test.employee a JOIN h2.test.employee b ON a.dept = b.dept + 1"
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
-
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test multi-way self join with conditions") {
- val sqlQuery = """
- |SELECT * FROM
- |h2.test.employee a
- |JOIN h2.test.employee b ON b.dept = a.dept + 1
- |JOIN h2.test.employee c ON c.dept = b.dept - 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- assert(!rows.isEmpty)
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test self join with column pruning") {
- val sqlQuery = """
- |SELECT a.dept + 2, b.dept, b.salary FROM
- |h2.test.employee a JOIN h2.test.employee b
- |ON a.dept = b.dept + 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
-
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test 2-way join with column pruning - different tables") {
- val sqlQuery = """
- |SELECT * FROM
- |h2.test.employee a JOIN h2.test.people b
- |ON a.dept = b.id
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
-
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.people]")
- checkPushedInfo(df,
- "PushedFilters: [DEPT IS NOT NULL, ID IS NOT NULL, DEPT = ID]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test multi-way self join with column pruning") {
- val sqlQuery = """
- |SELECT a.dept, b.*, c.dept, c.salary + a.salary
- |FROM h2.test.employee a
- |JOIN h2.test.employee b ON b.dept = a.dept + 1
- |JOIN h2.test.employee c ON c.dept = b.dept - 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
-
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test aliases not supported in join pushdown") {
- val sqlQuery = """
- |SELECT a.dept, bc.*
- |FROM h2.test.employee a
- |JOIN (
- | SELECT b.*, c.dept AS c_dept, c.salary AS c_salary
- | FROM h2.test.employee b
- | JOIN h2.test.employee c ON c.dept = b.dept - 1
- |) bc ON bc.dept = a.dept + 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.nonEmpty)
- checkAnswer(df, rows)
- }
- }
-
- test("Test join with dataframe with duplicated columns") {
- val df1 = sql("SELECT dept FROM h2.test.employee")
- val df2 = sql("SELECT dept, dept FROM h2.test.employee")
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- df1.join(df2, "dept").collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val joinDf = df1.join(df2, "dept")
- val joinNodes = joinDf.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- assert(joinNodes.isEmpty)
- checkPushedInfo(joinDf, "PushedJoins: [h2.test.employee, h2.test.employee]")
- checkAnswer(joinDf, rows)
- }
- }
-
- test("Test aggregate on top of 2-way self join") {
- val sqlQuery = """
- |SELECT min(a.dept + b.dept), min(a.dept)
- |FROM h2.test.employee a
- |JOIN h2.test.employee b ON a.dept = b.dept + 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- val aggNodes = df.queryExecution.optimizedPlan.collect {
- case a: Aggregate => a
- }
-
- assert(joinNodes.isEmpty)
- assert(aggNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test aggregate on top of multi-way self join") {
- val sqlQuery = """
- |SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2)
- |FROM h2.test.employee a
- |JOIN h2.test.employee b ON b.dept = a.dept + 1
- |JOIN h2.test.employee c ON c.dept = b.dept - 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- val aggNodes = df.queryExecution.optimizedPlan.collect {
- case a: Aggregate => a
- }
-
- assert(joinNodes.isEmpty)
- assert(aggNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-
- test("Test sort limit on top of join is pushed down") {
- val sqlQuery = """
- |SELECT min(a.dept + b.dept), a.dept, b.dept
- |FROM h2.test.employee a
- |JOIN h2.test.employee b ON b.dept = a.dept + 1
- |GROUP BY a.dept, b.dept
- |ORDER BY a.dept
- |LIMIT 1
- |""".stripMargin
-
- val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
- sql(sqlQuery).collect().toSeq
- }
-
- withSQLConf(
- SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
- val df = sql(sqlQuery)
- val joinNodes = df.queryExecution.optimizedPlan.collect {
- case j: Join => j
- }
-
- val sortNodes = df.queryExecution.optimizedPlan.collect {
- case s: Sort => s
- }
-
- val limitNodes = df.queryExecution.optimizedPlan.collect {
- case l: GlobalLimit => l
- }
-
- assert(joinNodes.isEmpty)
- assert(sortNodes.isEmpty)
- assert(limitNodes.isEmpty)
- checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
- checkAnswer(df, rows)
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala
new file mode 100644
index 0000000000000..244ae40c48a9d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/v2/JDBCV2JoinPushdownIntegrationSuiteBase.scala
@@ -0,0 +1,589 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc.v2
+
+import java.sql.{Connection, DriverManager}
+import java.util.Properties
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.connector.DataSourcePushdownTestUtils
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.jdbc.JdbcDialect
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
+
+trait JDBCV2JoinPushdownIntegrationSuiteBase
+ extends QueryTest
+ with SharedSparkSession
+ with DataSourcePushdownTestUtils {
+ val catalogName: String = "join_pushdown_catalog"
+ val namespace: String = "join_schema"
+ val url: String
+
+ val joinTableName1: String = "join_table_1"
+ val joinTableName2: String = "join_table_2"
+
+ val jdbcDialect: JdbcDialect
+
+ override def sparkConf: SparkConf = super.sparkConf
+ .set(s"spark.sql.catalog.$catalogName", classOf[JDBCTableCatalog].getName)
+ .set(SQLConf.ANSI_ENABLED.key, "true")
+ .set(s"spark.sql.catalog.$catalogName.url", url)
+ .set(s"spark.sql.catalog.$catalogName.pushDownJoin", "true")
+ .set(s"spark.sql.catalog.$catalogName.pushDownAggregate", "true")
+ .set(s"spark.sql.catalog.$catalogName.pushDownLimit", "true")
+ .set(s"spark.sql.catalog.$catalogName.pushDownOffset", "true")
+ .set(s"spark.sql.catalog.$catalogName.caseSensitive", "false")
+
+ private def catalogAndNamespace = s"$catalogName.${caseConvert(namespace)}"
+ private def casedJoinTableName1 = caseConvert(joinTableName1)
+ private def casedJoinTableName2 = caseConvert(joinTableName2)
+
+ def qualifyTableName(tableName: String): String = {
+ val fullyQualifiedCasedNamespace = jdbcDialect.quoteIdentifier(caseConvert(namespace))
+ val fullyQualifiedCasedTableName = jdbcDialect.quoteIdentifier(caseConvert(tableName))
+ s"$fullyQualifiedCasedNamespace.$fullyQualifiedCasedTableName"
+ }
+
+ def quoteSchemaName(schemaName: String): String =
+ jdbcDialect.quoteIdentifier(caseConvert(namespace))
+
+ private lazy val fullyQualifiedTableName1: String = qualifyTableName(joinTableName1)
+
+ private lazy val fullyQualifiedTableName2: String = qualifyTableName(joinTableName2)
+
+ protected def getJDBCTypeString(dt: DataType): String = {
+ JdbcUtils.getJdbcType(dt, jdbcDialect).databaseTypeDefinition.toUpperCase()
+ }
+
+ protected def caseConvert(identifier: String): String = identifier
+
+ protected def withConnection[T](f: Connection => T): T = {
+ val conn = DriverManager.getConnection(url, new Properties())
+ try {
+ f(conn)
+ } finally {
+ conn.close()
+ }
+ }
+
+ protected val integerType = DataTypes.IntegerType
+
+ protected val stringType = DataTypes.StringType
+
+ protected val decimalType = DataTypes.createDecimalType(10, 2)
+
+ /**
+ * This method should cover the following:
+ *