diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7267ee2805987..7ccb6db8d1023 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -51,6 +51,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.common import inherit_doc from pyspark.ml.stat import MultivariateGaussian +from pyspark.ml.util import RemoteModelRef from pyspark.sql import DataFrame from pyspark.ml.linalg import Vector, Matrix from pyspark.sql.utils import is_remote @@ -1581,10 +1582,9 @@ def toLocal(self) -> "LocalLDAModel": .. warning:: This involves collecting a large :py:func:`topicsMatrix` to the driver. """ - model = LocalLDAModel(self._call_java("toLocal")) if is_remote(): - return model - + return LocalLDAModel(RemoteModelRef(self._call_java("toLocal"))) + model = LocalLDAModel(self._call_java("toLocal")) # SPARK-10931: Temporary fix to be removed once LDAModel defines Params model._create_params_from_java() model._transfer_params_from_java() diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index 1b8eb73135a96..5ed1315300de9 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -378,8 +378,6 @@ def test_local_lda(self): self.assertEqual(str(model), str(model2)) def test_distributed_lda(self): - if is_remote(): - self.skipTest("Do not support Spark Connect.") spark = self.spark df = ( spark.createDataFrame(