Skip to content

Commit 9c8b0eb

Browse files
committed
[SPARK-51911] Support lateralJoin in DataFrame
### What changes were proposed in this pull request? This PR aims to support `lateralJoin` API in `DataFrame`. ### Why are the changes needed? To provide a foundation of `lateralJoin` API although `column` is not supported yet. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #88 from dongjoon-hyun/SPARK-51911. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d069034 commit 9c8b0eb

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,82 @@ public actor DataFrame: Sendable {
609609
return DataFrame(spark: self.spark, plan: plan)
610610
}
611611

612+
/// Lateral join with another ``DataFrame``.
613+
///
614+
/// Behaves as an JOIN LATERAL.
615+
///
616+
/// - Parameters:
617+
/// - right: Right side of the join operation.
618+
/// - Returns: A ``DataFrame``.
619+
public func lateralJoin(_ right: DataFrame) async -> DataFrame {
620+
let rightPlan = await (right.getPlan() as! Plan).root
621+
let plan = SparkConnectClient.getLateralJoin(
622+
self.plan.root,
623+
rightPlan,
624+
JoinType.inner
625+
)
626+
return DataFrame(spark: self.spark, plan: plan)
627+
}
628+
629+
/// Lateral join with another ``DataFrame``.
630+
///
631+
/// Behaves as an JOIN LATERAL.
632+
///
633+
/// - Parameters:
634+
/// - right: Right side of the join operation.
635+
/// - joinType: One of `inner` (default), `cross`, `left`, `leftouter`, `left_outer`.
636+
/// - Returns: A ``DataFrame``.
637+
public func lateralJoin(_ right: DataFrame, joinType: String) async -> DataFrame {
638+
let rightPlan = await (right.getPlan() as! Plan).root
639+
let plan = SparkConnectClient.getLateralJoin(
640+
self.plan.root,
641+
rightPlan,
642+
joinType.toJoinType
643+
)
644+
return DataFrame(spark: self.spark, plan: plan)
645+
}
646+
647+
/// Lateral join with another ``DataFrame``.
648+
///
649+
/// Behaves as an JOIN LATERAL.
650+
///
651+
/// - Parameters:
652+
/// - right: Right side of the join operation.
653+
/// - joinExprs: A join expression string.
654+
/// - Returns: A ``DataFrame``.
655+
public func lateralJoin(_ right: DataFrame, joinExprs: String) async -> DataFrame {
656+
let rightPlan = await (right.getPlan() as! Plan).root
657+
let plan = SparkConnectClient.getLateralJoin(
658+
self.plan.root,
659+
rightPlan,
660+
JoinType.inner,
661+
joinCondition: joinExprs
662+
)
663+
return DataFrame(spark: self.spark, plan: plan)
664+
}
665+
666+
/// Lateral join with another ``DataFrame``.
667+
///
668+
/// Behaves as an JOIN LATERAL.
669+
///
670+
/// - Parameters:
671+
/// - right: Right side of the join operation.
672+
/// - joinType: One of `inner` (default), `cross`, `left`, `leftouter`, `left_outer`.
673+
/// - joinExprs: A join expression string.
674+
/// - Returns: A ``DataFrame``.
675+
public func lateralJoin(
676+
_ right: DataFrame, joinExprs: String, joinType: String = "inner"
677+
) async -> DataFrame {
678+
let rightPlan = await (right.getPlan() as! Plan).root
679+
let plan = SparkConnectClient.getLateralJoin(
680+
self.plan.root,
681+
rightPlan,
682+
joinType.toJoinType,
683+
joinCondition: joinExprs
684+
)
685+
return DataFrame(spark: self.spark, plan: plan)
686+
}
687+
612688
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
613689
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
614690
/// - Parameter other: A `DataFrame` to exclude.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,24 @@ public actor SparkConnectClient {
617617
return plan
618618
}
619619

620+
static func getLateralJoin(
621+
_ left: Relation, _ right: Relation, _ joinType: JoinType,
622+
joinCondition: String? = nil
623+
) -> Plan {
624+
var lateralJoin = LateralJoin()
625+
lateralJoin.left = left
626+
lateralJoin.right = right
627+
lateralJoin.joinType = joinType
628+
if let joinCondition {
629+
lateralJoin.joinCondition.expressionString = joinCondition.toExpressionString
630+
}
631+
var relation = Relation()
632+
relation.lateralJoin = lateralJoin
633+
var plan = Plan()
634+
plan.opType = .root(relation)
635+
return plan
636+
}
637+
620638
static func getSetOperation(
621639
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
622640
byName: Bool = false, allowMissingColumns: Bool = false

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ typealias GroupType = Spark_Connect_Aggregate.GroupType
3434
typealias Join = Spark_Connect_Join
3535
typealias JoinType = Spark_Connect_Join.JoinType
3636
typealias KeyValue = Spark_Connect_KeyValue
37+
typealias LateralJoin = Spark_Connect_LateralJoin
3738
typealias Limit = Spark_Connect_Limit
3839
typealias MapType = Spark_Connect_DataType.Map
3940
typealias NamedTable = Spark_Connect_Read.NamedTable

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,26 @@ struct DataFrameTests {
479479
await spark.stop()
480480
}
481481

482+
@Test
483+
func lateralJoin() async throws {
484+
let spark = try await SparkSession.builder.getOrCreate()
485+
let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)")
486+
let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)")
487+
let expectedCross = [
488+
Row("a", "1", "c", "2"),
489+
Row("a", "1", "d", "3"),
490+
Row("b", "2", "c", "2"),
491+
Row("b", "2", "d", "3"),
492+
]
493+
#expect(try await df1.lateralJoin(df2).collect() == expectedCross)
494+
#expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross)
495+
496+
let expected = [Row("b", "2", "c", "2")]
497+
#expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected)
498+
#expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
499+
await spark.stop()
500+
}
501+
482502
@Test
483503
func except() async throws {
484504
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)