@@ -28,7 +28,7 @@ import scala.util.control.NonFatal
28
28
29
29
import org .apache .commons .text .StringEscapeUtils
30
30
31
- import org .apache .spark .{sql , TaskContext }
31
+ import org .apache .spark .{sql , SparkException , TaskContext }
32
32
import org .apache .spark .annotation .{DeveloperApi , Stable , Unstable }
33
33
import org .apache .spark .api .java .JavaRDD
34
34
import org .apache .spark .api .java .function ._
@@ -649,7 +649,7 @@ class Dataset[T] private[sql](
649
649
private def resolveSelfJoinCondition (
650
650
right : Dataset [_],
651
651
joinExprs : Option [Column ],
652
- joinType : String ): Join = {
652
+ joinType : String ): LogicalPlan = {
653
653
// Note that in this function, we introduce a hack in the case of self-join to automatically
654
654
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
655
655
// Consider this case: df.join(df, df("key") === df("key"))
@@ -660,28 +660,40 @@ class Dataset[T] private[sql](
660
660
661
661
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
662
662
// After the cloning, left and right side will have distinct expression ids.
663
- val plan = withPlan(
664
- Join (logicalPlan, right.logicalPlan,
665
- JoinType (joinType), joinExprs.map(_.expr), JoinHint .NONE ))
666
- .queryExecution.analyzed.asInstanceOf [Join ]
663
+ val planToAnalyze = Join (
664
+ logicalPlan, right.logicalPlan, JoinType (joinType), joinExprs.map(_.expr), JoinHint .NONE )
665
+ val analyzedJoinPlan = withPlan(planToAnalyze).queryExecution.analyzed
667
666
668
667
// If auto self join alias is disabled, return the plan.
669
668
if (! sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
670
- return plan
669
+ return analyzedJoinPlan
671
670
}
672
671
673
672
// If left/right have no output set intersection, return the plan.
674
673
val lanalyzed = this .queryExecution.analyzed
675
674
val ranalyzed = right.queryExecution.analyzed
676
675
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
677
- return plan
676
+ return analyzedJoinPlan
678
677
}
679
678
680
679
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
681
680
// By the time we get here, since we have already run analysis, all attributes should've been
682
681
// resolved and become AttributeReference.
683
-
684
- JoinWith .resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan)
682
+ analyzedJoinPlan match {
683
+ case project @ Project (_, join : Join ) =>
684
+ // SPARK-53143: Handling edge-cases when `AddMetadataColumns` analyzer rule adds `Project`
685
+ // node on top of `Join` node.
686
+ // Check "SPARK-53143: self join edge-case when Join is not returned by the analyzer" in
687
+ // `DataframeSelfJoinSuite` for more details.
688
+ val newProject = project.copy(child = JoinWith .resolveSelfJoinCondition(
689
+ sparkSession.sessionState.analyzer.resolver, join))
690
+ newProject.copyTagsFrom(project)
691
+ newProject
692
+ case join : Join =>
693
+ JoinWith .resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, join)
694
+ case _ => throw SparkException .internalError(
695
+ s " Unexpected plan type: ${analyzedJoinPlan.getClass.getName} for self join resolution. " )
696
+ }
685
697
}
686
698
687
699
/** @inheritdoc */
@@ -781,28 +793,38 @@ class Dataset[T] private[sql](
781
793
tolerance : Column ,
782
794
allowExactMatches : Boolean ,
783
795
direction : String ): DataFrame = {
784
- val joined = resolveSelfJoinCondition(other, Option (joinExprs), joinType)
785
- val leftAsOfExpr = leftAsOf.expr.transformUp {
786
- case a : AttributeReference if logicalPlan.outputSet.contains(a) =>
787
- val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
788
- joined.left. output(index )
789
- }
790
- val rightAsOfExpr = rightAsOf.expr.transformUp {
791
- case a : AttributeReference if other.logicalPlan.outputSet.contains(a) =>
792
- val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
793
- joined.right .output(index )
794
- }
795
- withPlan {
796
+
797
+ def createAsOfJoinPlan ( joinPlan : Join ) : AsOfJoin = {
798
+ val leftAsOfExpr = leftAsOf.expr.transformUp {
799
+ case a : AttributeReference if logicalPlan.outputSet.contains(a) =>
800
+ val index = logicalPlan. output.indexWhere(_.exprId == a.exprId )
801
+ joinPlan.left.output(index)
802
+ }
803
+ val rightAsOfExpr = rightAsOf.expr.transformUp {
804
+ case a : AttributeReference if other.logicalPlan.outputSet.contains(a) =>
805
+ val index = other.logicalPlan .output.indexWhere(_.exprId == a.exprId )
806
+ joinPlan.right.output(index)
807
+ }
796
808
AsOfJoin (
797
- joined .left, joined .right,
809
+ joinPlan .left, joinPlan .right,
798
810
leftAsOfExpr, rightAsOfExpr,
799
- joined .condition,
800
- joined .joinType,
811
+ joinPlan .condition,
812
+ joinPlan .joinType,
801
813
Option (tolerance).map(_.expr),
802
814
allowExactMatches,
803
815
AsOfJoinDirection (direction)
804
816
)
805
817
}
818
+
819
+ resolveSelfJoinCondition(other, Option (joinExprs), joinType) match {
820
+ case project @ Project (_, join : Join ) =>
821
+ val newProjectPlan = project.copy(child = createAsOfJoinPlan(join))
822
+ newProjectPlan.copyTagsFrom(project)
823
+ withPlan { newProjectPlan }
824
+ case join : Join => withPlan { createAsOfJoinPlan(join) }
825
+ case plan => throw SparkException .internalError(
826
+ s " Unexpected plan type: ${plan.getClass.getName} returned from self join resolution. " )
827
+ }
806
828
}
807
829
808
830
/** @inheritdoc */
0 commit comments