Skip to content

Commit 06e2704

Browse files
refactor find_window_fn for debug clarity
1 parent a521310 commit 06e2704

File tree

1 file changed

+56
-36
lines changed

1 file changed

+56
-36
lines changed

src/functions.rs

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -610,48 +610,68 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
610610
})
611611
}
612612

613-
/// Helper function to find the appropriate window function. First, if a session
614-
/// context is defined check it's registered functions. If no context is defined,
615-
/// attempt to find from all default functions. Lastly, as a fall back attempt
616-
/// to use built in window functions, which are being deprecated.
613+
/// Helper function to find the appropriate window function.
614+
///
615+
/// Search procedure:
616+
/// 1) If a session context is provided:
617+
/// a) search User Defined Aggregate Functions (UDAFs)
618+
/// b) search registered window functions
619+
/// c) search registered aggregate functions
620+
/// 2) If no function has been found, search default aggregate functions.
621+
/// 3) Lastly, as a fall back attempt, search built in window functions, which are being deprecated.
617622
fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowFunctionDefinition> {
618-
let mut maybe_fn = match &ctx {
619-
Some(ctx) => {
620-
let session_state = ctx.ctx.state();
621-
622-
match session_state.window_functions().contains_key(name) {
623-
true => session_state
624-
.window_functions()
625-
.get(name)
626-
.map(|f| WindowFunctionDefinition::WindowUDF(f.clone())),
627-
false => session_state
628-
.aggregate_functions()
629-
.get(name)
630-
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())),
631-
}
623+
if let Some(ctx) = ctx {
624+
// search UDAFs
625+
let udaf = ctx
626+
.ctx
627+
.udaf(name)
628+
.map(WindowFunctionDefinition::AggregateUDF)
629+
.ok();
630+
631+
if let Some(udaf) = udaf {
632+
return Ok(udaf);
632633
}
633-
None => {
634-
let default_aggregate_fns = all_default_aggregate_functions();
635634

636-
default_aggregate_fns
637-
.iter()
638-
.find(|v| v.aliases().contains(&name.to_string()))
639-
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()))
635+
let session_state = ctx.ctx.state();
636+
637+
// search registered window functions
638+
let window_fn = session_state
639+
.window_functions()
640+
.get(name)
641+
.map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));
642+
643+
if let Some(window_fn) = window_fn {
644+
return Ok(window_fn);
640645
}
641-
};
642646

643-
if maybe_fn.is_none() {
644-
maybe_fn = find_df_window_func(name).or_else(|| {
645-
ctx.and_then(|ctx| {
646-
ctx.ctx
647-
.udaf(name)
648-
.map(WindowFunctionDefinition::AggregateUDF)
649-
.ok()
650-
})
651-
});
647+
// search registered aggregate functions
648+
let agg_fn = session_state
649+
.aggregate_functions()
650+
.get(name)
651+
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
652+
653+
if let Some(agg_fn) = agg_fn {
654+
return Ok(agg_fn);
655+
}
656+
}
657+
658+
// search default aggregate functions
659+
let agg_fn = all_default_aggregate_functions()
660+
.iter()
661+
.find(|v| v.aliases().contains(&name.to_string()))
662+
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
663+
664+
if let Some(agg_fn) = agg_fn {
665+
return Ok(agg_fn);
652666
}
653667

654-
maybe_fn.ok_or(DataFusionError::Common(format!("window function `{name}` not found")).into())
668+
// search built in window functions (soon to be deprecated)
669+
let df_window_func = find_df_window_func(name);
670+
if let Some(df_window_func) = df_window_func {
671+
return Ok(df_window_func);
672+
}
673+
674+
Err(DataFusionError::Common(format!("window function `{name}` not found")).into())
655675
}
656676

657677
/// Creates a new Window function expression
@@ -1206,4 +1226,4 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
12061226
m.add_wrapped(wrap_pyfunction!(flatten))?;
12071227

12081228
Ok(())
1209-
}
1229+
}

0 commit comments

Comments
 (0)