Skip to content

Commit 7251e95

Browse files
Yicong-Huangzhengruifeng
authored andcommitted
[SPARK-53615][PYTHON] Introduce iterator API for Arrow grouped aggregation UDF
### What changes were proposed in this pull request? This PR introduces an iterator API for Arrow grouped aggregation UDFs in PySpark. It adds support for two new UDF patterns: - `Iterator[pa.Array] -> Any` for single column aggregations - `Iterator[Tuple[pa.Array, ...]] -> Any` for multiple column aggregations The implementation adds a new Python eval type `SQL_GROUPED_AGG_ARROW_ITER_UDF` with corresponding support in type inference, worker serialization, and Scala execution planning. ### Why are the changes needed? The current Arrow grouped aggregation API requires loading all data for a group into memory at once, which can be problematic for groups with large amounts of data. The iterator API allows processing data in batches, providing: 1. **Memory Efficiency**: Processes data incrementally rather than loading entire group into memory 2. **Consistency**: Aligns with existing iterator APIs (e.g., `SQL_SCALAR_ARROW_ITER_UDF`) 3. **Flexibility**: Allows initialization of expensive state once per group while processing batches iteratively ### Does this PR introduce _any_ user-facing change? Yes. This PR adds a new API pattern for Arrow grouped aggregation UDFs: **Single column aggregation:** ```python import pyarrow as pa from typing import Iterator from pyspark.sql.functions import arrow_udf arrow_udf("double") def arrow_mean(it: Iterator[pa.Array]) -> float: sum_val = 0.0 cnt = 0 for v in it: sum_val += pa.compute.sum(v).as_py() cnt += len(v) return sum_val / cnt if cnt > 0 else 0.0 df.groupby("id").agg(arrow_mean(df['v'])).show() ``` **Multiple column aggregation:** ```python import pyarrow as pa import numpy as np from typing import Iterator, Tuple from pyspark.sql.functions import arrow_udf arrow_udf("double") def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float: weighted_sum = 0.0 weight = 0.0 for v, w in it: weighted_sum += np.dot(v.to_numpy(), w.to_numpy()) weight += pa.compute.sum(w).as_py() return weighted_sum / weight if weight > 0 else 0.0 df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() ``` ### How was this patch tested? Added comprehensive unit tests in `python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py`: 1. `test_iterator_grouped_agg_single_column()` - Tests single column iterator aggregation with `Iterator[pa.Array]` 2. `test_iterator_grouped_agg_multiple_columns()` - Tests multiple column iterator aggregation with `Iterator[Tuple[pa.Array, pa.Array]]` 3. `test_iterator_grouped_agg_eval_type()` - Verifies correct eval type inference from type hints ### Was this patch authored or co-authored using generative AI tooling? Co-Generated-by: Cursor with Claude Sonnet 4.5 Closes #53035 from Yicong-Huang/SPARK-53615/feat/arrow-grouped-agg-iterator-api. Authored-by: Yicong-Huang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent d4e34f5 commit 7251e95

File tree

12 files changed

+456
-3
lines changed

12 files changed

+456
-3
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ private[spark] object PythonEvalType {
7575
val SQL_SCALAR_ARROW_ITER_UDF = 251
7676
val SQL_GROUPED_AGG_ARROW_UDF = 252
7777
val SQL_WINDOW_AGG_ARROW_UDF = 253
78+
val SQL_GROUPED_AGG_ARROW_ITER_UDF = 254
7879

7980
val SQL_TABLE_UDF = 300
8081
val SQL_ARROW_TABLE_UDF = 301
@@ -112,6 +113,7 @@ private[spark] object PythonEvalType {
112113
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
113114
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
114115
case SQL_WINDOW_AGG_ARROW_UDF => "SQL_WINDOW_AGG_ARROW_UDF"
116+
case SQL_GROUPED_AGG_ARROW_ITER_UDF => "SQL_GROUPED_AGG_ARROW_ITER_UDF"
115117
}
116118
}
117119

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ ArrowScalarUDFType = Literal[250]
6666
ArrowScalarIterUDFType = Literal[251]
6767
ArrowGroupedAggUDFType = Literal[252]
6868
ArrowWindowAggUDFType = Literal[253]
69+
ArrowGroupedAggIterUDFType = Literal[254]
6970

7071
class ArrowVariadicScalarToScalarFunction(Protocol):
7172
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...

python/pyspark/sql/pandas/functions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class ArrowUDFType:
5050

5151
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
5252

53+
GROUPED_AGG_ITER = PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF
54+
5355

