Skip to content

Commit cebd97b

Browse files
committed
simplify test
1 parent 0c6d495 commit cebd97b

File tree

1 file changed

+18
-81
lines changed

1 file changed

+18
-81
lines changed

python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py

Lines changed: 18 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ def conf(cls):
6161
cfg.set("spark.sql.shuffle.partitions", "5")
6262
return cfg
6363

64-
def _test_apply_in_pandas_with_state_basic(self, func, check_results):
64+
def _test_apply_in_pandas_with_state_basic(self, func, check_results, output_type=None):
65+
if output_type is None:
66+
output_type = StructType(
67+
[StructField("key", StringType()), StructField("countAsString", StringType())]
68+
)
69+
6570
input_path = tempfile.mkdtemp()
6671

6772
def prepare_test_resource():
@@ -77,9 +82,6 @@ def prepare_test_resource():
7782
q.stop()
7883
self.assertTrue(df.isStreaming)
7984

80-
output_type = StructType(
81-
[StructField("key", StringType()), StructField("countAsString", StringType())]
82-
)
8385
state_type = StructType([StructField("c", LongType())])
8486

8587
q = (
@@ -316,90 +318,25 @@ def assert_test():
316318
finally:
317319
q.stop()
318320

319-
def _test_apply_in_pandas_with_state_decimal_coercion(self, coercion_enabled, should_succeed):
320-
input_path = tempfile.mkdtemp()
321-
322-
with open(input_path + "/numeric-test.txt", "w") as fw:
323-
fw.write("group1,123\ngroup2,456\ngroup1,789\n")
324-
325-
df = (
326-
self.spark.readStream.format("csv")
327-
.option("header", "false")
328-
.schema("key string, value int")
329-
.load(input_path)
330-
)
321+
def test_apply_in_pandas_with_state_int_to_decimal_coercion(self):
322+
def func(key, pdf_iter, state):
323+
assert isinstance(state, GroupState)
324+
yield pd.DataFrame({"key": [key[0]], "decimal_sum": [1]})
331325

332-
for q in self.spark.streams.active:
333-
q.stop()
334-
self.assertTrue(df.isStreaming)
326+
def check_results(batch_df, _):
327+
assert set(batch_df.sort("key").collect()) == {
328+
Row(key="hello", decimal_sum=Decimal("1.00")),
329+
Row(key="this", decimal_sum=Decimal("1.00")),
330+
}, "Decimal coercion failed: " + str(batch_df.sort("key").collect())
335331

336332
output_type = StructType(
337-
[
338-
StructField("key", StringType()),
339-
StructField("decimal_sum", DecimalType(10, 2)),
340-
StructField("count", LongType()),
341-
]
333+
[StructField("key", StringType()), StructField("decimal_sum", DecimalType(10, 2))]
342334
)
343-
state_type = StructType([StructField("sum", LongType()), StructField("count", LongType())])
344-
345-
def stateful_func(key, pdf_iter, state):
346-
current_sum = state.get[0] if state.exists else 0
347-
current_count = state.get[1] if state.exists else 0
348-
349-
for pdf in pdf_iter:
350-
current_sum += pdf["value"].sum()
351-
current_count += len(pdf)
352-
353-
state.update((current_sum, current_count))
354-
yield pd.DataFrame(
355-
{"key": [key[0]], "decimal_sum": [current_sum], "count": [current_count]}
356-
)
357-
358-
def check_results(batch_df, _):
359-
if should_succeed:
360-
results = batch_df.sort("key").collect()
361-
for row in results:
362-
assert isinstance(row["decimal_sum"], Decimal)
363-
if row["key"] == "group1":
364-
assert row["decimal_sum"] == Decimal("912.00")
365-
elif row["key"] == "group2":
366-
assert row["decimal_sum"] == Decimal("456.00")
367335

368336
with self.sql_conf(
369-
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": coercion_enabled}
337+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
370338
):
371-
q = (
372-
df.groupBy(df["key"])
373-
.applyInPandasWithState(
374-
stateful_func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout
375-
)
376-
.writeStream.queryName(f"test_coercion_{coercion_enabled}")
377-
.foreachBatch(check_results)
378-
.outputMode("update")
379-
.start()
380-
)
381-
382-
self.assertTrue(q.isActive)
383-
384-
if should_succeed:
385-
q.processAllAvailable()
386-
self.assertTrue(q.exception() is None)
387-
else:
388-
with self.assertRaises(Exception) as context:
389-
q.processAllAvailable()
390-
self.assertIn("STREAM_FAILED", str(context.exception))
391-
392-
q.stop()
393-
394-
def test_apply_in_pandas_with_state_int_to_decimal_coercion(self):
395-
self._test_apply_in_pandas_with_state_decimal_coercion(
396-
coercion_enabled=True, should_succeed=True
397-
)
398-
399-
self._test_apply_in_pandas_with_state_decimal_coercion(
400-
coercion_enabled=False, should_succeed=False
401-
)
402-
339+
self._test_apply_in_pandas_with_state_basic(func, check_results, output_type)
403340

404341
class GroupedApplyInPandasWithStateTests(
405342
GroupedApplyInPandasWithStateTestsMixin, ReusedSQLTestCase

0 commit comments

Comments
 (0)