Skip to content

Commit d4e34f5

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-54446][SQL][ML][CONNECT] FPGrowth supports local filesystem with Arrow file format
### What changes were proposed in this pull request? FPGrowth supports local filesystem ### Why are the changes needed? to make FPGrowth work with local filesystem ### Does this PR introduce _any_ user-facing change? yes, FPGrowth will work when local saving mode is one ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #53232 from zhengruifeng/local_fs_fpg_with_file. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 250ccca commit d4e34f5

File tree

7 files changed

+255
-35
lines changed

7 files changed

+255
-35
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
343343
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
344344

345345
override protected def saveImpl(path: String): Unit = {
346-
if (ReadWriteUtils.localSavingModeState.get()) {
347-
throw new UnsupportedOperationException(
348-
"FPGrowthModel does not support saving to local filesystem path."
349-
)
350-
}
351346
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
352347
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
353348
extraMetadata = Some(extraMetadata))
354349
val dataPath = new Path(path, "data").toString
355-
instance.freqItemsets.write.parquet(dataPath)
350+
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
356351
}
357352
}
358353

@@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
362357
private val className = classOf[FPGrowthModel].getName
363358

364359
override def load(path: String): FPGrowthModel = {
365-
if (ReadWriteUtils.localSavingModeState.get()) {
366-
throw new UnsupportedOperationException(
367-
"FPGrowthModel does not support loading from local filesystem path."
368-
)
369-
}
370360
implicit val format = DefaultFormats
371361
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
372362
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
@@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
378368
(metadata.metadata \ "numTrainingRecords").extract[Long]
379369
}
380370
val dataPath = new Path(path, "data").toString
381-
val frequentItems = sparkSession.read.parquet(dataPath)
371+
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
382372
val itemSupport = if (numTrainingRecords == 0L) {
383373
Map.empty[Any, Double]
384374
} else {

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel
4646
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
4747
import org.apache.spark.ml.param.{ParamPair, Params}
4848
import org.apache.spark.ml.tuning.ValidatorParams
49-
import org.apache.spark.sql.{SparkSession, SQLContext}
49+
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
50+
import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite
5051
import org.apache.spark.util.{Utils, VersionUtils}
5152

5253
/**
@@ -1142,4 +1143,32 @@ private[spark] object ReadWriteUtils {
11421143
spark.read.parquet(path).as[T].collect()
11431144
}
11441145
}
1146+
1147+
def saveDataFrame(path: String, df: DataFrame): Unit = {
1148+
if (localSavingModeState.get()) {
1149+
df match {
1150+
case d: org.apache.spark.sql.classic.DataFrame =>
1151+
val filePath = Paths.get(path)
1152+
Files.createDirectories(filePath.getParent)
1153+
ArrowFileReadWrite.save(d, filePath)
1154+
case o => throw new UnsupportedOperationException(
1155+
s"Unsupported dataframe type: ${o.getClass.getName}")
1156+
}
1157+
} else {
1158+
df.write.parquet(path)
1159+
}
1160+
}
1161+
1162+
def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
1163+
if (localSavingModeState.get()) {
1164+
spark match {
1165+
case s: org.apache.spark.sql.classic.SparkSession =>
1166+
ArrowFileReadWrite.load(s, Paths.get(path))
1167+
case o => throw new UnsupportedOperationException(
1168+
s"Unsupported session type: ${o.getClass.getName}")
1169+
}
1170+
} else {
1171+
spark.read.parquet(path)
1172+
}
1173+
}
11451174
}

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
165165
}
166166
val fPGrowth = new FPGrowth()
167167
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
168-
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
168+
FPGrowthSuite.allParamSettings, checkModelData)
169169
}
170170
}
171171

sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,14 +2340,13 @@ class Dataset[T] private[sql](
23402340
}
23412341

23422342
/** Convert to an RDD of serialized ArrowRecordBatches. */
2343-
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
2343+
private def toArrowBatchRddImpl(
2344+
plan: SparkPlan,
2345+
maxRecordsPerBatch: Int,
2346+
timeZoneId: String,
2347+
errorOnDuplicatedFieldNames: Boolean,
2348+
largeVarTypes: Boolean): RDD[Array[Byte]] = {
23442349
val schemaCaptured = this.schema
2345-
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
2346-
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
2347-
val errorOnDuplicatedFieldNames =
2348-
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
2349-
val largeVarTypes =
2350-
sparkSession.sessionState.conf.arrowUseLargeVarTypes
23512350
plan.execute().mapPartitionsInternal { iter =>
23522351
val context = TaskContext.get()
23532352
ArrowConverters.toBatchIterator(
@@ -2361,7 +2360,24 @@ class Dataset[T] private[sql](
23612360
}
23622361
}
23632362

2364-
// This is only used in tests, for now.
2363+
private[sql] def toArrowBatchRdd(
2364+
maxRecordsPerBatch: Int,
2365+
timeZoneId: String,
2366+
errorOnDuplicatedFieldNames: Boolean,
2367+
largeVarTypes: Boolean): RDD[Array[Byte]] = {
2368+
toArrowBatchRddImpl(queryExecution.executedPlan,
2369+
maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
2370+
}
2371+
2372+
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
2373+
toArrowBatchRddImpl(
2374+
plan,
2375+
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch,
2376+
sparkSession.sessionState.conf.sessionLocalTimeZone,
2377+
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy",
2378+
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
2379+
}
2380+
23652381
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
23662382
toArrowBatchRdd(queryExecution.executedPlan)
23672383
}

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ private[sql] object ArrowConverters extends Logging {
128128
}
129129

130130
override def next(): Array[Byte] = {
131-
val out = new ByteArrayOutputStream()
132-
val writeChannel = new WriteChannel(Channels.newChannel(out))
131+
var bytes: Array[Byte] = null
133132

134133
Utils.tryWithSafeFinally {
135134
var rowCount = 0L
@@ -140,13 +139,13 @@ private[sql] object ArrowConverters extends Logging {
140139
}
141140
arrowWriter.finish()
142141
val batch = unloader.getRecordBatch()
143-
MessageSerializer.serialize(writeChannel, batch)
142+
bytes = serializeBatch(batch)
144143
batch.close()
145144
} {
146145
arrowWriter.reset()
147146
}
148147

149-
out.toByteArray
148+
bytes
150149
}
151150

152151
override def close(): Unit = {
@@ -548,32 +547,55 @@ private[sql] object ArrowConverters extends Logging {
548547
new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException
549548
}
550549

550+
private[arrow] def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = {
551+
val out = new ByteArrayOutputStream()
552+
val writeChannel = new WriteChannel(Channels.newChannel(out))
553+
MessageSerializer.serialize(writeChannel, batch)
554+
out.toByteArray
555+
}
556+
551557
/**
552558
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
553559
*/
554560
def toDataFrame(
555561
arrowBatches: Iterator[Array[Byte]],
556562
schemaString: String,
557563
session: SparkSession): DataFrame = {
558-
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
564+
toDataFrame(
565+
arrowBatches,
566+
DataType.fromJson(schemaString).asInstanceOf[StructType],
567+
session,
568+
session.sessionState.conf.sessionLocalTimeZone,
569+
false,
570+
session.sessionState.conf.arrowUseLargeVarTypes)
571+
}
572+
573+
/**
574+
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
575+
*/
576+
private[sql] def toDataFrame(
577+
arrowBatches: Iterator[Array[Byte]],
578+
schema: StructType,
579+
session: SparkSession,
580+
timeZoneId: String,
581+
errorOnDuplicatedFieldNames: Boolean,
582+
largeVarTypes: Boolean): DataFrame = {
559583
val attrs = toAttributes(schema)
560584
val batchesInDriver = arrowBatches.toArray
561-
val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
562585
val shouldUseRDD = session.sessionState.conf
563586
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum
564587

565588
if (shouldUseRDD) {
566589
logDebug("Using RDD-based createDataFrame with Arrow optimization.")
567-
val timezone = session.sessionState.conf.sessionLocalTimeZone
568590
val rdd = session.sparkContext
569591
.parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length)
570592
.mapPartitions { batchesInExecutors =>
571593
ArrowConverters.fromBatchIterator(
572594
batchesInExecutors,
573595
schema,
574-
timezone,
575-
errorOnDuplicatedFieldNames = false,
576-
largeVarTypes = largeVarTypes,
596+
timeZoneId,
597+
errorOnDuplicatedFieldNames,
598+
largeVarTypes,
577599
TaskContext.get())
578600
}
579601
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
@@ -582,9 +604,9 @@ private[sql] object ArrowConverters extends Logging {
582604
val data = ArrowConverters.fromBatchIterator(
583605
batchesInDriver.iterator,
584606
schema,
585-
session.sessionState.conf.sessionLocalTimeZone,
586-
errorOnDuplicatedFieldNames = false,
587-
largeVarTypes = largeVarTypes,
607+
timeZoneId,
608+
errorOnDuplicatedFieldNames,
609+
largeVarTypes,
588610
TaskContext.get())
589611

590612
// Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.arrow
19+
20+
import java.nio.channels.Channels
21+
import java.nio.file.{Files, Path}
22+
23+
import scala.jdk.CollectionConverters._
24+
25+
import org.apache.arrow.vector._
26+
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter}
27+
import org.apache.arrow.vector.types.pojo.Schema
28+
29+
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
30+
import org.apache.spark.sql.util.ArrowUtils
31+
32+
private[sql] class SparkArrowFileWriter(schema: Schema, path: Path) extends AutoCloseable {
33+
private val allocator = ArrowUtils.rootAllocator
34+
.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
35+
36+
protected val root = VectorSchemaRoot.create(schema, allocator)
37+
protected val loader = new VectorLoader(root)
38+
39+
protected val fileWriter =
40+
new ArrowFileWriter(root, null, Channels.newChannel(Files.newOutputStream(path)))
41+
42+
override def close(): Unit = {
43+
fileWriter.close()
44+
root.close()
45+
allocator.close()
46+
}
47+
48+
def write(batchBytesIter: Iterator[Array[Byte]]): Unit = {
49+
fileWriter.start()
50+
while (batchBytesIter.hasNext) {
51+
val batchBytes = batchBytesIter.next()
52+
val batch = ArrowConverters.loadBatch(batchBytes, allocator)
53+
loader.load(batch)
54+
fileWriter.writeBatch()
55+
batch.close()
56+
}
57+
fileWriter.close()
58+
}
59+
}
60+
61+
private[sql] class SparkArrowFileReader(path: Path) extends AutoCloseable {
62+
private val allocator = ArrowUtils.rootAllocator
63+
.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
64+
65+
protected val fileReader =
66+
new ArrowFileReader(Files.newByteChannel(path), allocator)
67+
68+
override def close(): Unit = {
69+
fileReader.close()
70+
allocator.close()
71+
}
72+
73+
val schema: Schema = fileReader.getVectorSchemaRoot.getSchema
74+
75+
def read(): Iterator[Array[Byte]] = {
76+
fileReader.getRecordBlocks.iterator().asScala.map { block =>
77+
fileReader.loadRecordBatch(block)
78+
val root = fileReader.getVectorSchemaRoot
79+
val unloader = new VectorUnloader(root)
80+
val batch = unloader.getRecordBatch
81+
val bytes = ArrowConverters.serializeBatch(batch)
82+
batch.close()
83+
bytes
84+
}
85+
}
86+
}
87+
88+
private[spark] object ArrowFileReadWrite {
89+
def save(df: DataFrame, path: Path): Unit = {
90+
val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
91+
val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false)
92+
val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false)
93+
val writer = new SparkArrowFileWriter(arrowSchema, path)
94+
writer.write(rdd.toLocalIterator)
95+
}
96+
97+
def load(spark: SparkSession, path: Path): DataFrame = {
98+
val reader = new SparkArrowFileReader(path)
99+
val schema = ArrowUtils.fromArrowSchema(reader.schema)
100+
ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false)
101+
}
102+
}

0 commit comments

Comments
 (0)