Skip to content

Commit 19090f4

Browse files
kosiewJefffrey
authored andcommitted
Refactor nvl2 Function to Support Lazy Evaluation and Simplification via CASE Expression (apache#18191)
## Which issue does this PR close? * Closes apache#17983 ## Rationale for this change The current implementation of the `nvl2` function in DataFusion eagerly evaluates all its arguments, which can lead to unnecessary computation and incorrect behavior when handling expressions that should only be conditionally evaluated. This PR introduces **lazy evaluation** for `nvl2`, aligning its behavior with other conditional expressions like `coalesce` and improving both performance and correctness. This change also introduces a **simplification rule** that rewrites `nvl2` expressions into equivalent `CASE` statements, allowing for better optimization during query planning and execution. ## What changes are included in this PR? * Refactored `nvl2` implementation in `datafusion/functions/src/core/nvl2.rs`: * Added support for **short-circuit (lazy) evaluation** using `short_circuits()`. * Implemented **simplify()** method to rewrite expressions into `CASE` form. * Introduced **return_field_from_args()** for correct nullability and type inference. * Replaced the previous eager `nvl2_func()` logic with an optimized, more declarative approach. * Added comprehensive **unit tests**: * `test_nvl2_short_circuit` in `dataframe_functions.rs` verifies correct short-circuit behavior. * `test_create_physical_expr_nvl2` in `expr_api/mod.rs` validates physical expression creation and output correctness. ## Are these changes tested? ✅ Yes, multiple new tests are included: * **`test_nvl2_short_circuit`** ensures `nvl2` does not evaluate unnecessary branches. * **`test_create_physical_expr_nvl2`** checks the correctness of evaluation and type coercion behavior. All existing and new tests pass successfully. ## Are there any user-facing changes? Yes, but they are **non-breaking** and **performance-enhancing**: * `nvl2` now evaluates lazily, meaning only the required branch is computed based on the nullity of the test expression. * Expression simplification will yield more optimized query plans. There are **no API-breaking changes**. However, users may observe improved performance and reduced computation for expressions involving `nvl2`. --------- Co-authored-by: Jeffrey Vo <[email protected]>
1 parent 4f42c32 commit 19090f4

File tree

3 files changed

+84
-48
lines changed

3 files changed

+84
-48
lines changed

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,33 @@ async fn test_nvl2() -> Result<()> {
274274

275275
Ok(())
276276
}
277+
278+
#[tokio::test]
279+
async fn test_nvl2_short_circuit() -> Result<()> {
280+
let expr = nvl2(
281+
col("a"),
282+
arrow_cast(lit("1"), lit("Int32")),
283+
arrow_cast(col("a"), lit("Int32")),
284+
);
285+
286+
let batches = get_batches(expr).await?;
287+
288+
assert_snapshot!(
289+
batches_to_string(&batches),
290+
@r#"
291+
+-----------------------------------------------------------------------------------+
292+
| nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32"))) |
293+
+-----------------------------------------------------------------------------------+
294+
| 1 |
295+
| 1 |
296+
| 1 |
297+
| 1 |
298+
+-----------------------------------------------------------------------------------+
299+
"#
300+
);
301+
302+
Ok(())
303+
}
277304
#[tokio::test]
278305
async fn test_fn_arrow_typeof() -> Result<()> {
279306
let expr = arrow_typeof(col("l"));

datafusion/core/tests/expr_api/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,26 @@ async fn test_create_physical_expr() {
320320
create_simplified_expr_test(lit(1i32) + lit(2i32), "3");
321321
}
322322

323+
#[test]
324+
fn test_create_physical_expr_nvl2() {
325+
let batch = &TEST_BATCH;
326+
let df_schema = DFSchema::try_from(batch.schema()).unwrap();
327+
let ctx = SessionContext::new();
328+
329+
let expect_err = |expr| {
330+
let physical_expr = ctx.create_physical_expr(expr, &df_schema).unwrap();
331+
let err = physical_expr.evaluate(batch).unwrap_err();
332+
assert!(
333+
err.to_string()
334+
.contains("nvl2 should have been simplified to case"),
335+
"unexpected error: {err:?}"
336+
);
337+
};
338+
339+
expect_err(nvl2(col("i"), lit(1i64), lit(0i64)));
340+
expect_err(nvl2(lit(1i64), col("i"), lit(0i64)));
341+
}
342+
323343
#[tokio::test]
324344
async fn test_create_physical_expr_coercion() {
325345
// create_physical_expr does apply type coercion and unwrapping in cast

datafusion/functions/src/core/nvl2.rs

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::Array;
19-
use arrow::compute::is_not_null;
20-
use arrow::compute::kernels::zip::zip;
21-
use arrow::datatypes::DataType;
18+
use arrow::datatypes::{DataType, Field, FieldRef};
2219
use datafusion_common::{internal_err, utils::take_function_args, Result};
2320
use datafusion_expr::{
24-
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
25-
ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
21+
conditional_expressions::CaseBuilder,
22+
simplify::{ExprSimplifyResult, SimplifyInfo},
23+
type_coercion::binary::comparison_coercion,
24+
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
25+
ScalarUDFImpl, Signature, Volatility,
2626
};
2727
use datafusion_macros::user_doc;
28-
use std::sync::Arc;
2928

3029
#[user_doc(
3130
doc_section(label = "Conditional Functions"),
@@ -95,8 +94,37 @@ impl ScalarUDFImpl for NVL2Func {
9594
Ok(arg_types[1].clone())
9695
}
9796

98-
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99-
nvl2_func(&args.args)
97+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
98+
let nullable =
99+
args.arg_fields[1].is_nullable() || args.arg_fields[2].is_nullable();
100+
let return_type = args.arg_fields[1].data_type().clone();
101+
Ok(Field::new(self.name(), return_type, nullable).into())
102+
}
103+
104+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105+
internal_err!("nvl2 should have been simplified to case")
106+
}
107+
108+
fn simplify(
109+
&self,
110+
args: Vec<Expr>,
111+
_info: &dyn SimplifyInfo,
112+
) -> Result<ExprSimplifyResult> {
113+
let [test, if_non_null, if_null] = take_function_args(self.name(), args)?;
114+
115+
let expr = CaseBuilder::new(
116+
None,
117+
vec![test.is_not_null()],
118+
vec![if_non_null],
119+
Some(Box::new(if_null)),
120+
)
121+
.end()?;
122+
123+
Ok(ExprSimplifyResult::Simplified(expr))
124+
}
125+
126+
fn short_circuits(&self) -> bool {
127+
true
100128
}
101129

102130
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
@@ -123,42 +151,3 @@ impl ScalarUDFImpl for NVL2Func {
123151
self.doc()
124152
}
125153
}
126-
127-
fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
128-
let mut len = 1;
129-
let mut is_array = false;
130-
for arg in args {
131-
if let ColumnarValue::Array(array) = arg {
132-
len = array.len();
133-
is_array = true;
134-
break;
135-
}
136-
}
137-
if is_array {
138-
let args = args
139-
.iter()
140-
.map(|arg| match arg {
141-
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len),
142-
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
143-
})
144-
.collect::<Result<Vec<_>>>()?;
145-
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
146-
let to_apply = is_not_null(&tested)?;
147-
let value = zip(&to_apply, &if_non_null, &if_null)?;
148-
Ok(ColumnarValue::Array(value))
149-
} else {
150-
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
151-
match &tested {
152-
ColumnarValue::Array(_) => {
153-
internal_err!("except Scalar value, but got Array")
154-
}
155-
ColumnarValue::Scalar(scalar) => {
156-
if scalar.is_null() {
157-
Ok(if_null.clone())
158-
} else {
159-
Ok(if_non_null.clone())
160-
}
161-
}
162-
}
163-
}
164-
}

0 commit comments

Comments
 (0)