Skip to content

Commit 5892585

Browse files
authored
Allow users to pass a single expression instead of a list of expressions for partition_by and order_by (#1187)
1 parent f947941 commit 5892585

File tree

4 files changed

+146
-61
lines changed

4 files changed

+146
-61
lines changed

python/datafusion/expr.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,11 @@
216216

217217

218218
def expr_list_to_raw_expr_list(
219-
expr_list: Optional[list[Expr]],
219+
expr_list: Optional[list[Expr] | Expr],
220220
) -> Optional[list[expr_internal.Expr]]:
221221
"""Helper function to convert an optional list to raw expressions."""
222+
if isinstance(expr_list, Expr):
223+
expr_list = [expr_list]
222224
return [e.expr for e in expr_list] if expr_list is not None else None
223225

224226

@@ -230,9 +232,11 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
230232

231233

232234
def sort_list_to_raw_sort_list(
233-
sort_list: Optional[list[Expr | SortExpr]],
235+
sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
234236
) -> Optional[list[expr_internal.SortExpr]]:
235237
"""Helper function to return an optional sort list to raw variant."""
238+
if isinstance(sort_list, (Expr, SortExpr)):
239+
sort_list = [sort_list]
236240
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None
237241

238242

@@ -1140,9 +1144,9 @@ class Window:
11401144

11411145
def __init__(
11421146
self,
1143-
partition_by: Optional[list[Expr]] = None,
1147+
partition_by: Optional[list[Expr] | Expr] = None,
11441148
window_frame: Optional[WindowFrame] = None,
1145-
order_by: Optional[list[SortExpr | Expr]] = None,
1149+
order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
11461150
null_treatment: Optional[NullTreatment] = None,
11471151
) -> None:
11481152
"""Construct a window definition.

python/datafusion/functions.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ def when(when: Expr, then: Expr) -> CaseBuilder:
428428
def window(
429429
name: str,
430430
args: list[Expr],
431-
partition_by: list[Expr] | None = None,
432-
order_by: list[Expr | SortExpr] | None = None,
431+
partition_by: list[Expr] | Expr | None = None,
432+
order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None,
433433
window_frame: WindowFrame | None = None,
434434
ctx: SessionContext | None = None,
435435
) -> Expr:
@@ -442,11 +442,11 @@ def window(
442442
df.select(functions.lag(col("a")).partition_by(col("b")).build())
443443
"""
444444
args = [a.expr for a in args]
445-
partition_by = expr_list_to_raw_expr_list(partition_by)
445+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
446446
order_by_raw = sort_list_to_raw_sort_list(order_by)
447447
window_frame = window_frame.window_frame if window_frame is not None else None
448448
ctx = ctx.ctx if ctx is not None else None
449-
return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx))
449+
return Expr(f.window(name, args, partition_by_raw, order_by_raw, window_frame, ctx))
450450

451451

