Skip to content

Commit 0228de7

Browse files
ueshinhaoyangeng-db
authored andcommitted
[SPARK-52861][PYTHON] Skip Row object creation in Arrow-optimized UDTF execution
### What changes were proposed in this pull request? Skips `Row` object creation in Arrow-optimized UDTF execution. ### Why are the changes needed? The `Row` object creation is used in Arrow-optimized UDTF execution, although it's expensive, but not necessary. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests, and manual benchmarks. ```py def profile(f, *args, _n=10, **kwargs): import cProfile import pstats import gc st = None for _ in range(5): f(*args, **kwargs) for _ in range(_n): gc.collect() with cProfile.Profile() as pr: ret = f(*args, **kwargs) if st is None: st = pstats.Stats(pr) else: st.add(pstats.Stats(pr)) st.sort_stats("time", "cumulative").print_stats() return ret from pyspark.sql.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion from pyspark.sql.types import * data = [ (i if i % 1000 else None, str(i)) for i in range(1000000) ] schema = ( StructType() .add("i", IntegerType(), nullable=True) .add("s", StringType(), nullable=True) ) def to_arrow(): return LocalDataToArrowConversion.convert(data, schema, use_large_var_types=False) def from_arrow(tbl, return_as_tuples): return ArrowTableToRowsConversion.convert(tbl, schema, return_as_tuples=return_as_tuples) tbl = to_arrow() profile(from_arrow, tbl, return_as_tuples=False) profile(from_arrow, tbl, return_as_tuples=True) ``` - before (`return_as_tuples=False`) ``` 60655810 function calls in 14.112 seconds ``` - after (`return_as_tuples=True`) ``` 20328060 function calls in 5.613 seconds ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51546 from ueshin/issues/SPARK-52861/skip_row_creation. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Takuya Ueshin <[email protected]>
1 parent 46e0a28 commit 0228de7

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

python/pyspark/sql/conversion.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import array
1919
import datetime
2020
import decimal
21-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, overload
21+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload
2222

2323
from pyspark.errors import PySparkValueError
2424
from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names, to_arrow_schema
@@ -446,7 +446,7 @@ def to_row(item: Any) -> tuple:
446446

447447
return pa.Table.from_arrays(pylist, schema=pa_schema)
448448
else:
449-
return pa.table({"_": [None] * len(rows)}).drop("_")
449+
return pa.Table.from_struct_array(pa.array([{}] * len(rows)))
450450

451451

452452
class ArrowTableToRowsConversion:
@@ -687,8 +687,24 @@ def convert_variant(value: Any) -> Any:
687687
else:
688688
return lambda value: value
689689

690+
@overload
691+
@staticmethod
692+
def convert( # type: ignore[overload-overlap]
693+
table: "pa.Table", schema: StructType
694+
) -> List[Row]:
695+
pass
696+
697+
@overload
690698
@staticmethod
691-
def convert(table: "pa.Table", schema: StructType) -> List[Row]:
699+
def convert(
700+
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = True
701+
) -> List[tuple]:
702+
pass
703+
704+
@staticmethod # type: ignore[misc]
705+
def convert(
706+
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
707+
) -> List[Union[Row, tuple]]:
692708
require_minimum_pyarrow_version()
693709
import pyarrow as pa
694710

@@ -709,8 +725,14 @@ def convert(table: "pa.Table", schema: StructType) -> List[Row]:
709725
for column, conv in zip(table.columns, field_converters)
710726
]
711727

712-
rows = [_create_row(fields, tuple(cols)) for cols in zip(*columnar_data)]
728+
if return_as_tuples:
729+
rows = [tuple(cols) for cols in zip(*columnar_data)]
730+
else:
731+
rows = [_create_row(fields, tuple(cols)) for cols in zip(*columnar_data)]
713732
assert len(rows) == table.num_rows, f"{len(rows)}, {table.num_rows}"
714733
return rows
715734
else:
716-
return [_create_row(fields, tuple())] * table.num_rows
735+
if return_as_tuples:
736+
return [tuple()] * table.num_rows
737+
else:
738+
return [_create_row(fields, tuple())] * table.num_rows

python/pyspark/worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,9 +1694,10 @@ def evaluate(*args: pa.ChunkedArray):
16941694
names = [f"_{n}" for n in range(len(list_args))]
16951695
t = pa.Table.from_arrays(list_args, names=names)
16961696
schema = from_arrow_schema(t.schema, prefers_large_var_types)
1697-
rows = ArrowTableToRowsConversion.convert(t, schema=schema)
1697+
rows = ArrowTableToRowsConversion.convert(
1698+
t, schema=schema, return_as_tuples=True
1699+
)
16981700
for row in rows:
1699-
row = tuple(row) # type: ignore[assignment]
17001701
for batch in convert_to_arrow(func(*row)):
17011702
yield verify_result(batch), arrow_return_type
17021703

0 commit comments

Comments
 (0)