Skip to content

Commit 7db7f3c

Browse files
committed
[SPARK-54574][ML][CONNECT] Reenable FPGrowth on connect
### What changes were proposed in this pull request? Reenable FPGrowth on Connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes, FPGrowth will be available on connect ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #53294 from zhengruifeng/fpgrowth_model_size. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent a00709b commit 7db7f3c

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.expressions.SparkUserDefinedFunction
3636
import org.apache.spark.sql.functions._
3737
import org.apache.spark.sql.types._
3838
import org.apache.spark.storage.StorageLevel
39-
import org.apache.spark.util.VersionUtils
39+
import org.apache.spark.util.{SizeEstimator, VersionUtils}
4040

4141
/**
4242
* Common params for FPGrowth and FPGrowthModel
@@ -324,8 +324,13 @@ class FPGrowthModel private[ml] (
324324
}
325325

326326
override def estimatedSize: Long = {
327-
// TODO: Implement this method.
328-
throw new UnsupportedOperationException
327+
freqItemsets match {
328+
case df: org.apache.spark.sql.classic.DataFrame =>
329+
df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _) +
330+
SizeEstimator.estimate(itemSupport)
331+
case o => throw new UnsupportedOperationException(
332+
s"Unsupported dataframe type: ${o.getClass.getName}")
333+
}
329334
}
330335
}
331336

python/pyspark/ml/tests/test_fpm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tempfile
1919
import unittest
2020

21-
from pyspark.sql import is_remote, Row
21+
from pyspark.sql import Row
2222
import pyspark.sql.functions as sf
2323
from pyspark.ml.fpm import (
2424
FPGrowth,
@@ -30,8 +30,6 @@
3030

3131
class FPMTestsMixin:
3232
def test_fp_growth(self):
33-
if is_remote():
34-
self.skipTest("Do not support Spark Connect.")
3533
df = self.spark.createDataFrame(
3634
[
3735
["r z h k p"],

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,6 @@ private[connect] object MLHandler extends Logging {
229229
} catch {
230230
case _: UnsupportedOperationException => ()
231231
}
232-
if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") {
233-
throw MlUnsupportedException(
234-
"FPGrowth algorithm is not supported " +
235-
"if Spark Connect model cache offloading is enabled.")
236-
}
237232
if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA"
238233
&& estimator
239234
.asInstanceOf[org.apache.spark.ml.clustering.LDA]

sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2378,7 +2378,7 @@ class Dataset[T] private[sql](
23782378
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
23792379
}
23802380

2381-
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
2381+
private[spark] def toArrowBatchRdd: RDD[Array[Byte]] = {
23822382
toArrowBatchRdd(queryExecution.executedPlan)
23832383
}
23842384
}

0 commit comments

Comments
 (0)