diff --git a/python/pyspark/ml/tests/connect/test_parity_classification.py b/python/pyspark/ml/tests/connect/test_parity_classification.py index 7805546dba707..c43e784fd9a1c 100644 --- a/python/pyspark/ml/tests/connect/test_parity_classification.py +++ b/python/pyspark/ml/tests/connect/test_parity_classification.py @@ -17,14 +17,46 @@ import unittest +from pyspark.testing.utils import eventually, timeout from pyspark.ml.tests.test_classification import ClassificationTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -# TODO(SPARK-52764): Re-enable this test after fixing the flakiness. -@unittest.skip("Disabled due to flakiness, should be enabled after fixing the issue") class ClassificationParityTests(ClassificationTestsMixin, ReusedConnectTestCase): - pass + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_binary_logistic_regression_summary(self): + super().test_binary_logistic_regression_summary() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_multiclass_logistic_regression_summary(self): + super().test_multiclass_logistic_regression_summary() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_linear_svc(self): + super().test_linear_svc() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_factorization_machine(self): + super().test_factorization_machine() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_binary_random_forest_classifier(self): + super().test_binary_random_forest_classifier() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_multiclass_random_forest_classifier(self): + super().test_multiclass_random_forest_classifier() + + @eventually(timeout=60, catch_timeout=True) + @timeout(timeout=10) + def test_mlp(self): + super().test_mlp() if __name__ == "__main__":