Skip to content

Commit cab692c

Browse files
committed
[SPARK-51917] Add DataFrameWriterV2 actor
### What changes were proposed in this pull request? This PR aims to add `DataFrameWriterV2` actor. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No behavior change because this is an additional API. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #91 from dongjoon-hyun/SPARK-51917. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent ebc4dbf commit cab692c

File tree

4 files changed

+249
-0
lines changed

4 files changed

+249
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,4 +852,11 @@ public actor DataFrame: Sendable {
852852
return DataFrameWriter(df: self)
853853
}
854854
}
855+
856+
/// Create a write configuration builder for v2 sources.
857+
/// - Parameter table: A table name, e.g., `catalog.db.table`.
858+
/// - Returns: A ``DataFrameWriterV2`` instance.
859+
public func writeTo(_ table: String) -> DataFrameWriterV2 {
860+
return DataFrameWriterV2(table, self)
861+
}
855862
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Foundation
20+
21+
/// Interface used to write a ``DataFrame`` to external storage using the v2 API.
22+
public actor DataFrameWriterV2: Sendable {
23+
24+
let tableName: String
25+
26+
let df: DataFrame
27+
28+
var provider: String? = nil
29+
30+
var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary()
31+
32+
var tableProperties: CaseInsensitiveDictionary = CaseInsensitiveDictionary()
33+
34+
var partitioningColumns: [Spark_Connect_Expression] = []
35+
36+
var clusteringColumns: [String]? = nil
37+
38+
init(_ table: String, _ df: DataFrame) {
39+
self.tableName = table
40+
self.df = df
41+
}
42+
43+
/// Specifies a provider for the underlying output data source. Spark's default catalog supports
44+
/// "orc", "json", etc.
45+
/// - Parameter provider: <#provider description#>
46+
public func using(_ provider: String) -> DataFrameWriterV2 {
47+
self.provider = provider
48+
return self
49+
}
50+
51+
/// Adds an output option for the underlying data source.
52+
/// - Parameters:
53+
/// - key: A key string.
54+
/// - value: A value string.
55+
/// - Returns: A `DataFrameWriter`.
56+
public func option(_ key: String, _ value: String) -> DataFrameWriterV2 {
57+
self.extraOptions[key] = value
58+
return self
59+
}
60+
61+
/// Add a table property.
62+
/// - Parameters:
63+
/// - property: A property name.
64+
/// - value: A property value.
65+
public func tableProperty(property: String, value: String) -> DataFrameWriterV2 {
66+
self.tableProperties[property] = value
67+
return self
68+
}
69+
70+
/// Partition the output table created by `create`, `createOrReplace`, or `replace` using the
71+
/// given columns or transforms.
72+
/// - Parameter columns: Columns to partition
73+
/// - Returns: A ``DataFrameWriterV2``.
74+
public func partitionBy(_ columns: String...) -> DataFrameWriterV2 {
75+
self.partitioningColumns = columns.map {
76+
var expr = Spark_Connect_Expression()
77+
expr.expressionString = $0.toExpressionString
78+
return expr
79+
}
80+
return self
81+
}
82+
83+
/// Clusters the output by the given columns on the storage. The rows with matching values in the
84+
/// specified clustering columns will be consolidated within the same group.
85+
/// - Parameter columns: Columns to cluster
86+
/// - Returns: A ``DataFrameWriterV2``.
87+
public func clusterBy(_ columns: String...) -> DataFrameWriterV2 {
88+
self.clusteringColumns = columns
89+
return self
90+
}
91+
92+
/// Create a new table from the contents of the data frame.
93+
public func create() async throws {
94+
try await executeWriteOperation(.create)
95+
}
96+
97+
/// Replace an existing table with the contents of the data frame.
98+
public func replace() async throws {
99+
try await executeWriteOperation(.replace)
100+
}
101+
102+
/// Create a new table or replace an existing table with the contents of the data frame.
103+
public func createOrReplace() async throws {
104+
try await executeWriteOperation(.createOrReplace)
105+
}
106+
107+
/// Append the contents of the data frame to the output table.
108+
public func append() async throws {
109+
try await executeWriteOperation(.append)
110+
}
111+
112+
/// Overwrite rows matching the given filter condition with the contents of the ``DataFrame`` in the
113+
/// output table.
114+
/// - Parameter condition: A filter condition.
115+
public func overwrite(condition: String) async throws {
116+
try await executeWriteOperation(.overwrite)
117+
}
118+
119+
/// Overwrite all partition for which the ``DataFrame`` contains at least one row with the contents
120+
/// of the data frame in the output table.
121+
/// This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
122+
/// partitions dynamically depending on the contents of the ``DataFrame``.
123+
public func overwritePartitions() async throws {
124+
try await executeWriteOperation(.overwritePartitions)
125+
}
126+
127+
private func executeWriteOperation(_ mode: WriteOperationV2.Mode) async throws {
128+
var write = WriteOperationV2()
129+
130+
let plan = await self.df.getPlan() as! Plan
131+
write.input = plan.root
132+
write.tableName = self.tableName
133+
if let provider = self.provider {
134+
write.provider = provider
135+
}
136+
write.partitioningColumns = self.partitioningColumns
137+
if let clusteringColumns = self.clusteringColumns {
138+
write.clusteringColumns = clusteringColumns
139+
}
140+
for option in self.extraOptions.toStringDictionary() {
141+
write.options[option.key] = option.value
142+
}
143+
for property in self.tableProperties.toStringDictionary() {
144+
write.tableProperties[property.key] = property.value
145+
}
146+
write.mode = mode
147+
148+
var command = Spark_Connect_Command()
149+
command.writeOperationV2 = write
150+
_ = try await df.spark.client.execute(df.spark.sessionID, command)
151+
}
152+
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,5 @@ typealias UserContext = Spark_Connect_UserContext
6060
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
6161
typealias WithColumnsRenamed = Spark_Connect_WithColumnsRenamed
6262
typealias WriteOperation = Spark_Connect_WriteOperation
63+
typealias WriteOperationV2 = Spark_Connect_WriteOperationV2
6364
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import SparkConnect
22+
import Testing
23+
24+
/// A test suite for `DataFrameWriterV2`
25+
struct DataFrameWriterV2Tests {
26+
27+
@Test
28+
func create() async throws {
29+
let spark = try await SparkSession.builder.getOrCreate()
30+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
31+
try await SQLHelper.withTable(spark, tableName)({
32+
let write = try await spark.range(2).writeTo(tableName).using("orc")
33+
try await write.create()
34+
#expect(try await spark.table(tableName).count() == 2)
35+
try await #require(throws: Error.self) {
36+
try await write.create()
37+
}
38+
})
39+
await spark.stop()
40+
}
41+
42+
@Test
43+
func createOrReplace() async throws {
44+
let spark = try await SparkSession.builder.getOrCreate()
45+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
46+
try await SQLHelper.withTable(spark, tableName)({
47+
let write = try await spark.range(2).writeTo(tableName).using("orc")
48+
try await write.create()
49+
#expect(try await spark.table(tableName).count() == 2)
50+
// TODO: Use Iceberg to verify success case after Iceberg supports Apache Spark 4
51+
try await #require(throws: Error.self) {
52+
try await write.createOrReplace()
53+
}
54+
})
55+
await spark.stop()
56+
}
57+
58+
@Test
59+
func replace() async throws {
60+
let spark = try await SparkSession.builder.getOrCreate()
61+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
62+
try await SQLHelper.withTable(spark, tableName)({
63+
let write = try await spark.range(2).writeTo(tableName).using("orc")
64+
try await write.create()
65+
#expect(try await spark.table(tableName).count() == 2)
66+
// TODO: Use Iceberg to verify success case after Iceberg supports Apache Spark 4
67+
try await #require(throws: Error.self) {
68+
try await write.replace()
69+
}
70+
})
71+
await spark.stop()
72+
}
73+
74+
@Test
75+
func append() async throws {
76+
let spark = try await SparkSession.builder.getOrCreate()
77+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
78+
try await SQLHelper.withTable(spark, tableName)({
79+
let write = try await spark.range(2).writeTo(tableName).using("orc")
80+
try await write.create()
81+
#expect(try await spark.table(tableName).count() == 2)
82+
// TODO: Use Iceberg to verify success case after Iceberg supports Apache Spark 4
83+
try await #require(throws: Error.self) {
84+
try await write.append()
85+
}
86+
})
87+
await spark.stop()
88+
}
89+
}

0 commit comments

Comments
 (0)