diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index e25fdc3e05ab..e270294ef2be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.VersionUtils +import org.apache.spark.util.{SizeEstimator, VersionUtils} /** * Common params for FPGrowth and FPGrowthModel @@ -324,8 +324,13 @@ class FPGrowthModel private[ml] ( } override def estimatedSize: Long = { - // TODO: Implement this method. - throw new UnsupportedOperationException + freqItemsets match { + case df: org.apache.spark.sql.classic.DataFrame => + df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _) + + SizeEstimator.estimate(itemSupport) + case o => throw new UnsupportedOperationException( + s"Unsupported dataframe type: ${o.getClass.getName}") + } } } diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py index 7b949763c398..ea94216c9860 100644 --- a/python/pyspark/ml/tests/test_fpm.py +++ b/python/pyspark/ml/tests/test_fpm.py @@ -18,7 +18,7 @@ import tempfile import unittest -from pyspark.sql import is_remote, Row +from pyspark.sql import Row import pyspark.sql.functions as sf from pyspark.ml.fpm import ( FPGrowth, @@ -30,8 +30,6 @@ class FPMTestsMixin: def test_fp_growth(self): - if is_remote(): - self.skipTest("Do not support Spark Connect.") df = self.spark.createDataFrame( [ ["r z h k p"], diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 40f1172677a5..3a53aa77fde6 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -229,11 +229,6 @@ private[connect] object MLHandler extends Logging { } catch { case _: UnsupportedOperationException => () } - if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") { - throw MlUnsupportedException( - "FPGrowth algorithm is not supported " + - "if Spark Connect model cache offloading is enabled.") - } if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA" && estimator .asInstanceOf[org.apache.spark.ml.clustering.LDA] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index d73918586b09..d02b63b49ca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2378,7 +2378,7 @@ class Dataset[T] private[sql]( sparkSession.sessionState.conf.arrowUseLargeVarTypes) } - private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + private[spark] def toArrowBatchRdd: RDD[Array[Byte]] = { toArrowBatchRdd(queryExecution.executedPlan) } }