Skip to content

Commit e65bd5f

Browse files
committed
[SPARK-52839][PYTHON][CONNECT][TESTS] Enable Arrow aggregation tests in connect
### What changes were proposed in this pull request? Enable Arrow aggregation tests in connect ### Why are the changes needed? For test coverage ### Does this PR introduce _any_ user-facing change? No, test-only ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #51530 from zhengruifeng/arrow_udf_agg_connect. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent efbb06f commit e65bd5f

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ def __hash__(self):
11131113
"pyspark.sql.tests.connect.arrow.test_parity_arrow_python_udf",
11141114
"pyspark.sql.tests.connect.arrow.test_parity_arrow_udf",
11151115
"pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_scalar",
1116+
"pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_grouped_agg",
11161117
"pyspark.sql.tests.connect.pandas.test_parity_pandas_map",
11171118
"pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map",
11181119
"pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map_with_state",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import time
20+
21+
from pyspark.sql.tests.arrow.test_arrow_udf_grouped_agg import GroupedAggArrowUDFTestsMixin
22+
from pyspark.testing.connectutils import ReusedConnectTestCase
23+
24+
25+
class GroupedAggArrowPythonUDFParityTests(GroupedAggArrowUDFTestsMixin, ReusedConnectTestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
ReusedConnectTestCase.setUpClass()
29+
30+
# Synchronize default timezone between Python and Java
31+
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
32+
tz = "America/Los_Angeles"
33+
os.environ["TZ"] = tz
34+
time.tzset()
35+
36+
cls.spark.conf.set("spark.sql.session.timeZone", tz)
37+
38+
@classmethod
39+
def tearDownClass(cls):
40+
del os.environ["TZ"]
41+
if cls.tz_prev is not None:
42+
os.environ["TZ"] = cls.tz_prev
43+
time.tzset()
44+
ReusedConnectTestCase.tearDownClass()
45+
46+
47+
if __name__ == "__main__":
48+
import unittest
49+
from pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_grouped_agg import * # noqa: F401
50+
51+
try:
52+
import xmlrunner # type: ignore[import]
53+
54+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
55+
except ImportError:
56+
testRunner = None
57+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)