Skip to content

Commit 1f1bacc

Browse files
davidm-dbcloud-fan
andcommitted
[SPARK-53143][SQL] Fix self join in DataFrame API - Join is not the only expected output from analyzer
### What changes were proposed in this pull request? `Dataset.resolveSelfJoinCondition` expects that analyzer output will always be of `Join` type and that is not correct. There are edge cases when analyzer can produce a plan that has `Project` as the top node. This PR fixes failures in such cases. Check [SPARK-53143](https://issues.apache.org/jira/browse/SPARK-53143) for more details. ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually in shell. Unit test to cover the problematic case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51873 from davidm-db/spark-53143. Lead-authored-by: David Milicevic <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0c5797a commit 1f1bacc

File tree

2 files changed

+73
-25
lines changed

2 files changed

+73
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import scala.util.control.NonFatal
2828

2929
import org.apache.commons.text.StringEscapeUtils
3030

31-
import org.apache.spark.{sql, TaskContext}
31+
import org.apache.spark.{sql, SparkException, TaskContext}
3232
import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable}
3333
import org.apache.spark.api.java.JavaRDD
3434
import org.apache.spark.api.java.function._
@@ -649,7 +649,7 @@ class Dataset[T] private[sql](
649649
private def resolveSelfJoinCondition(
650650
right: Dataset[_],
651651
joinExprs: Option[Column],
652-
joinType: String): Join = {
652+
joinType: String): LogicalPlan = {
653653
// Note that in this function, we introduce a hack in the case of self-join to automatically
654654
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
655655
// Consider this case: df.join(df, df("key") === df("key"))
@@ -660,28 +660,40 @@ class Dataset[T] private[sql](
660660

661661
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
662662
// After the cloning, left and right side will have distinct expression ids.
663-
val plan = withPlan(
664-
Join(logicalPlan, right.logicalPlan,
665-
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
666-
.queryExecution.analyzed.asInstanceOf[Join]
663+
val planToAnalyze = Join(
664+
logicalPlan, right.logicalPlan, JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)
665+
val analyzedJoinPlan = withPlan(planToAnalyze).queryExecution.analyzed
667666

668667
// If auto self join alias is disabled, return the plan.
669668
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
670-
return plan
669+
return analyzedJoinPlan
671670
}
672671

673672
// If left/right have no output set intersection, return the plan.
674673
val lanalyzed = this.queryExecution.analyzed
675674
val ranalyzed = right.queryExecution.analyzed
676675
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
677-
return plan
676+
return analyzedJoinPlan
678677
}
679678

680679
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
681680
// By the time we get here, since we have already run analysis, all attributes should've been
682681
// resolved and become AttributeReference.
683-
684-
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan)
682+
analyzedJoinPlan match {
683+
case project @ Project(_, join: Join) =>
684+
// SPARK-53143: Handling edge-cases when `AddMetadataColumns` analyzer rule adds `Project`
685+
// node on top of `Join` node.
686+
// Check "SPARK-53143: self join edge-case when Join is not returned by the analyzer" in
687+
// `DataframeSelfJoinSuite` for more details.
688+
val newProject = project.copy(child = JoinWith.resolveSelfJoinCondition(
689+
sparkSession.sessionState.analyzer.resolver, join))
690+
newProject.copyTagsFrom(project)
691+
newProject
692+
case join: Join =>
693+
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, join)
694+
case _ => throw SparkException.internalError(
695+
s"Unexpected plan type: ${analyzedJoinPlan.getClass.getName} for self join resolution.")
696+
}
685697
}
686698

687699
/** @inheritdoc */
@@ -781,28 +793,38 @@ class Dataset[T] private[sql](
781793
tolerance: Column,
782794
allowExactMatches: Boolean,
783795
direction: String): DataFrame = {
784-
val joined = resolveSelfJoinCondition(other, Option(joinExprs), joinType)
785-
val leftAsOfExpr = leftAsOf.expr.transformUp {
786-
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
787-
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
788-
joined.left.output(index)
789-
}
790-
val rightAsOfExpr = rightAsOf.expr.transformUp {
791-
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
792-
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
793-
joined.right.output(index)
794-
}
795-
withPlan {
796+
797+
def createAsOfJoinPlan(joinPlan: Join): AsOfJoin = {
798+
val leftAsOfExpr = leftAsOf.expr.transformUp {
799+
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
800+
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
801+
joinPlan.left.output(index)
802+
}
803+
val rightAsOfExpr = rightAsOf.expr.transformUp {
804+
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
805+
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
806+
joinPlan.right.output(index)
807+
}
796808
AsOfJoin(
797-
joined.left, joined.right,
809+
joinPlan.left, joinPlan.right,
798810
leftAsOfExpr, rightAsOfExpr,
799-
joined.condition,
800-
joined.joinType,
811+
joinPlan.condition,
812+
joinPlan.joinType,
801813
Option(tolerance).map(_.expr),
802814
allowExactMatches,
803815
AsOfJoinDirection(direction)
804816
)
805817
}
818+
819+
resolveSelfJoinCondition(other, Option(joinExprs), joinType) match {
820+
case project @ Project(_, join: Join) =>
821+
val newProjectPlan = project.copy(child = createAsOfJoinPlan(join))
822+
newProjectPlan.copyTagsFrom(project)
823+
withPlan { newProjectPlan }
824+
case join: Join => withPlan { createAsOfJoinPlan(join) }
825+
case plan => throw SparkException.internalError(
826+
s"Unexpected plan type: ${plan.getClass.getName} returned from self join resolution.")
827+
}
806828
}
807829

808830
/** @inheritdoc */

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,4 +527,30 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
527527
}
528528
}
529529
}
530+
531+
test("SPARK-53143: self join edge-case when Join is not returned by the analyzer") {
532+
withTable("table_1", "table_2") {
533+
// Edge case with multiple joins. Example: two joins, where the latter one is self join.
534+
// The first one is the "using" join - in this case, analyzer's
535+
// `ResolveNaturalAndUsingJoin` will add `Project` as the top node.
536+
// The second join is a self join, but with specified join condition (i.e. `joinExprs`) -
537+
// if the join condition uses columns that are not part of the project list (of the first
538+
// join), `AddMetadataColumns` rule will be hit to add metadata for those columns. As a
539+
// consequence, `Project` will be added to the top of joined plan to return the
540+
// original/expected list of projected columns.
541+
// Whereas similar (i.e. `Project` node on top) can happen in multiple other cases,
542+
// from `Dataset` perspective the issue is specific to self joins only, since
543+
// `resolveSelfJoinCondition` assumed that the analyzed plan will be always of `Join` type.
544+
sql("CREATE TABLE IF NOT EXISTS table_1 (id INT);")
545+
sql("INSERT INTO table_1 VALUES (1), (2);")
546+
sql("CREATE TABLE IF NOT EXISTS table_2 (id INT, col_1 STRING);")
547+
sql("INSERT INTO table_2 VALUES (1, 'str'), (2, 'test');")
548+
val df = spark.table("table_2").where("col_1 = 'test'").select("id")
549+
assert(
550+
spark.table("table_1").alias("t")
551+
.join(df.alias("df1"), usingColumns = Seq("id"))
552+
.join(df.alias("df2"), joinExprs = $"df1.id" === $"df2.id", joinType = "left")
553+
.count() == 1)
554+
}
555+
}
530556
}

0 commit comments

Comments
 (0)