452452
# scalar functions
@@ -1723,7 +1723,7 @@ def array_agg(
17231723
expression: Expr,
17241724
distinct: bool = False,
17251725
filter: Optional[Expr] = None,
1726-
order_by: Optional[list[Expr | SortExpr]] = None,
1726+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
17271727
) -> Expr:
17281728
"""Aggregate values into an array.
17291729
@@ -2222,7 +2222,7 @@ def regr_syy(
22222222
def first_value(
22232223
expression: Expr,
22242224
filter: Optional[Expr] = None,
2225-
order_by: Optional[list[Expr | SortExpr]] = None,
2225+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
22262226
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22272227
) -> Expr:
22282228
"""Returns the first value in a group of values.
@@ -2254,7 +2254,7 @@ def first_value(
22542254
def last_value(
22552255
expression: Expr,
22562256
filter: Optional[Expr] = None,
2257-
order_by: Optional[list[Expr | SortExpr]] = None,
2257+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
22582258
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22592259
) -> Expr:
22602260
"""Returns the last value in a group of values.
@@ -2287,7 +2287,7 @@ def nth_value(
22872287
expression: Expr,
22882288
n: int,
22892289
filter: Optional[Expr] = None,
2290-
order_by: Optional[list[Expr | SortExpr]] = None,
2290+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
22912291
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22922292
) -> Expr:
22932293
"""Returns the n-th value in a group of values.
@@ -2407,8 +2407,8 @@ def lead(
24072407
arg: Expr,
24082408
shift_offset: int = 1,
24092409
default_value: Optional[Any] = None,
2410-
partition_by: Optional[list[Expr]] = None,
2411-
order_by: Optional[list[Expr | SortExpr]] = None,
2410+
partition_by: Optional[list[Expr] | Expr] = None,
2411+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
24122412
) -> Expr:
24132413
"""Create a lead window function.
24142414
@@ -2442,17 +2442,15 @@ def lead(
24422442
if not isinstance(default_value, pa.Scalar) and default_value is not None:
24432443
default_value = pa.scalar(default_value)
24442444

2445-
partition_cols = (
2446-
[col.expr for col in partition_by] if partition_by is not None else None
2447-
)
2445+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
24482446
order_by_raw = sort_list_to_raw_sort_list(order_by)
24492447

24502448
return Expr(
24512449
f.lead(
24522450
arg.expr,
24532451
shift_offset,
24542452
default_value,
2455-
partition_by=partition_cols,
2453+
partition_by=partition_by_raw,
24562454
order_by=order_by_raw,
24572455
)
24582456
)
@@ -2462,8 +2460,8 @@ def lag(
24622460
arg: Expr,
24632461
shift_offset: int = 1,
24642462
default_value: Optional[Any] = None,
2465-
partition_by: Optional[list[Expr]] = None,
2466-
order_by: Optional[list[Expr | SortExpr]] = None,
2463+
partition_by: Optional[list[Expr] | Expr] = None,
2464+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
24672465
) -> Expr:
24682466
"""Create a lag window function.
24692467
@@ -2494,25 +2492,23 @@ def lag(
24942492
if not isinstance(default_value, pa.Scalar):
24952493
default_value = pa.scalar(default_value)
24962494

2497-
partition_cols = (
2498-
[col.expr for col in partition_by] if partition_by is not None else None
2499-
)
2495+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
25002496
order_by_raw = sort_list_to_raw_sort_list(order_by)
25012497

25022498
return Expr(
25032499
f.lag(
25042500
arg.expr,
25052501
shift_offset,
25062502
default_value,
2507-
partition_by=partition_cols,
2503+
partition_by=partition_by_raw,
25082504
order_by=order_by_raw,
25092505
)
25102506
)
25112507

25122508

25132509
def row_number(
2514-
partition_by: Optional[list[Expr]] = None,
2515-
order_by: Optional[list[Expr | SortExpr]] = None,
2510+
partition_by: Optional[list[Expr] | Expr] = None,
2511+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
25162512
) -> Expr:
25172513
"""Create a row number window function.
25182514
@@ -2533,22 +2529,20 @@ def row_number(
25332529
partition_by: Expressions to partition the window frame on.
25342530
order_by: Set ordering within the window frame.
25352531
"""
2536-
partition_cols = (
2537-
[col.expr for col in partition_by] if partition_by is not None else None
2538-
)
2532+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
25392533
order_by_raw = sort_list_to_raw_sort_list(order_by)
25402534

25412535
return Expr(
25422536
f.row_number(
2543-
partition_by=partition_cols,
2537+
partition_by=partition_by_raw,
25442538
order_by=order_by_raw,
25452539
)
25462540
)
25472541

25482542

25492543
def rank(
2550-
partition_by: Optional[list[Expr]] = None,
2551-
order_by: Optional[list[Expr | SortExpr]] = None,
2544+
partition_by: Optional[list[Expr] | Expr] = None,
2545+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
25522546
) -> Expr:
25532547
"""Create a rank window function.
25542548
@@ -2574,22 +2568,20 @@ def rank(
25742568
partition_by: Expressions to partition the window frame on.
25752569
order_by: Set ordering within the window frame.
25762570
"""
2577-
partition_cols = (
2578-
[col.expr for col in partition_by] if partition_by is not None else None
2579-
)
2571+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
25802572
order_by_raw = sort_list_to_raw_sort_list(order_by)
25812573

25822574
return Expr(
25832575
f.rank(
2584-
partition_by=partition_cols,
2576+
partition_by=partition_by_raw,
25852577
order_by=order_by_raw,
25862578
)
25872579
)
25882580

25892581

25902582
def dense_rank(
2591-
partition_by: Optional[list[Expr]] = None,
2592-
order_by: Optional[list[Expr | SortExpr]] = None,
2583+
partition_by: Optional[list[Expr] | Expr] = None,
2584+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
25932585
) -> Expr:
25942586
"""Create a dense_rank window function.
25952587
@@ -2610,22 +2602,20 @@ def dense_rank(
26102602
partition_by: Expressions to partition the window frame on.
26112603
order_by: Set ordering within the window frame.
26122604
"""
2613-
partition_cols = (
2614-
[col.expr for col in partition_by] if partition_by is not None else None
2615-
)
2605+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
26162606
order_by_raw = sort_list_to_raw_sort_list(order_by)
26172607

26182608
return Expr(
26192609
f.dense_rank(
2620-
partition_by=partition_cols,
2610+
partition_by=partition_by_raw,
26212611
order_by=order_by_raw,
26222612
)
26232613
)
26242614

26252615

26262616
def percent_rank(
2627-
partition_by: Optional[list[Expr]] = None,
2628-
order_by: Optional[list[Expr | SortExpr]] = None,
2617+
partition_by: Optional[list[Expr] | Expr] = None,
2618+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
26292619
) -> Expr:
26302620
"""Create a percent_rank window function.
26312621
@@ -2647,22 +2637,20 @@ def percent_rank(
26472637
partition_by: Expressions to partition the window frame on.
26482638
order_by: Set ordering within the window frame.
26492639
"""
2650-
partition_cols = (
2651-
[col.expr for col in partition_by] if partition_by is not None else None
2652-
)
2640+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
26532641
order_by_raw = sort_list_to_raw_sort_list(order_by)
26542642

26552643
return Expr(
26562644
f.percent_rank(
2657-
partition_by=partition_cols,
2645+
partition_by=partition_by_raw,
26582646
order_by=order_by_raw,
26592647
)
26602648
)
26612649

26622650

26632651
def cume_dist(
2664-
partition_by: Optional[list[Expr]] = None,
2665-
order_by: Optional[list[Expr | SortExpr]] = None,
2652+
partition_by: Optional[list[Expr] | Expr] = None,
2653+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
26662654
) -> Expr:
26672655
"""Create a cumulative distribution window function.
26682656
@@ -2684,23 +2672,21 @@ def cume_dist(
26842672
partition_by: Expressions to partition the window frame on.
26852673
order_by: Set ordering within the window frame.
26862674
"""
2687-
partition_cols = (
2688-
[col.expr for col in partition_by] if partition_by is not None else None
2689-
)
2675+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
26902676
order_by_raw = sort_list_to_raw_sort_list(order_by)
26912677

26922678
return Expr(
26932679
f.cume_dist(
2694-
partition_by=partition_cols,
2680+
partition_by=partition_by_raw,
26952681
order_by=order_by_raw,
26962682
)
26972683
)
26982684

26992685

27002686
def ntile(
27012687
groups: int,
2702-
partition_by: Optional[list[Expr]] = None,
2703-
order_by: Optional[list[Expr | SortExpr]] = None,
2688+
partition_by: Optional[list[Expr] | Expr] = None,
2689+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
27042690
) -> Expr:
27052691
"""Create a n-tile window function.
27062692
@@ -2725,15 +2711,13 @@ def ntile(
27252711
partition_by: Expressions to partition the window frame on.
27262712
order_by: Set ordering within the window frame.
27272713
"""
2728-
partition_cols = (
2729-
[col.expr for col in partition_by] if partition_by is not None else None
2730-
)
2714+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
27312715
order_by_raw = sort_list_to_raw_sort_list(order_by)
27322716

27332717
return Expr(
27342718
f.ntile(
27352719
Expr.literal(groups).expr,
2736-
partition_by=partition_cols,
2720+
partition_by=partition_by_raw,
27372721
order_by=order_by_raw,
27382722
)
27392723
)
@@ -2743,7 +2727,7 @@ def string_agg(
27432727
expression: Expr,
27442728
delimiter: str,
27452729
filter: Optional[Expr] = None,
2746-
order_by: Optional[list[Expr | SortExpr]] = None,
2730+
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
27472731
) -> Expr:
27482732
"""Concatenates the input strings.
27492733

python/tests/test_aggregation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
154154
pa.array([[6, 4, 4]]),
155155
False,
156156
),
157+
(
158+
f.array_agg(column("b"), order_by=column("c")),
159+
pa.array([[6, 4, 4]]),
160+
False,
161+
),
157162
(f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False),
158163
(f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), False),
159164
(f.count(column("b"), distinct=True), pa.array([2]), False),
@@ -329,6 +334,15 @@ def test_bit_and_bool_fns(df, name, expr, result):
329334
),
330335
[None, None],
331336
),
337+
(
338+
"first_value_no_list_order_by",
339+
f.first_value(
340+
column("b"),
341+
order_by=column("b"),
342+
null_treatment=NullTreatment.RESPECT_NULLS,
343+
),
344+
[None, None],
345+
),
332346
(
333347
"first_value_ignore_null",
334348
f.first_value(
@@ -343,6 +357,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
343357
f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]),
344358
[0, 4],
345359
),
360+
(
361+
"last_value_no_list_ordered",
362+
f.last_value(column("a"), order_by=column("a")),
363+
[3, 6],
364+
),
346365
(
347366
"last_value_with_null",
348367
f.last_value(
@@ -366,6 +385,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
366385
f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]),
367386
[2, 5],
368387
),
388+
(
389+
"nth_value_no_list_ordered",
390+
f.nth_value(column("a"), 2, order_by=column("a").sort(ascending=False)),
391+
[2, 5],
392+
),
369393
(
370394
"nth_value_with_null",
371395
f.nth_value(
@@ -414,6 +438,11 @@ def test_first_last_value(df_partitioned, name, expr, result) -> None:
414438
f.string_agg(column("a"), ",", order_by=[column("b")]),
415439
"one,three,two,two",
416440
),
441+
(
442+
"string_agg",
443+
f.string_agg(column("a"), ",", order_by=column("b")),
444+
"one,three,two,two",
445+
),
417446
],
418447
)
419448
def test_string_agg(name, expr, result) -> None:

0 commit comments

Comments
 (0)