@@ -61,7 +61,12 @@ def conf(cls):
61
61
cfg .set ("spark.sql.shuffle.partitions" , "5" )
62
62
return cfg
63
63
64
- def _test_apply_in_pandas_with_state_basic (self , func , check_results ):
64
+ def _test_apply_in_pandas_with_state_basic (self , func , check_results , output_type = None ):
65
+ if output_type is None :
66
+ output_type = StructType (
67
+ [StructField ("key" , StringType ()), StructField ("countAsString" , StringType ())]
68
+ )
69
+
65
70
input_path = tempfile .mkdtemp ()
66
71
67
72
def prepare_test_resource ():
@@ -77,9 +82,6 @@ def prepare_test_resource():
77
82
q .stop ()
78
83
self .assertTrue (df .isStreaming )
79
84
80
- output_type = StructType (
81
- [StructField ("key" , StringType ()), StructField ("countAsString" , StringType ())]
82
- )
83
85
state_type = StructType ([StructField ("c" , LongType ())])
84
86
85
87
q = (
@@ -316,90 +318,25 @@ def assert_test():
316
318
finally :
317
319
q .stop ()
318
320
319
- def _test_apply_in_pandas_with_state_decimal_coercion (self , coercion_enabled , should_succeed ):
320
- input_path = tempfile .mkdtemp ()
321
-
322
- with open (input_path + "/numeric-test.txt" , "w" ) as fw :
323
- fw .write ("group1,123\n group2,456\n group1,789\n " )
324
-
325
- df = (
326
- self .spark .readStream .format ("csv" )
327
- .option ("header" , "false" )
328
- .schema ("key string, value int" )
329
- .load (input_path )
330
- )
321
+ def test_apply_in_pandas_with_state_int_to_decimal_coercion (self ):
322
+ def func (key , pdf_iter , state ):
323
+ assert isinstance (state , GroupState )
324
+ yield pd .DataFrame ({"key" : [key [0 ]], "decimal_sum" : [1 ]})
331
325
332
- for q in self .spark .streams .active :
333
- q .stop ()
334
- self .assertTrue (df .isStreaming )
326
+ def check_results (batch_df , _ ):
327
+ assert set (batch_df .sort ("key" ).collect ()) == {
328
+ Row (key = "hello" , decimal_sum = Decimal ("1.00" )),
329
+ Row (key = "this" , decimal_sum = Decimal ("1.00" )),
330
+ }, "Decimal coercion failed: " + str (batch_df .sort ("key" ).collect ())
335
331
336
332
output_type = StructType (
337
- [
338
- StructField ("key" , StringType ()),
339
- StructField ("decimal_sum" , DecimalType (10 , 2 )),
340
- StructField ("count" , LongType ()),
341
- ]
333
+ [StructField ("key" , StringType ()), StructField ("decimal_sum" , DecimalType (10 , 2 ))]
342
334
)
343
- state_type = StructType ([StructField ("sum" , LongType ()), StructField ("count" , LongType ())])
344
-
345
- def stateful_func (key , pdf_iter , state ):
346
- current_sum = state .get [0 ] if state .exists else 0
347
- current_count = state .get [1 ] if state .exists else 0
348
-
349
- for pdf in pdf_iter :
350
- current_sum += pdf ["value" ].sum ()
351
- current_count += len (pdf )
352
-
353
- state .update ((current_sum , current_count ))
354
- yield pd .DataFrame (
355
- {"key" : [key [0 ]], "decimal_sum" : [current_sum ], "count" : [current_count ]}
356
- )
357
-
358
- def check_results (batch_df , _ ):
359
- if should_succeed :
360
- results = batch_df .sort ("key" ).collect ()
361
- for row in results :
362
- assert isinstance (row ["decimal_sum" ], Decimal )
363
- if row ["key" ] == "group1" :
364
- assert row ["decimal_sum" ] == Decimal ("912.00" )
365
- elif row ["key" ] == "group2" :
366
- assert row ["decimal_sum" ] == Decimal ("456.00" )
367
335
368
336
with self .sql_conf (
369
- {"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled" : coercion_enabled }
337
+ {"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled" : True }
370
338
):
371
- q = (
372
- df .groupBy (df ["key" ])
373
- .applyInPandasWithState (
374
- stateful_func , output_type , state_type , "Update" , GroupStateTimeout .NoTimeout
375
- )
376
- .writeStream .queryName (f"test_coercion_{ coercion_enabled } " )
377
- .foreachBatch (check_results )
378
- .outputMode ("update" )
379
- .start ()
380
- )
381
-
382
- self .assertTrue (q .isActive )
383
-
384
- if should_succeed :
385
- q .processAllAvailable ()
386
- self .assertTrue (q .exception () is None )
387
- else :
388
- with self .assertRaises (Exception ) as context :
389
- q .processAllAvailable ()
390
- self .assertIn ("STREAM_FAILED" , str (context .exception ))
391
-
392
- q .stop ()
393
-
394
- def test_apply_in_pandas_with_state_int_to_decimal_coercion (self ):
395
- self ._test_apply_in_pandas_with_state_decimal_coercion (
396
- coercion_enabled = True , should_succeed = True
397
- )
398
-
399
- self ._test_apply_in_pandas_with_state_decimal_coercion (
400
- coercion_enabled = False , should_succeed = False
401
- )
402
-
339
+ self ._test_apply_in_pandas_with_state_basic (func , check_results , output_type )
403
340
404
341
class GroupedApplyInPandasWithStateTests (
405
342
GroupedApplyInPandasWithStateTestsMixin , ReusedSQLTestCase
0 commit comments