@@ -610,48 +610,68 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
610
610
} )
611
611
}
612
612
613
- /// Helper function to find the appropriate window function. First, if a session
614
- /// context is defined check it's registered functions. If no context is defined,
615
- /// attempt to find from all default functions. Lastly, as a fall back attempt
616
- /// to use built in window functions, which are being deprecated.
613
+ /// Helper function to find the appropriate window function.
614
+ ///
615
+ /// Search procedure:
616
+ /// 1) If a session context is provided:
617
+ /// a) search User Defined Aggregate Functions (UDAFs)
618
+ /// b) search registered window functions
619
+ /// c) search registered aggregate functions
620
+ /// 2) If no function has been found, search default aggregate functions.
621
+ /// 3) Lastly, as a fall back attempt, search built in window functions, which are being deprecated.
617
622
fn find_window_fn ( name : & str , ctx : Option < PySessionContext > ) -> PyResult < WindowFunctionDefinition > {
618
- let mut maybe_fn = match & ctx {
619
- Some ( ctx) => {
620
- let session_state = ctx. ctx . state ( ) ;
621
-
622
- match session_state. window_functions ( ) . contains_key ( name) {
623
- true => session_state
624
- . window_functions ( )
625
- . get ( name)
626
- . map ( |f| WindowFunctionDefinition :: WindowUDF ( f. clone ( ) ) ) ,
627
- false => session_state
628
- . aggregate_functions ( )
629
- . get ( name)
630
- . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ,
631
- }
623
+ if let Some ( ctx) = ctx {
624
+ // search UDAFs
625
+ let udaf = ctx
626
+ . ctx
627
+ . udaf ( name)
628
+ . map ( WindowFunctionDefinition :: AggregateUDF )
629
+ . ok ( ) ;
630
+
631
+ if let Some ( udaf) = udaf {
632
+ return Ok ( udaf) ;
632
633
}
633
- None => {
634
- let default_aggregate_fns = all_default_aggregate_functions ( ) ;
635
634
636
- default_aggregate_fns
637
- . iter ( )
638
- . find ( |v| v. aliases ( ) . contains ( & name. to_string ( ) ) )
639
- . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) )
635
+ let session_state = ctx. ctx . state ( ) ;
636
+
637
+ // search registered window functions
638
+ let window_fn = session_state
639
+ . window_functions ( )
640
+ . get ( name)
641
+ . map ( |f| WindowFunctionDefinition :: WindowUDF ( f. clone ( ) ) ) ;
642
+
643
+ if let Some ( window_fn) = window_fn {
644
+ return Ok ( window_fn) ;
640
645
}
641
- } ;
642
646
643
- if maybe_fn. is_none ( ) {
644
- maybe_fn = find_df_window_func ( name) . or_else ( || {
645
- ctx. and_then ( |ctx| {
646
- ctx. ctx
647
- . udaf ( name)
648
- . map ( WindowFunctionDefinition :: AggregateUDF )
649
- . ok ( )
650
- } )
651
- } ) ;
647
+ // search registered aggregate functions
648
+ let agg_fn = session_state
649
+ . aggregate_functions ( )
650
+ . get ( name)
651
+ . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ;
652
+
653
+ if let Some ( agg_fn) = agg_fn {
654
+ return Ok ( agg_fn) ;
655
+ }
656
+ }
657
+
658
+ // search default aggregate functions
659
+ let agg_fn = all_default_aggregate_functions ( )
660
+ . iter ( )
661
+ . find ( |v| v. aliases ( ) . contains ( & name. to_string ( ) ) )
662
+ . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ;
663
+
664
+ if let Some ( agg_fn) = agg_fn {
665
+ return Ok ( agg_fn) ;
652
666
}
653
667
654
- maybe_fn. ok_or ( DataFusionError :: Common ( format ! ( "window function `{name}` not found" ) ) . into ( ) )
668
+ // search built in window functions (soon to be deprecated)
669
+ let df_window_func = find_df_window_func ( name) ;
670
+ if let Some ( df_window_func) = df_window_func {
671
+ return Ok ( df_window_func) ;
672
+ }
673
+
674
+ Err ( DataFusionError :: Common ( format ! ( "window function `{name}` not found" ) ) . into ( ) )
655
675
}
656
676
657
677
/// Creates a new Window function expression
@@ -1206,4 +1226,4 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1206
1226
m. add_wrapped ( wrap_pyfunction ! ( flatten) ) ?;
1207
1227
1208
1228
Ok ( ( ) )
1209
- }
1229
+ }
0 commit comments