Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,33 @@ async fn test_nvl2() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_nvl2_short_circuit() -> Result<()> {
let expr = nvl2(
col("a"),
arrow_cast(lit("1"), lit("Int32")),
arrow_cast(col("a"), lit("Int32")),
);

let batches = get_batches(expr).await?;

assert_snapshot!(
batches_to_string(&batches),
@r#"
+-----------------------------------------------------------------------------------+
| nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32"))) |
+-----------------------------------------------------------------------------------+
| 1 |
| 1 |
| 1 |
| 1 |
+-----------------------------------------------------------------------------------+
"#
);

Ok(())
}
#[tokio::test]
async fn test_fn_arrow_typeof() -> Result<()> {
let expr = arrow_typeof(col("l"));
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,26 @@ async fn test_create_physical_expr() {
create_simplified_expr_test(lit(1i32) + lit(2i32), "3");
}

#[test]
fn test_create_physical_expr_nvl2() {
let batch = &TEST_BATCH;
let df_schema = DFSchema::try_from(batch.schema()).unwrap();
let ctx = SessionContext::new();

let expect_err = |expr| {
let physical_expr = ctx.create_physical_expr(expr, &df_schema).unwrap();
let err = physical_expr.evaluate(batch).unwrap_err();
assert!(
err.to_string()
.contains("nvl2 should have been simplified to case"),
"unexpected error: {err:?}"
);
};

expect_err(nvl2(col("i"), lit(1i64), lit(0i64)));
expect_err(nvl2(lit(1i64), col("i"), lit(0i64)));
}

#[tokio::test]
async fn test_create_physical_expr_coercion() {
// create_physical_expr does apply type coercion and unwrapping in cast
Expand Down
85 changes: 37 additions & 48 deletions datafusion/functions/src/core/nvl2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{internal_err, utils::take_function_args, Result};
use datafusion_expr::{
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
conditional_expressions::CaseBuilder,
simplify::{ExprSimplifyResult, SimplifyInfo},
type_coercion::binary::comparison_coercion,
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Conditional Functions"),
Expand Down Expand Up @@ -95,8 +94,37 @@ impl ScalarUDFImpl for NVL2Func {
Ok(arg_types[1].clone())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
nvl2_func(&args.args)
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable =
args.arg_fields[1].is_nullable() || args.arg_fields[2].is_nullable();
let return_type = args.arg_fields[1].data_type().clone();
Ok(Field::new(self.name(), return_type, nullable).into())
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("nvl2 should have been simplified to case")
}

fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let [test, if_non_null, if_null] = take_function_args(self.name(), args)?;

let expr = CaseBuilder::new(
None,
vec![test.is_not_null()],
vec![if_non_null],
Some(Box::new(if_null)),
)
.end()?;

Ok(ExprSimplifyResult::Simplified(expr))
}

fn short_circuits(&self) -> bool {
true
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Expand All @@ -123,42 +151,3 @@ impl ScalarUDFImpl for NVL2Func {
self.doc()
}
}

fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let mut len = 1;
let mut is_array = false;
for arg in args {
if let ColumnarValue::Array(array) = arg {
len = array.len();
is_array = true;
break;
}
}
if is_array {
let args = args
.iter()
.map(|arg| match arg {
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len),
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
})
.collect::<Result<Vec<_>>>()?;
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
let to_apply = is_not_null(&tested)?;
let value = zip(&to_apply, &if_non_null, &if_null)?;
Ok(ColumnarValue::Array(value))
} else {
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
match &tested {
ColumnarValue::Array(_) => {
internal_err!("except Scalar value, but got Array")
}
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
Ok(if_null.clone())
} else {
Ok(if_non_null.clone())
}
}
}
}
}