5456
def arrow_udf(f=None, returnType=None, functionType=None):
5557
"""
@@ -301,6 +303,69 @@ def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
301303
Therefore, mutating the input arrays is not allowed and will cause incorrect results.
302304
For the same reason, users should also not rely on the index of the input arrays.
303305
306+
* Iterator of Arrays to Scalar
307+
`Iterator[pyarrow.Array]` -> `Any`
308+
309+
The function takes an iterator of `pyarrow.Array` and returns a scalar value. This is
310+
useful for grouped aggregations where the UDF can process all batches for a group
311+
iteratively, which is more memory-efficient than loading all data at once. The returned
312+
scalar can be a python primitive type, a numpy data type, or a `pyarrow.Scalar` instance.
313+
314+
.. note:: Only a single UDF is supported per aggregation.
315+
316+
>>> from typing import Iterator
317+
>>> @arrow_udf("double")
318+
... def arrow_mean(it: Iterator[pa.Array]) -> float:
319+
... sum_val = 0.0
320+
... cnt = 0
321+
... for v in it:
322+
... assert isinstance(v, pa.Array)
323+
... sum_val += pa.compute.sum(v).as_py()
324+
... cnt += len(v)
325+
... return sum_val / cnt
326+
...
327+
>>> df = spark.createDataFrame(
328+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
329+
>>> df.groupby("id").agg(arrow_mean(df['v'])).show()
330+
+---+-------------+
331+
| id|arrow_mean(v)|
332+
+---+-------------+
333+
| 1| 1.5|
334+
| 2| 6.0|
335+
+---+-------------+
336+
337+
* Iterator of Multiple Arrays to Scalar
338+
`Iterator[Tuple[pyarrow.Array, ...]]` -> `Any`
339+
340+
The function takes an iterator of a tuple of multiple `pyarrow.Array` and returns a
341+
scalar value. This is useful for grouped aggregations with multiple input columns.
342+
343+
.. note:: Only a single UDF is supported per aggregation.
344+
345+
>>> from typing import Iterator, Tuple
346+
>>> import numpy as np
347+
>>> @arrow_udf("double")
348+
... def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
349+
... weighted_sum = 0.0
350+
... weight = 0.0
351+
... for v, w in it:
352+
... assert isinstance(v, pa.Array)
353+
... assert isinstance(w, pa.Array)
354+
... weighted_sum += np.dot(v, w)
355+
... weight += pa.compute.sum(w).as_py()
356+
... return weighted_sum / weight
357+
...
358+
>>> df = spark.createDataFrame(
359+
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
360+
... ("id", "v", "w"))
361+
>>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show()
362+
+---+-------------------------+
363+
| id|arrow_weighted_mean(v, w)|
364+
+---+-------------------------+
365+
| 1| 1.6666666666666...|
366+
| 2| 7.166666666666...|
367+
+---+-------------------------+
368+
304369
Notes
305370
-----
306371
The user-defined functions do not support conditional expressions or short circuiting
@@ -720,6 +785,7 @@ def vectorized_udf(
720785
PythonEvalType.SQL_SCALAR_ARROW_UDF,
721786
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
722787
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
788+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
723789
None,
724790
]: # None means it should infer the type from type hints.
725791
raise PySparkTypeError(
@@ -768,6 +834,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
768834
PythonEvalType.SQL_SCALAR_ARROW_UDF,
769835
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
770836
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
837+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
771838
]:
772839
warnings.warn(
773840
"It is preferred to specify type hints for "

python/pyspark/sql/pandas/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ from pyspark.sql.pandas._typing import (
4141
ArrowScalarIterFunction,
4242
ArrowScalarIterUDFType,
4343
ArrowGroupedAggUDFType,
44+
ArrowGroupedAggIterUDFType,
4445
)
4546

4647
from pyspark import since as since # noqa: F401
@@ -57,6 +58,7 @@ class ArrowUDFType:
5758
SCALAR: ArrowScalarUDFType
5859
SCALAR_ITER: ArrowScalarIterUDFType
5960
GROUPED_AGG: ArrowGroupedAggUDFType
61+
GROUPED_AGG_ITER: ArrowGroupedAggIterUDFType
6062

6163
@overload
6264
def arrow_udf(

python/pyspark/sql/pandas/serializers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,58 @@ def __repr__(self):
11851185
return "ArrowStreamAggArrowUDFSerializer"
11861186

11871187

1188+
# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
1189+
class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
1190+
def __init__(
1191+
self,
1192+
timezone,
1193+
safecheck,
1194+
assign_cols_by_name,
1195+
arrow_cast,
1196+
):
1197+
super().__init__(
1198+
timezone=timezone,
1199+
safecheck=safecheck,
1200+
assign_cols_by_name=False,
1201+
arrow_cast=True,
1202+
)
1203+
self._timezone = timezone
1204+
self._safecheck = safecheck
1205+
self._assign_cols_by_name = assign_cols_by_name
1206+
self._arrow_cast = arrow_cast
1207+
1208+
def load_stream(self, stream):
1209+
"""
1210+
Yield an iterator that produces one list of column arrays per batch.
1211+
Each group yields Iterator[List[pa.Array]], allowing UDF to process batches one by one
1212+
without consuming all batches upfront.
1213+
"""
1214+
dataframes_in_group = None
1215+
1216+
while dataframes_in_group is None or dataframes_in_group > 0:
1217+
dataframes_in_group = read_int(stream)
1218+
1219+
if dataframes_in_group == 1:
1220+
# Lazily read and convert Arrow batches one at a time from the stream
1221+
# This avoids loading all batches into memory for the group
1222+
batch_iter = (
1223+
batch.columns for batch in ArrowStreamSerializer.load_stream(self, stream)
1224+
)
1225+
yield batch_iter
1226+
# Make sure the batches are fully iterated before getting the next group
1227+
for _ in batch_iter:
1228+
pass
1229+
1230+
elif dataframes_in_group != 0:
1231+
raise PySparkValueError(
1232+
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
1233+
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
1234+
)
1235+
1236+
def __repr__(self):
1237+
return "ArrowStreamAggArrowIterUDFSerializer"
1238+
1239+
11881240
# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF
11891241
class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
11901242
def __init__(

python/pyspark/sql/pandas/typehints.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ArrowScalarUDFType,
3030
ArrowScalarIterUDFType,
3131
ArrowGroupedAggUDFType,
32+
ArrowGroupedAggIterUDFType,
3233
ArrowGroupedMapIterUDFType,
3334
ArrowGroupedMapUDFType,
3435
ArrowGroupedMapFunction,
@@ -156,7 +157,14 @@ def infer_pandas_eval_type(
156157

157158
def infer_arrow_eval_type(
158159
sig: Signature, type_hints: Dict[str, Any]
159-
) -> Optional[Union["ArrowScalarUDFType", "ArrowScalarIterUDFType", "ArrowGroupedAggUDFType"]]:
160+
) -> Optional[
161+
Union[
162+
"ArrowScalarUDFType",
163+
"ArrowScalarIterUDFType",
164+
"ArrowGroupedAggUDFType",
165+
"ArrowGroupedAggIterUDFType",
166+
]
167+
]:
160168
"""
161169
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
162170
:class:`inspect.Signature` instance and type hints.
@@ -235,6 +243,41 @@ def infer_arrow_eval_type(
235243
if is_array_agg:
236244
return ArrowUDFType.GROUPED_AGG
237245

246+
# Iterator[Tuple[pa.Array, ...]] -> Any
247+
is_iterator_tuple_array_agg = (
248+
len(parameters_sig) == 1
249+
and check_iterator_annotation( # Iterator
250+
parameters_sig[0],
251+
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
252+
a,
253+
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
254+
),
255+
)
256+
and (
257+
return_annotation != pa.Array
258+
and not check_iterator_annotation(return_annotation)
259+
and not check_tuple_annotation(return_annotation)
260+
)
261+
)
262+
if is_iterator_tuple_array_agg:
263+
return ArrowUDFType.GROUPED_AGG_ITER
264+
265+
# Iterator[pa.Array] -> Any
266+
is_iterator_array_agg = (
267+
len(parameters_sig) == 1
268+
and check_iterator_annotation(
269+
parameters_sig[0],
270+
parameter_check_func=lambda a: a == pa.Array,
271+
)
272+
and (
273+
return_annotation != pa.Array
274+
and not check_iterator_annotation(return_annotation)
275+
and not check_tuple_annotation(return_annotation)
276+
)
277+
)
278+
if is_iterator_array_agg:
279+
return ArrowUDFType.GROUPED_AGG_ITER
280+
238281
return None
239282

240283

@@ -249,6 +292,7 @@ def infer_eval_type(
249292
"ArrowScalarUDFType",
250293
"ArrowScalarIterUDFType",
251294
"ArrowGroupedAggUDFType",
295+
"ArrowGroupedAggIterUDFType",
252296
]:
253297
"""
254298
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
@@ -264,6 +308,7 @@ def infer_eval_type(
264308
"ArrowScalarUDFType",
265309
"ArrowScalarIterUDFType",
266310
"ArrowGroupedAggUDFType",
311+
"ArrowGroupedAggIterUDFType",
267312
]
268313
] = None
269314
if kind == "pandas":
@@ -295,6 +340,7 @@ def infer_eval_type_for_udf( # type: ignore[no-untyped-def]
295340
"ArrowScalarUDFType",
296341
"ArrowScalarIterUDFType",
297342
"ArrowGroupedAggUDFType",
343+
"ArrowGroupedAggIterUDFType",
298344
]
299345
]:
300346
argspec = getfullargspec(f)

0 commit comments

Comments
 (0)