From 714f9d68199a44a3261f9ef07ddddad0e73bda5b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:08:21 +0300 Subject: [PATCH 01/22] move case to a folder for later additions --- datafusion/physical-expr/src/expressions/{ => case}/case.rs | 2 +- datafusion/physical-expr/src/expressions/case/mod.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) rename datafusion/physical-expr/src/expressions/{ => case}/case.rs (99%) create mode 100644 datafusion/physical-expr/src/expressions/case/mod.rs diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs similarity index 99% rename from datafusion/physical-expr/src/expressions/case.rs rename to datafusion/physical-expr/src/expressions/case/case.rs index 2db599047bcd..e861605db373 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case/case.rs @@ -31,7 +31,7 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; -use super::{Column, Literal}; +use super::super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs new file mode 100644 index 000000000000..a8f9c5389213 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -0,0 +1,3 @@ +mod case; + +pub use case::*; From b4b19704481679eff0451d6828420c3aae416ac2 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 20 Oct 2025 22:53:22 +0300 Subject: [PATCH 02/22] started implementation for literal lookup for case when --- .../src/expressions/case/case.rs | 89 ++- .../literal_values/literal_lookup_table.rs | 585 ++++++++++++++++++ .../expressions/case/literal_values/mod.rs | 4 + .../case/literal_values/wrapper.rs | 348 +++++++++++ .../physical-expr/src/expressions/case/mod.rs | 1 + 5 files changed, 1021 insertions(+), 6 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/mod.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs diff --git a/datafusion/physical-expr/src/expressions/case/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs index e861605db373..2e33cff2113b 100644 --- a/datafusion/physical-expr/src/expressions/case/case.rs +++ b/datafusion/physical-expr/src/expressions/case/case.rs @@ -34,6 +34,7 @@ use datafusion_expr::ColumnarValue; use super::super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; +use crate::expressions::case::literal_values::LookupTable; type WhenThen = (Arc, Arc); @@ -66,6 +67,58 @@ enum EvalMethod { /// /// CASE WHEN condition THEN expression ELSE expression END ExpressionOrExpression, + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// `CASE WHEN` pattern on supported lookup types: + /// + /// This optimization applies to CASE expressions of the form: + /// ```sql + /// CASE + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// ELSE + /// END + /// ``` + /// + /// all the `WHEN` expressions are equality comparisons on the same expression against literals, + /// and all the `THEN` expressions are literals + /// the expression `` can be any expression as long as it does not have any state (e.g. random number generator, current timestamp, etc.) + /// + /// TODO - how to assert that the expression is stateless and deterministic + /// + /// # Improvement idea + /// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons + /// so it will use this optimization as well, e.g. + /// ```sql + /// -- Before + /// CASE + /// WHEN ( = ) THEN + /// WHEN ( in (, ) THEN + /// WHEN ( = ) THEN + /// ELSE + /// + /// -- After + /// CASE + /// WHEN ( = ) THEN + /// WHEN ( = ) THEN + /// WHEN ( = ) THEN + /// WHEN ( = ) THEN + /// ELSE + /// END + /// ``` + /// + WithExpressionOnlyScalarValuesAndResults(ScalarsOrNullLookup) +} + +#[derive(Debug)] +struct ScalarsOrNullLookup { + /// The lookup table to use for evaluating the CASE expression + lookup: Arc, + + values_to_take_from: ArrayRef, } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -455,22 +508,45 @@ impl CaseExpr { }; let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; + .1 + .evaluate_selection(batch, &when_value)? + .into_array(batch.num_rows())?; // evaluate else expression on the values not covered by when_value let remainder = not(&when_value)?; let e = self.else_expr.as_ref().unwrap(); + // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); + .unwrap_or_else(|_| Arc::clone(e)); let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } + + fn with_expression_scalars_values_and_results(&self, batch: &RecordBatch, scalars_or_null_lookup: &ScalarsOrNullLookup) -> Result { + let expr = self.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; + + let take_indices = scalars_or_null_lookup.lookup.match_values(&evaluated_expression)?; + + // Zero-copy conversion + let take_indices = Int32Array::from(take_indices); + + let output = arrow::compute::take(&scalars_or_null_lookup.values_to_take_from, &take_indices, None)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) + } else { + ColumnarValue::Array(output) + }; + + Ok(result) + } } impl PhysicalExpr for CaseExpr { @@ -535,6 +611,7 @@ impl PhysicalExpr for CaseExpr { } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), + EvalMethod::WithExpressionOnlyScalarValuesAndResults(ref e) => self.with_expression_scalars_values_and_results(batch, e), } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs new file mode 100644 index 000000000000..47a3d5892cf3 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs @@ -0,0 +1,585 @@ +use crate::expressions::Literal; +use arrow::array::AsArray; +use arrow::array::{downcast_integer, downcast_primitive, Array, ArrayAccessor, ArrayIter, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeBinaryIter, GenericByteViewArray, TypedDictionaryArray}; +use arrow::array::GenericByteArray; +use arrow::datatypes::{i256, ArrowDictionaryKeyType, BinaryViewType, ByteArrayType, ByteViewType, DataType, GenericBinaryType, GenericStringType, IntervalDayTime, IntervalMonthDayNano, StringViewType}; +use datafusion_common::{exec_datafusion_err, plan_datafusion_err, ScalarValue}; +use half::f16; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::{Hash}; +use std::iter::Map; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Lookup table for mapping literal values to their corresponding indices +/// +/// The else index is used when a value is not found in the lookup table +pub(crate) trait LookupTable: Debug + Send + Sync { + /// Try creating a new lookup table from the given literals and else index + fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result + where + Self: Sized; + + /// Return indices to take from the literals based on the values in the given array + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; +} + +pub(crate) fn try_creating_lookup_table( + literals: &[Arc], + else_index: i32, +) -> datafusion_common::Result> { + assert_ne!(literals.len(), 0, "Must have at least one literal"); + match literals[0].value().data_type() { + DataType::Boolean => { + let lookup_table = BooleanLookupMap::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { + ($t:ty) => {{ + let lookup_table = + PrimitiveArrayMapHolder::<$t>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + }}; + } + + downcast_primitive! { + data_type => (create_matching_map), + _ => Err(plan_datafusion_err!( + "Unsupported field type for primitive: {:?}", + data_type + )), + } + } + + DataType::Utf8 => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table = + BytesLookupTable::::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = + BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = + BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Dictionary(key, value) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ + create_lookup_table_for_dictionary_input::<$t>( + value.as_ref(), + literals, + else_index, + ) + }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + _ => Err(plan_datafusion_err!( + "Unsupported data type for lookup table: {}", + literals[0].value().data_type() + )), + } +} + +fn create_lookup_table_for_dictionary_input( + value: &DataType, + literals: &[Arc], + else_index: i32, +) -> datafusion_common::Result> { + match value { + DataType::Utf8 => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table =BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + _ => Err(plan_datafusion_err!( + "Unsupported dictionary value type for lookup table: {}", + value + )), + } +} + +#[derive(Clone)] +struct PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap::HashableKey>, i32>, + else_index: i32, +} + +impl LookupTable for PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals.iter().map(|item| item.value().clone()))?; + + let map = input + .as_primitive::() + .into_iter() + .enumerate() + .map(|(map_index, value)| (value.map(|v| v.into_hashable_key()), map_index as i32)) + .collect(); + + Ok(Self { map, else_index }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let indices = array + .as_primitive::() + .into_iter() + .map(|value| self.map.get(&value.map(|item| item.into_hashable_key())).copied().unwrap_or(self.else_index)) + .collect::>(); + + Ok(indices) + } +} + + +trait BytesMapHelperWrapperTrait: Send + Sync +{ + type IntoIter<'a>: Iterator> + 'a; + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; +} + + +#[derive(Debug, Clone, Default)] +struct GenericBytesHelper(PhantomData); + +impl BytesMapHelperWrapperTrait for GenericBytesHelper { + type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array + .as_bytes::() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +#[derive(Debug, Clone, Default)] +struct FixedBinaryHelper; + +impl BytesMapHelperWrapperTrait for FixedBinaryHelper { + type IntoIter<'a> = FixedSizeBinaryIter<'a>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_fixed_size_binary().into_iter()) + } +} + + +#[derive(Debug, Clone, Default)] +struct GenericBytesViewHelper(PhantomData); +impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { + type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_byte_view::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + + +#[derive(Debug, Clone, Default)] +struct BytesDictionaryHelper(PhantomData<(Key, Value)>); + +impl BytesMapHelperWrapperTrait for BytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteArrayType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: + IntoIterator> { + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }))) + } +} + +#[derive(Debug, Clone, Default)] +struct FixedBytesDictionaryHelper(PhantomData); + +impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: IntoIterator> { + type IntoIter<'a> = as IntoIterator>::IntoIter; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", + array.data_type() + ) + })?; + + Ok(dict_array.into_iter()) + } +} + +#[derive(Debug, Clone, Default)] +struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); + +impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteViewType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: + IntoIterator> { + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }))) + } +} + +#[derive(Clone)] +struct BytesLookupTable { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, i32>, + null_index: i32, + else_index: i32, + + _phantom_data: PhantomData, +} + +impl Debug for BytesLookupTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BytesMapHelper") + .field("map", &self.map) + .field("null_index", &self.null_index) + .field("else_index", &self.else_index) + .finish() + } +} + +impl LookupTable for BytesLookupTable { + fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals.iter().map(|item| item.value().clone()))?; + let bytes_iter = Helper::array_to_iter(&input)?; + + let mut null_index = None; + + let mut map: HashMap, i32> = HashMap::new(); + + for (map_index, value) in bytes_iter.enumerate() { + match value { + Some(value) => { + let slice_value: &[u8] = value.as_ref(); + + // Insert only the first occurrence + map.entry(slice_value.to_vec()).or_insert(map_index as i32); + } + None => { + // Only set the null index once + if null_index.is_none() { + null_index = Some(map_index as i32); + } + } + } + } + + Ok(Self { + map, + null_index: null_index.unwrap_or(else_index), + else_index, + _phantom_data: Default::default(), + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let bytes_iter = Helper::array_to_iter(array)?; + let indices = bytes_iter + .map(|value| { + match value { + Some(value) => { + let slice_value: &[u8] = value.as_ref(); + self.map.get(slice_value).copied().unwrap_or(self.else_index) + } + None => { + self.null_index + } + } + }) + .collect::>(); + + Ok(indices) + } +} + + +#[derive(Clone, Debug)] +struct BooleanLookupMap { + true_index: i32, + false_index: i32, + null_index: i32, +} + +impl LookupTable for BooleanLookupMap { + fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + fn get_first_index( + literals: &[Arc], + target: Option, + ) -> Option { + literals + .iter() + .position(|literal| matches!(literal.value(), ScalarValue::Boolean(target))) + .map(|pos| pos as i32) + } + + Ok(Self { + false_index: get_first_index(literals, Some(false)).unwrap_or(else_index), + true_index: get_first_index(literals, Some(true)).unwrap_or(else_index), + null_index: get_first_index(literals, None).unwrap_or(else_index), + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + Ok( + array + .as_boolean() + .into_iter() + .map(|value| match value { + Some(true) => self.true_index, + Some(false) => self.false_index, + None => self.null_index, + }) + .collect::>() + ) + } +} + +macro_rules! impl_lookup_table_super_traits { + (impl _ for $MyType:ty) => { + impl_lookup_table_super_traits!(impl<> _ for $MyType where); + }; + (impl<$($impl_generics:ident),*> _ for $MyType:ty where $($where_clause:tt)*) => { + impl<$($impl_generics),*> Debug for $MyType + where + $($where_clause)* + { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + f.debug_struct(stringify!($MyType)) + .field("map", &self.map) + .field("else_index", &self.else_index) + .finish() + } + } + }; +} + +impl_lookup_table_super_traits!( + impl _ for PrimitiveArrayMapHolder where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +); + +// TODO - We need to port it to arrow so that it can be reused in other places + +/// Trait that help convert a value to a key that is hashable and equatable +/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly +trait ToHashableKey: ArrowNativeTypeOp { + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; +} + +macro_rules! impl_to_hashable_key { + (@single_already_hashable | $t:ty) => { + impl ToHashableKey for $t { + type HashableKey = $t; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self + } + } + }; + (@already_hashable | $($t:ty),+ $(,)?) => { + $( + impl_to_hashable_key!(@single_already_hashable | $t); + )+ + }; + (@float | $t:ty => $hashable:ty) => { + impl ToHashableKey for $t { + type HashableKey = $hashable; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self.to_bits() + } + } + }; +} + +impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); +impl_to_hashable_key!(@float | f16 => u16); +impl_to_hashable_key!(@float | f32 => u32); +impl_to_hashable_key!(@float | f64 => u64); + +#[cfg(test)] +mod tests { + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } + + macro_rules! create_matching_set { + ($t:ty) => {{ + let _lookup_table = ExampleSet::<$t> { + _map: Default::default() + }; + + return; + }}; + } + + let data_type = arrow::datatypes::DataType::Float16; + + downcast_primitive! { + data_type => (create_matching_set), + _ => panic!("not implemented for {data_type}"), + } + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs new file mode 100644 index 000000000000..e97ec9706612 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs @@ -0,0 +1,4 @@ +mod wrapper; +mod literal_lookup_table; + +pub(super) use literal_lookup_table::{LookupTable, try_creating_lookup_table}; \ No newline at end of file diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs b/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs new file mode 100644 index 000000000000..518e08023c28 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs @@ -0,0 +1,348 @@ +use arrow::array::ArrayRef; +use arrow::datatypes::Schema; +use std::ops::Deref; +use std::sync::Arc; +use datafusion_common::{internal_datafusion_err, plan_datafusion_err, ScalarValue}; +use datafusion_expr::expr::FieldMetadata; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use crate::expressions::{BinaryExpr, Literal}; + +/// All the optimizations in this module are based on the assumption of this `CASE WHEN` pattern: +/// ```sql +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// END +/// ``` +/// +/// all the `WHEN` expressions are equality comparisons on the same expression against literals, +/// and all the `THEN` expressions are literals +/// the expression `` can be any expression as long as it does not have any state (e.g. random number generator, current timestamp, etc.) +pub(super) struct CaseWhenLiteralMapping { + /// The expression that is being compared against the literals in the when clauses + /// This expression must be deterministic + /// In the example above this is `` + pub(super) expression_to_match_on: Arc, + + /// The literals that are being compared against the expression in the when clauses + /// In the example above this is ``, ``, ``, `` + /// These literals must all be of the same data type as the expression_to_match_on + pub(super) when_equality_literals: Vec>, + + /// The literals that are being returned in the then clauses + /// In the example above this is ``, ``, ``, `` + /// These literals must all be of the same data type + pub(super) then_literals: Vec>, + + /// The literal that is being returned in the else clause + /// In the example above this is `` + /// This literal must be of the same data type as the then_literals + /// + /// If no else clause is provided, this will be a null literal of the same data type as the then_literals + pub(super) else_expr: Arc, +} + +impl CaseWhenLiteralMapping { + /// Will return None if the optimization cannot be used + /// Otherwise will return the optimized expression + pub fn map_case_when( + when_then_pairs: Vec<(Arc, Arc)>, + else_expr: Option>, + input_schema: &Schema, + ) -> Option { + // let expression_to_match_on: Arc; + // let when_equality_literals: Vec>; + + // We can't use the optimization if we don't have any when then pairs + if when_then_pairs.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if when_then_pairs.len() == 1 { + return None; + } + + let when_exprs = when_then_pairs + .iter() + .map(|(when, _)| Arc::clone(when)) + .collect::>(); + + // If any of the when expressions is not a binary expression we cannot use this optimization + + let when_binary_exprs = when_exprs + .iter() + .map(|when| { + let binary = when + .as_any() + .downcast_ref::(); + binary.cloned() + }) + .collect::>>()?; + let when_exprs = when_binary_exprs; + + // If not all the binary expression are equality we cannot use this optimization + if when_exprs + .iter() + .any(|when| !matches!(when.op(), Operator::Eq)) + { + return None; + } + + let expressions_to_match_on = when_exprs + .iter() + .map(|when| Arc::clone(when.left())) + .collect::>(); + + let first_expression_to_match_on = &expressions_to_match_on[0]; + + // Check if all expressions are the same + if expressions_to_match_on + .iter() + .any(|expr| !expr.dyn_eq(first_expression_to_match_on.deref().as_any())) + { + return None; + } + // TODO - Test that the expression is deterministic + let expression_to_match_on: Arc = + Arc::clone(first_expression_to_match_on); + + let equality_value_exprs = when_exprs + .iter() + .map(|when| when.right()) + .collect::>(); + + let when_equality_literals: Vec> = { + // TODO - spark should do constant folding but we should support expression on literal anyway + // Test that all of the expressions are literals + if equality_value_exprs + .iter() + .any(|expr| expr.as_any().downcast_ref::().is_none()) + { + return None; + } + + equality_value_exprs + .iter() + .map(|expr| { + let literal = expr.as_any().downcast_ref::().unwrap(); + let literal = Literal::new_with_metadata( + literal.value().clone(), + // Empty schema as it is not used by literal + Some(FieldMetadata::from( + literal.return_field(&Schema::empty()).unwrap().deref(), + )), + ); + + Arc::new(literal) + }) + .collect::>() + }; + + { + let Ok(data_type) = expression_to_match_on.data_type(input_schema) else { + return None; + }; + + if data_type != when_equality_literals[0].value().data_type() { + return None; + } + } + + let then_literals: Vec> = { + let then_literal_values = when_then_pairs + .iter() + .map(|(_, then)| then) + .collect::>(); + + // TODO - spark should do constant folding but we should support expression on literal anyway + // Test that all of the expressions are literals + if then_literal_values + .iter() + .any(|expr| expr.as_any().downcast_ref::().is_none()) + { + return None; + } + + then_literal_values + .iter() + .map(|expr| { + let literal = expr.as_any().downcast_ref::().unwrap(); + let literal = Literal::new_with_metadata( + literal.value().clone(), + // Empty schema as it is not used by literal + Some(FieldMetadata::from( + literal.return_field(&Schema::empty()).unwrap().deref(), + )), + ); + + Arc::new(literal) + }) + .collect::>() + }; + + let else_expr: Arc = if let Some(else_expr) = else_expr { + // TODO - spark should do constant folding but we should support expression on literal anyway + + let literal = else_expr.as_any().downcast_ref::()?; + let literal = Literal::new_with_metadata( + literal.value().clone(), + // Empty schema as it is not used by literal + Some(FieldMetadata::from( + literal.return_field(&Schema::empty()).unwrap().deref(), + )), + ); + + Arc::new(literal) + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].value().data_type()) + else { + return None; + }; + Arc::new(Literal::new(null_scalar)) + }; + + let this = Self { + expression_to_match_on, + when_equality_literals, + then_literals, + else_expr, + }; + + this.assert_requirements_are_met().ok()?; + + Some(this) + } + + /// Assert that the requirements for the optimization are met so we can use it + pub fn assert_requirements_are_met(&self) -> datafusion_common::Result<()> { + // If expression_to_match_on is not deterministic we cannot use this optimization + // TODO - we need a way to check if an expression is deterministic + + if self.when_equality_literals.len() != self.then_literals.len() { + return Err(plan_datafusion_err!( + "when_equality_literals and then_literals must be the same length" + )); + } + + if self.when_equality_literals.is_empty() { + return Err(plan_datafusion_err!( + "when_equality_literals and then_literals cannot be empty" + )); + } + + // Assert that all when equality literals are the same type and no nulls + { + let data_type = self.when_equality_literals[0].value().data_type(); + + for when_lit in &self.when_equality_literals { + if when_lit.value().data_type() != data_type { + return Err(plan_datafusion_err!( + "All when_equality_literals must have the same data type, found {} and {}", + when_lit.value().data_type(), + data_type + )); + } + } + } + + // Assert that all output values are the same type + { + let data_type = self.then_literals[0].value().data_type(); + + for then_lit in &self.then_literals { + if then_lit.value().data_type() != data_type { + return Err(plan_datafusion_err!( + "All then_literals must have the same data type, found {} and {}", + then_lit.value().data_type(), + data_type + )); + } + } + + if self.else_expr.value().data_type() != data_type { + return Err(plan_datafusion_err!( + "else_expr must have the same data type as then_literals, found {} and {}", + self.else_expr.value().data_type(), + data_type + )); + } + } + + Ok(()) + } + + /// Return ArrayRef where array[i] = then_literals[i] + /// the last value in the array is the else_expr + pub fn build_dense_output_values(&self) -> datafusion_common::Result { + // Create the dictionary values array filled with the else value + let mut dictionary_values = vec![]; + + // Fill the dictionary values array with the then literals + for then_lit in self.then_literals.iter() { + dictionary_values.push(then_lit.value().clone()); + } + + // Add the else + dictionary_values.push(self.else_expr.value().clone()); + + let dictionary_values = ScalarValue::iter_to_array(dictionary_values)?; + + Ok(dictionary_values) + } + + /// Normalized all literal values to i128 to ease one-time computations + /// + /// this is i128 as we don't know if the input is signed or unsigned + /// as it can be used to validate the requirements of no negative, + /// and we don't want to lose information + pub fn get_when_literals_values_normalized_for_non_nullable_integer_literals( + &self, + ) -> datafusion_common::Result> { + self.when_equality_literals + .iter() + .map(|lit| lit.value()) + .map(|lit| { + if !lit.data_type().is_integer() { + return Err(plan_datafusion_err!( + "All when_equality_literals must be integer type, found {}", + lit.data_type() + )); + } + + if !lit.data_type().is_dictionary_key_type() { + return Err(plan_datafusion_err!( + "All when_equality_literals must be valid dictionary key type, found {}", + lit.data_type() + )); + } + + if lit.is_null() { + return Err(plan_datafusion_err!( + "All when_equality_literals must be non-null numeric types, found null" + )); + } + + match lit { + ScalarValue::Int8(Some(v)) => Ok(*v as i128), + ScalarValue::Int16(Some(v)) => Ok(*v as i128), + ScalarValue::Int32(Some(v)) => Ok(*v as i128), + ScalarValue::Int64(Some(v)) => Ok(*v as i128), + ScalarValue::UInt8(Some(v)) => Ok(*v as i128), + ScalarValue::UInt16(Some(v)) => Ok(*v as i128), + ScalarValue::UInt32(Some(v)) => Ok(*v as i128), + ScalarValue::UInt64(Some(v)) => Ok(*v as i128), + _ => Err(internal_datafusion_err!( + "dictionary key type is not supported {}, value: {}", + lit.data_type(), + lit + )), + } + }) + .collect::>>() + } +} diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index a8f9c5389213..1695ad2cb231 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -1,3 +1,4 @@ mod case; +mod literal_values; pub use case::*; From 1045071ff885af00fc027523c38ccd91f1fa2fc4 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 14:53:26 +0300 Subject: [PATCH 03/22] feat: add to `ExprProperties` the `volatility` I don't like that it is in `sort_properties.rs` but this struct is used in `get_properties` which seems like the most appropriate place --- .../src/expressions/case/case.rs | 179 +++++- .../boolean_lookup_table.rs | 47 ++ .../bytes_like_lookup_table.rs | 225 +++++++ .../case/literal_lookup_table/mod.rs | 161 +++++ .../primitive_lookup_table.rs | 140 +++++ .../literal_values/literal_lookup_table.rs | 585 ------------------ .../expressions/case/literal_values/mod.rs | 4 - .../case/literal_values/wrapper.rs | 348 ----------- .../physical-expr/src/expressions/case/mod.rs | 2 +- 9 files changed, 721 insertions(+), 970 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs delete mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs delete mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/mod.rs delete mode 100644 datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs diff --git a/datafusion/physical-expr/src/expressions/case/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs index 2e33cff2113b..ac119455c7f3 100644 --- a/datafusion/physical-expr/src/expressions/case/case.rs +++ b/datafusion/physical-expr/src/expressions/case/case.rs @@ -26,17 +26,15 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use super::super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; -use crate::expressions::case::literal_values::LookupTable; +use crate::expressions::case::literal_values::{try_creating_lookup_table, LookupTable}; -type WhenThen = (Arc, Arc); +pub(super) type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { @@ -110,7 +108,7 @@ enum EvalMethod { /// END /// ``` /// - WithExpressionOnlyScalarValuesAndResults(ScalarsOrNullLookup) + WithExprScalarLookupTable(ScalarsOrNullLookup) } #[derive(Debug)] @@ -118,9 +116,120 @@ struct ScalarsOrNullLookup { /// The lookup table to use for evaluating the CASE expression lookup: Arc, + /// ArrayRef where array[i] = then_literals[i] + /// the last value in the array is the else_expr values_to_take_from: ArrayRef, } +impl ScalarsOrNullLookup { + pub fn maybe_new( + when_then_expr: &Vec, else_expr: &Option> + ) -> Option { + // We can't use the optimization if we don't have any when then pairs + if when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if when_then_expr.len() == 1 { + return None; + } + + let when_then_exprs_maybe_literals = when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::(); + let then_maybe_literal = then.as_any().downcast_ref::(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); + + // If not all the when/then expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let (when_literals, then_literals): (Vec, Vec) = when_then_exprs_maybe_literals + .iter() + // Unwrap the options as we have already checked they are all Some + .flatten() + .map(|(when_lit, then_lit)| (when_lit.value().clone(), then_lit.value().clone())) + .unzip(); + + + let else_expr: ScalarValue = if let Some(else_expr) = else_expr { + let literal = else_expr.as_any().downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].data_type()) + else { + return None; + }; + + null_scalar + }; + + { + let data_type = when_literals[0].data_type(); + + // If not all the when literals are the same data type we cannot use this optimization + if when_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + } + + { + let data_type = then_literals[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_expr.data_type() != data_type { + return None; + } + } + + + let output_array = ScalarValue::iter_to_array( + then_literals.iter() + // The else is in the end + .chain(std::iter::once(&else_expr)) + .cloned() + ).ok()?; + + let lookup = try_creating_lookup_table( + when_literals, + + // The else expression is in the end + output_array.len() as i32 - 1, + ).ok()?; + + Some(Self { + lookup, + values_to_take_from: output_array, + }) + } + + fn create_output(&self, expr_array: &ArrayRef) -> Result { + let take_indices = self.lookup.match_values(&expr_array)?; + + // Zero-copy conversion + let take_indices = Int32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -195,24 +304,7 @@ impl CaseExpr { if when_then_expr.is_empty() { exec_err!("There must be at least one WHEN clause") } else { - let eval_method = if expr.is_some() { - EvalMethod::WithExpression - } else if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if when_then_expr.len() == 1 - && when_then_expr[0].1.as_any().is::() - && else_expr.is_some() - && else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression - } else { - EvalMethod::NoExpression - }; + let eval_method = Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); Ok(Self { expr, @@ -223,6 +315,33 @@ impl CaseExpr { } } + fn find_best_eval_method(expr: &Option>, when_then_expr: &Vec, else_expr: &Option>) -> EvalMethod { + if expr.is_some() { + if let Some(mapping) = ScalarsOrNullLookup::maybe_new(when_then_expr, else_expr) { + return EvalMethod::WithExprScalarLookupTable(mapping); + } + + return EvalMethod::WithExpression + } + + if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if when_then_expr.len() == 1 && else_expr.is_some() { + EvalMethod::ExpressionOrExpression + } else { + EvalMethod::NoExpression + } + } + /// Optional base expression that can be compared to literal values in the "when" expressions pub fn expr(&self) -> Option<&Arc> { self.expr.as_ref() @@ -526,18 +645,14 @@ impl CaseExpr { Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } - fn with_expression_scalars_values_and_results(&self, batch: &RecordBatch, scalars_or_null_lookup: &ScalarsOrNullLookup) -> Result { + fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &ScalarsOrNullLookup) -> Result { let expr = self.expr.as_ref().unwrap(); let evaluated_expression = expr.evaluate(batch)?; + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); let evaluated_expression = evaluated_expression.to_array(1)?; - let take_indices = scalars_or_null_lookup.lookup.match_values(&evaluated_expression)?; - - // Zero-copy conversion - let take_indices = Int32Array::from(take_indices); - - let output = arrow::compute::take(&scalars_or_null_lookup.values_to_take_from, &take_indices, None)?; + let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; let result = if is_scalar { ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) @@ -611,7 +726,7 @@ impl PhysicalExpr for CaseExpr { } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), - EvalMethod::WithExpressionOnlyScalarValuesAndResults(ref e) => self.with_expression_scalars_values_and_results(batch, e), + EvalMethod::WithExprScalarLookupTable(ref e) => self.with_lookup_table(batch, e), } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs new file mode 100644 index 000000000000..4b3e0057db5c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -0,0 +1,47 @@ +use arrow::array::ArrayRef; +use datafusion_common::ScalarValue; +use crate::expressions::case::literal_lookup_table::LookupTable; + +#[derive(Clone, Debug)] +pub(super) struct BooleanLookupMap { + true_index: i32, + false_index: i32, + null_index: i32, +} + +impl LookupTable for BooleanLookupMap { + fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + fn get_first_index( + literals: &[ScalarValue], + target: Option, + ) -> Option { + literals + .iter() + .position(|literal| matches!(literal, ScalarValue::Boolean(target))) + .map(|pos| pos as i32) + } + + Ok(Self { + false_index: get_first_index(&literals, Some(false)).unwrap_or(else_index), + true_index: get_first_index(&literals, Some(true)).unwrap_or(else_index), + null_index: get_first_index(&literals, None).unwrap_or(else_index), + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + Ok( + array + .as_boolean() + .into_iter() + .map(|value| match value { + Some(true) => self.true_index, + Some(false) => self.false_index, + None => self.null_index, + }) + .collect::>() + ) + } +} \ No newline at end of file diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs new file mode 100644 index 000000000000..dfd49d3af86c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -0,0 +1,225 @@ +use std::fmt::Debug; +use std::iter::Map; +use std::marker::PhantomData; +use arrow::array::{ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, GenericByteArray, GenericByteViewArray, TypedDictionaryArray}; +use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; +use datafusion_common::{exec_datafusion_err, HashMap, ScalarValue}; +use crate::expressions::case::literal_lookup_table::LookupTable; + +trait BytesMapHelperWrapperTrait: Send + Sync { + type IntoIter<'a>: Iterator> + 'a; + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; +} + + +#[derive(Debug, Clone, Default)] +struct GenericBytesHelper(PhantomData); + +impl BytesMapHelperWrapperTrait for GenericBytesHelper { + type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array + .as_bytes::() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +#[derive(Debug, Clone, Default)] +struct FixedBinaryHelper; + +impl BytesMapHelperWrapperTrait for FixedBinaryHelper { + type IntoIter<'a> = FixedSizeBinaryIter<'a>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_fixed_size_binary().into_iter()) + } +} + + +#[derive(Debug, Clone, Default)] +struct GenericBytesViewHelper(PhantomData); +impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { + type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_byte_view::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + + +#[derive(Debug, Clone, Default)] +struct BytesDictionaryHelper(PhantomData<(Key, Value)>); + +impl BytesMapHelperWrapperTrait for BytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteArrayType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: + IntoIterator> { + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }))) + } +} + +#[derive(Debug, Clone, Default)] +struct FixedBytesDictionaryHelper(PhantomData); + +impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: IntoIterator> { + type IntoIter<'a> = as IntoIterator>::IntoIter; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", + array.data_type() + ) + })?; + + Ok(dict_array.into_iter()) + } +} + +#[derive(Debug, Clone, Default)] +struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); + +impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteViewType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: + IntoIterator> { + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }))) + } +} + +#[derive(Clone)] +pub(super) struct BytesLookupTable { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, i32>, + null_index: i32, + else_index: i32, + + _phantom_data: PhantomData, +} + +impl Debug for BytesLookupTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BytesMapHelper") + .field("map", &self.map) + .field("null_index", &self.null_index) + .field("else_index", &self.else_index) + .finish() + } +} + +impl LookupTable for BytesLookupTable { + fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals)?; + let bytes_iter = Helper::array_to_iter(&input)?; + + let mut null_index = None; + + let mut map: HashMap, i32> = HashMap::new(); + + for (map_index, value) in bytes_iter.enumerate() { + match value { + Some(value) => { + let slice_value: &[u8] = value.as_ref(); + + // Insert only the first occurrence + map.entry(slice_value.to_vec()).or_insert(map_index as i32); + } + None => { + // Only set the null index once + if null_index.is_none() { + null_index = Some(map_index as i32); + } + } + } + } + + Ok(Self { + map, + null_index: null_index.unwrap_or(else_index), + else_index, + _phantom_data: Default::default(), + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let bytes_iter = Helper::array_to_iter(array)?; + let indices = bytes_iter + .map(|value| { + match value { + Some(value) => { + let slice_value: &[u8] = value.as_ref(); + self.map.get(slice_value).copied().unwrap_or(self.else_index) + } + None => { + self.null_index + } + } + }) + .collect::>(); + + Ok(indices) + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs new file mode 100644 index 000000000000..2e6ee2248e32 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -0,0 +1,161 @@ +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use std::fmt::Debug; +use std::sync::Arc; +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::ScalarValue; + +/// Lookup table for mapping literal values to their corresponding indices +/// +/// The else index is used when a value is not found in the lookup table +pub(super) trait LookupTable: Debug + Send + Sync { + /// Try creating a new lookup table from the given literals and else index + fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result + where + Self: Sized; + + /// Return indices to take from the literals based on the values in the given array + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; +} + +pub(crate) fn try_creating_lookup_table( + literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { + assert_ne!(literals.len(), 0, "Must have at least one literal"); + match literals[0].data_type() { + DataType::Boolean => { + let lookup_table = BooleanLookupMap::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { + ($t:ty) => {{ + let lookup_table = + PrimitiveArrayMapHolder::<$t>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + }}; + } + + downcast_primitive! { + data_type => (create_matching_map), + _ => Err(plan_datafusion_err!( + "Unsupported field type for primitive: {:?}", + data_type + )), + } + } + + DataType::Utf8 => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = + BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table = + BytesLookupTable::::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = + BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = + BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Dictionary(key, value) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ + create_lookup_table_for_dictionary_input::<$t>( + value.as_ref(), + literals, + else_index, + ) + }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + _ => Err(plan_datafusion_err!( + "Unsupported data type for lookup table: {}", + literals[0].data_type() + )), + } +} + +fn create_lookup_table_for_dictionary_input( + value: &DataType, + literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { + match value { + DataType::Utf8 => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table =BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + _ => Err(plan_datafusion_err!( + "Unsupported dictionary value type for lookup table: {}", + value + )), + } +} \ No newline at end of file diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs new file mode 100644 index 000000000000..c4380c7491da --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -0,0 +1,140 @@ +use std::fmt::Debug; +use std::hash::Hash; +use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; +use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; +use half::f16; +use datafusion_common::{HashMap, ScalarValue}; +use crate::expressions::case::literal_lookup_table::LookupTable; + +#[derive(Clone)] +struct PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap::HashableKey>, i32>, + else_index: i32, +} + + + +impl LookupTable for PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals)?; + + let map = input + .as_primitive::() + .into_iter() + .enumerate() + .map(|(map_index, value)| (value.map(|v| v.into_hashable_key()), map_index as i32)) + .collect(); + + Ok(Self { map, else_index }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let indices = array + .as_primitive::() + .into_iter() + .map(|value| self.map.get(&value.map(|item| item.into_hashable_key())).copied().unwrap_or(self.else_index)) + .collect::>(); + + Ok(indices) + } +} + + +// TODO - We need to port it to arrow so that it can be reused in other places + +/// Trait that help convert a value to a key that is hashable and equatable +/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly +trait ToHashableKey: ArrowNativeTypeOp { + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; +} + +macro_rules! impl_to_hashable_key { + (@single_already_hashable | $t:ty) => { + impl ToHashableKey for $t { + type HashableKey = $t; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self + } + } + }; + (@already_hashable | $($t:ty),+ $(,)?) => { + $( + impl_to_hashable_key!(@single_already_hashable | $t); + )+ + }; + (@float | $t:ty => $hashable:ty) => { + impl ToHashableKey for $t { + type HashableKey = $hashable; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self.to_bits() + } + } + }; +} + +impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); +impl_to_hashable_key!(@float | f16 => u16); +impl_to_hashable_key!(@float | f32 => u32); +impl_to_hashable_key!(@float | f64 => u64); + +#[cfg(test)] +mod tests { + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } + + macro_rules! create_matching_set { + ($t:ty) => {{ + let _lookup_table = ExampleSet::<$t> { + _map: Default::default() + }; + + return; + }}; + } + + let data_type = arrow::datatypes::DataType::Float16; + + downcast_primitive! { + data_type => (create_matching_set), + _ => panic!("not implemented for {data_type}"), + } + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs deleted file mode 100644 index 47a3d5892cf3..000000000000 --- a/datafusion/physical-expr/src/expressions/case/literal_values/literal_lookup_table.rs +++ /dev/null @@ -1,585 +0,0 @@ -use crate::expressions::Literal; -use arrow::array::AsArray; -use arrow::array::{downcast_integer, downcast_primitive, Array, ArrayAccessor, ArrayIter, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeBinaryIter, GenericByteViewArray, TypedDictionaryArray}; -use arrow::array::GenericByteArray; -use arrow::datatypes::{i256, ArrowDictionaryKeyType, BinaryViewType, ByteArrayType, ByteViewType, DataType, GenericBinaryType, GenericStringType, IntervalDayTime, IntervalMonthDayNano, StringViewType}; -use datafusion_common::{exec_datafusion_err, plan_datafusion_err, ScalarValue}; -use half::f16; -use std::collections::HashMap; -use std::fmt::Debug; -use std::hash::{Hash}; -use std::iter::Map; -use std::marker::PhantomData; -use std::sync::Arc; - -/// Lookup table for mapping literal values to their corresponding indices -/// -/// The else index is used when a value is not found in the lookup table -pub(crate) trait LookupTable: Debug + Send + Sync { - /// Try creating a new lookup table from the given literals and else index - fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result - where - Self: Sized; - - /// Return indices to take from the literals based on the values in the given array - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; -} - -pub(crate) fn try_creating_lookup_table( - literals: &[Arc], - else_index: i32, -) -> datafusion_common::Result> { - assert_ne!(literals.len(), 0, "Must have at least one literal"); - match literals[0].value().data_type() { - DataType::Boolean => { - let lookup_table = BooleanLookupMap::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - data_type if data_type.is_primitive() => { - macro_rules! create_matching_map { - ($t:ty) => {{ - let lookup_table = - PrimitiveArrayMapHolder::<$t>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - }}; - } - - downcast_primitive! { - data_type => (create_matching_map), - _ => Err(plan_datafusion_err!( - "Unsupported field type for primitive: {:?}", - data_type - )), - } - } - - DataType::Utf8 => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::LargeUtf8 => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::Binary => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::LargeBinary => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::FixedSizeBinary(_) => { - let lookup_table = - BytesLookupTable::::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::Utf8View => { - let lookup_table = - BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - DataType::BinaryView => { - let lookup_table = - BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::Dictionary(key, value) => { - macro_rules! downcast_dictionary_array_helper { - ($t:ty) => {{ - create_lookup_table_for_dictionary_input::<$t>( - value.as_ref(), - literals, - else_index, - ) - }}; - } - - downcast_integer! { - key.as_ref() => (downcast_dictionary_array_helper), - k => unreachable!("unsupported dictionary key type: {}", k) - } - } - _ => Err(plan_datafusion_err!( - "Unsupported data type for lookup table: {}", - literals[0].value().data_type() - )), - } -} - -fn create_lookup_table_for_dictionary_input( - value: &DataType, - literals: &[Arc], - else_index: i32, -) -> datafusion_common::Result> { - match value { - DataType::Utf8 => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::LargeUtf8 => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::Binary => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::LargeBinary => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::FixedSizeBinary(_) => { - let lookup_table =BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - - DataType::Utf8View => { - let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - DataType::BinaryView => { - let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - _ => Err(plan_datafusion_err!( - "Unsupported dictionary value type for lookup table: {}", - value - )), - } -} - -#[derive(Clone)] -struct PrimitiveArrayMapHolder -where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, -{ - /// Literal value to map index - /// - /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps - map: HashMap::HashableKey>, i32>, - else_index: i32, -} - -impl LookupTable for PrimitiveArrayMapHolder -where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, -{ - fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - let input = ScalarValue::iter_to_array(literals.iter().map(|item| item.value().clone()))?; - - let map = input - .as_primitive::() - .into_iter() - .enumerate() - .map(|(map_index, value)| (value.map(|v| v.into_hashable_key()), map_index as i32)) - .collect(); - - Ok(Self { map, else_index }) - } - - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - let indices = array - .as_primitive::() - .into_iter() - .map(|value| self.map.get(&value.map(|item| item.into_hashable_key())).copied().unwrap_or(self.else_index)) - .collect::>(); - - Ok(indices) - } -} - - -trait BytesMapHelperWrapperTrait: Send + Sync -{ - type IntoIter<'a>: Iterator> + 'a; - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; -} - - -#[derive(Debug, Clone, Default)] -struct GenericBytesHelper(PhantomData); - -impl BytesMapHelperWrapperTrait for GenericBytesHelper { - type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array - .as_bytes::() - .into_iter() - .map(|item| { - item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }) - })) - } -} - -#[derive(Debug, Clone, Default)] -struct FixedBinaryHelper; - -impl BytesMapHelperWrapperTrait for FixedBinaryHelper { - type IntoIter<'a> = FixedSizeBinaryIter<'a>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array.as_fixed_size_binary().into_iter()) - } -} - - -#[derive(Debug, Clone, Default)] -struct GenericBytesViewHelper(PhantomData); -impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { - type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array.as_byte_view::().into_iter().map(|item| { - item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }) - })) - } -} - - -#[derive(Debug, Clone, Default)] -struct BytesDictionaryHelper(PhantomData<(Key, Value)>); - -impl BytesMapHelperWrapperTrait for BytesDictionaryHelper -where - Key: ArrowDictionaryKeyType + Send + Sync, - Value: ByteArrayType, - for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: - IntoIterator> { - type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array - .as_dictionary::() - .downcast_dict::>() - .ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary value {}", - array.data_type(), - Value::DATA_TYPE - ) - })?; - - Ok(dict_array.into_iter().map(|item| item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }))) - } -} - -#[derive(Debug, Clone, Default)] -struct FixedBytesDictionaryHelper(PhantomData); - -impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper -where - Key: ArrowDictionaryKeyType + Send + Sync, - for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: IntoIterator> { - type IntoIter<'a> = as IntoIterator>::IntoIter; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", - array.data_type() - ) - })?; - - Ok(dict_array.into_iter()) - } -} - -#[derive(Debug, Clone, Default)] -struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); - -impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper -where - Key: ArrowDictionaryKeyType + Send + Sync, - Value: ByteViewType, - for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: - IntoIterator> { - type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array - .as_dictionary::() - .downcast_dict::>() - .ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary value {}", - array.data_type(), - Value::DATA_TYPE - ) - })?; - - Ok(dict_array.into_iter().map(|item| item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }))) - } -} - -#[derive(Clone)] -struct BytesLookupTable { - /// Map from non-null literal value the first occurrence index in the literals - map: HashMap, i32>, - null_index: i32, - else_index: i32, - - _phantom_data: PhantomData, -} - -impl Debug for BytesLookupTable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("BytesMapHelper") - .field("map", &self.map) - .field("null_index", &self.null_index) - .field("else_index", &self.else_index) - .finish() - } -} - -impl LookupTable for BytesLookupTable { - fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - let input = ScalarValue::iter_to_array(literals.iter().map(|item| item.value().clone()))?; - let bytes_iter = Helper::array_to_iter(&input)?; - - let mut null_index = None; - - let mut map: HashMap, i32> = HashMap::new(); - - for (map_index, value) in bytes_iter.enumerate() { - match value { - Some(value) => { - let slice_value: &[u8] = value.as_ref(); - - // Insert only the first occurrence - map.entry(slice_value.to_vec()).or_insert(map_index as i32); - } - None => { - // Only set the null index once - if null_index.is_none() { - null_index = Some(map_index as i32); - } - } - } - } - - Ok(Self { - map, - null_index: null_index.unwrap_or(else_index), - else_index, - _phantom_data: Default::default(), - }) - } - - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - let bytes_iter = Helper::array_to_iter(array)?; - let indices = bytes_iter - .map(|value| { - match value { - Some(value) => { - let slice_value: &[u8] = value.as_ref(); - self.map.get(slice_value).copied().unwrap_or(self.else_index) - } - None => { - self.null_index - } - } - }) - .collect::>(); - - Ok(indices) - } -} - - -#[derive(Clone, Debug)] -struct BooleanLookupMap { - true_index: i32, - false_index: i32, - null_index: i32, -} - -impl LookupTable for BooleanLookupMap { - fn try_new(literals: &[Arc], else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - fn get_first_index( - literals: &[Arc], - target: Option, - ) -> Option { - literals - .iter() - .position(|literal| matches!(literal.value(), ScalarValue::Boolean(target))) - .map(|pos| pos as i32) - } - - Ok(Self { - false_index: get_first_index(literals, Some(false)).unwrap_or(else_index), - true_index: get_first_index(literals, Some(true)).unwrap_or(else_index), - null_index: get_first_index(literals, None).unwrap_or(else_index), - }) - } - - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - Ok( - array - .as_boolean() - .into_iter() - .map(|value| match value { - Some(true) => self.true_index, - Some(false) => self.false_index, - None => self.null_index, - }) - .collect::>() - ) - } -} - -macro_rules! impl_lookup_table_super_traits { - (impl _ for $MyType:ty) => { - impl_lookup_table_super_traits!(impl<> _ for $MyType where); - }; - (impl<$($impl_generics:ident),*> _ for $MyType:ty where $($where_clause:tt)*) => { - impl<$($impl_generics),*> Debug for $MyType - where - $($where_clause)* - { - fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { - f.debug_struct(stringify!($MyType)) - .field("map", &self.map) - .field("else_index", &self.else_index) - .finish() - } - } - }; -} - -impl_lookup_table_super_traits!( - impl _ for PrimitiveArrayMapHolder where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, -); - -// TODO - We need to port it to arrow so that it can be reused in other places - -/// Trait that help convert a value to a key that is hashable and equatable -/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly -trait ToHashableKey: ArrowNativeTypeOp { - /// The type that is hashable and equatable - /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self - /// this is just a helper trait so you can reuse the same code for all arrow native types - type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; - - /// Converts self to a hashable key - /// the result of this value can be used as the key in hash maps/sets - fn into_hashable_key(self) -> Self::HashableKey; -} - -macro_rules! impl_to_hashable_key { - (@single_already_hashable | $t:ty) => { - impl ToHashableKey for $t { - type HashableKey = $t; - - #[inline] - fn into_hashable_key(self) -> Self::HashableKey { - self - } - } - }; - (@already_hashable | $($t:ty),+ $(,)?) => { - $( - impl_to_hashable_key!(@single_already_hashable | $t); - )+ - }; - (@float | $t:ty => $hashable:ty) => { - impl ToHashableKey for $t { - type HashableKey = $hashable; - - #[inline] - fn into_hashable_key(self) -> Self::HashableKey { - self.to_bits() - } - } - }; -} - -impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); -impl_to_hashable_key!(@float | f16 => u16); -impl_to_hashable_key!(@float | f32 => u32); -impl_to_hashable_key!(@float | f64 => u64); - -#[cfg(test)] -mod tests { - use super::ToHashableKey; - use arrow::array::downcast_primitive; - - // This test ensure that all arrow primitive types implement ToHashableKey - // otherwise the code will not compile - #[test] - fn should_implement_to_hashable_key_for_all_primitives() { - #[derive(Debug, Default)] - struct ExampleSet - where - T: arrow::datatypes::ArrowPrimitiveType, - T::Native: ToHashableKey, - { - _map: std::collections::HashSet<::HashableKey>, - } - - macro_rules! create_matching_set { - ($t:ty) => {{ - let _lookup_table = ExampleSet::<$t> { - _map: Default::default() - }; - - return; - }}; - } - - let data_type = arrow::datatypes::DataType::Float16; - - downcast_primitive! { - data_type => (create_matching_set), - _ => panic!("not implemented for {data_type}"), - } - } -} diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs deleted file mode 100644 index e97ec9706612..000000000000 --- a/datafusion/physical-expr/src/expressions/case/literal_values/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod wrapper; -mod literal_lookup_table; - -pub(super) use literal_lookup_table::{LookupTable, try_creating_lookup_table}; \ No newline at end of file diff --git a/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs b/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs deleted file mode 100644 index 518e08023c28..000000000000 --- a/datafusion/physical-expr/src/expressions/case/literal_values/wrapper.rs +++ /dev/null @@ -1,348 +0,0 @@ -use arrow::array::ArrayRef; -use arrow::datatypes::Schema; -use std::ops::Deref; -use std::sync::Arc; -use datafusion_common::{internal_datafusion_err, plan_datafusion_err, ScalarValue}; -use datafusion_expr::expr::FieldMetadata; -use datafusion_expr_common::operator::Operator; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expressions::{BinaryExpr, Literal}; - -/// All the optimizations in this module are based on the assumption of this `CASE WHEN` pattern: -/// ```sql -/// CASE -/// WHEN ( = ) THEN -/// WHEN ( = ) THEN -/// WHEN ( = ) THEN -/// WHEN ( = ) THEN -/// ELSE -/// END -/// ``` -/// -/// all the `WHEN` expressions are equality comparisons on the same expression against literals, -/// and all the `THEN` expressions are literals -/// the expression `` can be any expression as long as it does not have any state (e.g. random number generator, current timestamp, etc.) -pub(super) struct CaseWhenLiteralMapping { - /// The expression that is being compared against the literals in the when clauses - /// This expression must be deterministic - /// In the example above this is `` - pub(super) expression_to_match_on: Arc, - - /// The literals that are being compared against the expression in the when clauses - /// In the example above this is ``, ``, ``, `` - /// These literals must all be of the same data type as the expression_to_match_on - pub(super) when_equality_literals: Vec>, - - /// The literals that are being returned in the then clauses - /// In the example above this is ``, ``, ``, `` - /// These literals must all be of the same data type - pub(super) then_literals: Vec>, - - /// The literal that is being returned in the else clause - /// In the example above this is `` - /// This literal must be of the same data type as the then_literals - /// - /// If no else clause is provided, this will be a null literal of the same data type as the then_literals - pub(super) else_expr: Arc, -} - -impl CaseWhenLiteralMapping { - /// Will return None if the optimization cannot be used - /// Otherwise will return the optimized expression - pub fn map_case_when( - when_then_pairs: Vec<(Arc, Arc)>, - else_expr: Option>, - input_schema: &Schema, - ) -> Option { - // let expression_to_match_on: Arc; - // let when_equality_literals: Vec>; - - // We can't use the optimization if we don't have any when then pairs - if when_then_pairs.is_empty() { - return None; - } - - // If we only have 1 than this optimization is not useful - if when_then_pairs.len() == 1 { - return None; - } - - let when_exprs = when_then_pairs - .iter() - .map(|(when, _)| Arc::clone(when)) - .collect::>(); - - // If any of the when expressions is not a binary expression we cannot use this optimization - - let when_binary_exprs = when_exprs - .iter() - .map(|when| { - let binary = when - .as_any() - .downcast_ref::(); - binary.cloned() - }) - .collect::>>()?; - let when_exprs = when_binary_exprs; - - // If not all the binary expression are equality we cannot use this optimization - if when_exprs - .iter() - .any(|when| !matches!(when.op(), Operator::Eq)) - { - return None; - } - - let expressions_to_match_on = when_exprs - .iter() - .map(|when| Arc::clone(when.left())) - .collect::>(); - - let first_expression_to_match_on = &expressions_to_match_on[0]; - - // Check if all expressions are the same - if expressions_to_match_on - .iter() - .any(|expr| !expr.dyn_eq(first_expression_to_match_on.deref().as_any())) - { - return None; - } - // TODO - Test that the expression is deterministic - let expression_to_match_on: Arc = - Arc::clone(first_expression_to_match_on); - - let equality_value_exprs = when_exprs - .iter() - .map(|when| when.right()) - .collect::>(); - - let when_equality_literals: Vec> = { - // TODO - spark should do constant folding but we should support expression on literal anyway - // Test that all of the expressions are literals - if equality_value_exprs - .iter() - .any(|expr| expr.as_any().downcast_ref::().is_none()) - { - return None; - } - - equality_value_exprs - .iter() - .map(|expr| { - let literal = expr.as_any().downcast_ref::().unwrap(); - let literal = Literal::new_with_metadata( - literal.value().clone(), - // Empty schema as it is not used by literal - Some(FieldMetadata::from( - literal.return_field(&Schema::empty()).unwrap().deref(), - )), - ); - - Arc::new(literal) - }) - .collect::>() - }; - - { - let Ok(data_type) = expression_to_match_on.data_type(input_schema) else { - return None; - }; - - if data_type != when_equality_literals[0].value().data_type() { - return None; - } - } - - let then_literals: Vec> = { - let then_literal_values = when_then_pairs - .iter() - .map(|(_, then)| then) - .collect::>(); - - // TODO - spark should do constant folding but we should support expression on literal anyway - // Test that all of the expressions are literals - if then_literal_values - .iter() - .any(|expr| expr.as_any().downcast_ref::().is_none()) - { - return None; - } - - then_literal_values - .iter() - .map(|expr| { - let literal = expr.as_any().downcast_ref::().unwrap(); - let literal = Literal::new_with_metadata( - literal.value().clone(), - // Empty schema as it is not used by literal - Some(FieldMetadata::from( - literal.return_field(&Schema::empty()).unwrap().deref(), - )), - ); - - Arc::new(literal) - }) - .collect::>() - }; - - let else_expr: Arc = if let Some(else_expr) = else_expr { - // TODO - spark should do constant folding but we should support expression on literal anyway - - let literal = else_expr.as_any().downcast_ref::()?; - let literal = Literal::new_with_metadata( - literal.value().clone(), - // Empty schema as it is not used by literal - Some(FieldMetadata::from( - literal.return_field(&Schema::empty()).unwrap().deref(), - )), - ); - - Arc::new(literal) - } else { - let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].value().data_type()) - else { - return None; - }; - Arc::new(Literal::new(null_scalar)) - }; - - let this = Self { - expression_to_match_on, - when_equality_literals, - then_literals, - else_expr, - }; - - this.assert_requirements_are_met().ok()?; - - Some(this) - } - - /// Assert that the requirements for the optimization are met so we can use it - pub fn assert_requirements_are_met(&self) -> datafusion_common::Result<()> { - // If expression_to_match_on is not deterministic we cannot use this optimization - // TODO - we need a way to check if an expression is deterministic - - if self.when_equality_literals.len() != self.then_literals.len() { - return Err(plan_datafusion_err!( - "when_equality_literals and then_literals must be the same length" - )); - } - - if self.when_equality_literals.is_empty() { - return Err(plan_datafusion_err!( - "when_equality_literals and then_literals cannot be empty" - )); - } - - // Assert that all when equality literals are the same type and no nulls - { - let data_type = self.when_equality_literals[0].value().data_type(); - - for when_lit in &self.when_equality_literals { - if when_lit.value().data_type() != data_type { - return Err(plan_datafusion_err!( - "All when_equality_literals must have the same data type, found {} and {}", - when_lit.value().data_type(), - data_type - )); - } - } - } - - // Assert that all output values are the same type - { - let data_type = self.then_literals[0].value().data_type(); - - for then_lit in &self.then_literals { - if then_lit.value().data_type() != data_type { - return Err(plan_datafusion_err!( - "All then_literals must have the same data type, found {} and {}", - then_lit.value().data_type(), - data_type - )); - } - } - - if self.else_expr.value().data_type() != data_type { - return Err(plan_datafusion_err!( - "else_expr must have the same data type as then_literals, found {} and {}", - self.else_expr.value().data_type(), - data_type - )); - } - } - - Ok(()) - } - - /// Return ArrayRef where array[i] = then_literals[i] - /// the last value in the array is the else_expr - pub fn build_dense_output_values(&self) -> datafusion_common::Result { - // Create the dictionary values array filled with the else value - let mut dictionary_values = vec![]; - - // Fill the dictionary values array with the then literals - for then_lit in self.then_literals.iter() { - dictionary_values.push(then_lit.value().clone()); - } - - // Add the else - dictionary_values.push(self.else_expr.value().clone()); - - let dictionary_values = ScalarValue::iter_to_array(dictionary_values)?; - - Ok(dictionary_values) - } - - /// Normalized all literal values to i128 to ease one-time computations - /// - /// this is i128 as we don't know if the input is signed or unsigned - /// as it can be used to validate the requirements of no negative, - /// and we don't want to lose information - pub fn get_when_literals_values_normalized_for_non_nullable_integer_literals( - &self, - ) -> datafusion_common::Result> { - self.when_equality_literals - .iter() - .map(|lit| lit.value()) - .map(|lit| { - if !lit.data_type().is_integer() { - return Err(plan_datafusion_err!( - "All when_equality_literals must be integer type, found {}", - lit.data_type() - )); - } - - if !lit.data_type().is_dictionary_key_type() { - return Err(plan_datafusion_err!( - "All when_equality_literals must be valid dictionary key type, found {}", - lit.data_type() - )); - } - - if lit.is_null() { - return Err(plan_datafusion_err!( - "All when_equality_literals must be non-null numeric types, found null" - )); - } - - match lit { - ScalarValue::Int8(Some(v)) => Ok(*v as i128), - ScalarValue::Int16(Some(v)) => Ok(*v as i128), - ScalarValue::Int32(Some(v)) => Ok(*v as i128), - ScalarValue::Int64(Some(v)) => Ok(*v as i128), - ScalarValue::UInt8(Some(v)) => Ok(*v as i128), - ScalarValue::UInt16(Some(v)) => Ok(*v as i128), - ScalarValue::UInt32(Some(v)) => Ok(*v as i128), - ScalarValue::UInt64(Some(v)) => Ok(*v as i128), - _ => Err(internal_datafusion_err!( - "dictionary key type is not supported {}, value: {}", - lit.data_type(), - lit - )), - } - }) - .collect::>>() - } -} diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 1695ad2cb231..31a16f844b6a 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -1,4 +1,4 @@ mod case; -mod literal_values; +mod literal_lookup_table; pub use case::*; From 9a74f79ea510bf9c731a1f60fc0d4c802c6dab6a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:10:04 +0300 Subject: [PATCH 04/22] extract and cleanup --- .../src/expressions/case/case.rs | 169 +------- .../boolean_lookup_table.rs | 8 +- .../bytes_like_lookup_table.rs | 34 +- .../case/literal_lookup_table/mod.rs | 399 +++++++++++++----- .../primitive_lookup_table.rs | 19 +- 5 files changed, 344 insertions(+), 285 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs index ac119455c7f3..e0feda5ba851 100644 --- a/datafusion/physical-expr/src/expressions/case/case.rs +++ b/datafusion/physical-expr/src/expressions/case/case.rs @@ -32,7 +32,7 @@ use datafusion_expr::ColumnarValue; use super::super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; -use crate::expressions::case::literal_values::{try_creating_lookup_table, LookupTable}; +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; pub(super) type WhenThen = (Arc, Arc); @@ -67,168 +67,11 @@ enum EvalMethod { ExpressionOrExpression, /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals - /// - /// `CASE WHEN` pattern on supported lookup types: - /// - /// This optimization applies to CASE expressions of the form: - /// ```sql - /// CASE - /// WHEN THEN - /// WHEN THEN - /// WHEN THEN - /// WHEN THEN - /// ELSE - /// END - /// ``` - /// - /// all the `WHEN` expressions are equality comparisons on the same expression against literals, - /// and all the `THEN` expressions are literals - /// the expression `` can be any expression as long as it does not have any state (e.g. random number generator, current timestamp, etc.) - /// - /// TODO - how to assert that the expression is stateless and deterministic - /// - /// # Improvement idea - /// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons - /// so it will use this optimization as well, e.g. - /// ```sql - /// -- Before - /// CASE - /// WHEN ( = ) THEN - /// WHEN ( in (, ) THEN - /// WHEN ( = ) THEN - /// ELSE - /// - /// -- After - /// CASE - /// WHEN ( = ) THEN - /// WHEN ( = ) THEN - /// WHEN ( = ) THEN - /// WHEN ( = ) THEN - /// ELSE - /// END - /// ``` - /// - WithExprScalarLookupTable(ScalarsOrNullLookup) -} - -#[derive(Debug)] -struct ScalarsOrNullLookup { - /// The lookup table to use for evaluating the CASE expression - lookup: Arc, - - /// ArrayRef where array[i] = then_literals[i] - /// the last value in the array is the else_expr - values_to_take_from: ArrayRef, + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable) } -impl ScalarsOrNullLookup { - pub fn maybe_new( - when_then_expr: &Vec, else_expr: &Option> - ) -> Option { - // We can't use the optimization if we don't have any when then pairs - if when_then_expr.is_empty() { - return None; - } - - // If we only have 1 than this optimization is not useful - if when_then_expr.len() == 1 { - return None; - } - - let when_then_exprs_maybe_literals = when_then_expr - .iter() - .map(|(when, then)| { - let when_maybe_literal = when.as_any().downcast_ref::(); - let then_maybe_literal = then.as_any().downcast_ref::(); - - when_maybe_literal.zip(then_maybe_literal) - }) - .collect::>(); - - // If not all the when/then expressions are literals we cannot use this optimization - if when_then_exprs_maybe_literals.contains(&None) { - return None; - } - - let (when_literals, then_literals): (Vec, Vec) = when_then_exprs_maybe_literals - .iter() - // Unwrap the options as we have already checked they are all Some - .flatten() - .map(|(when_lit, then_lit)| (when_lit.value().clone(), then_lit.value().clone())) - .unzip(); - - - let else_expr: ScalarValue = if let Some(else_expr) = else_expr { - let literal = else_expr.as_any().downcast_ref::()?; - - literal.value().clone() - } else { - let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].data_type()) - else { - return None; - }; - - null_scalar - }; - - { - let data_type = when_literals[0].data_type(); - - // If not all the when literals are the same data type we cannot use this optimization - if when_literals.iter().any(|l| l.data_type() != data_type) { - return None; - } - } - - { - let data_type = then_literals[0].data_type(); - - // If not all the then and the else literals are the same data type we cannot use this optimization - if then_literals.iter().any(|l| l.data_type() != data_type) { - return None; - } - - if else_expr.data_type() != data_type { - return None; - } - } - - - let output_array = ScalarValue::iter_to_array( - then_literals.iter() - // The else is in the end - .chain(std::iter::once(&else_expr)) - .cloned() - ).ok()?; - - let lookup = try_creating_lookup_table( - when_literals, - - // The else expression is in the end - output_array.len() as i32 - 1, - ).ok()?; - - Some(Self { - lookup, - values_to_take_from: output_array, - }) - } - - fn create_output(&self, expr_array: &ArrayRef) -> Result { - let take_indices = self.lookup.match_values(&expr_array)?; - - // Zero-copy conversion - let take_indices = Int32Array::from(take_indices); - - // An optimize version would depend on the type of the values_to_take_from - // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) - // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array - let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(output) - } -} /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with @@ -317,7 +160,7 @@ impl CaseExpr { fn find_best_eval_method(expr: &Option>, when_then_expr: &Vec, else_expr: &Option>) -> EvalMethod { if expr.is_some() { - if let Some(mapping) = ScalarsOrNullLookup::maybe_new(when_then_expr, else_expr) { + if let Some(mapping) = LiteralLookupTable::maybe_new(when_then_expr, else_expr) { return EvalMethod::WithExprScalarLookupTable(mapping); } @@ -645,7 +488,7 @@ impl CaseExpr { Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } - fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &ScalarsOrNullLookup) -> Result { + fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &LiteralLookupTable) -> Result { let expr = self.expr.as_ref().unwrap(); let evaluated_expression = expr.evaluate(batch)?; diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs index 4b3e0057db5c..fb4167d2fb1a 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -1,15 +1,15 @@ -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, AsArray}; use datafusion_common::ScalarValue; -use crate::expressions::case::literal_lookup_table::LookupTable; +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; #[derive(Clone, Debug)] -pub(super) struct BooleanLookupMap { +pub(super) struct BooleanIndexMap { true_index: i32, false_index: i32, null_index: i32, } -impl LookupTable for BooleanLookupMap { +impl WhenLiteralIndexMap for BooleanIndexMap { fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result where Self: Sized, diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index dfd49d3af86c..fb8db56a1961 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -4,16 +4,20 @@ use std::marker::PhantomData; use arrow::array::{ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, GenericByteArray, GenericByteViewArray, TypedDictionaryArray}; use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; use datafusion_common::{exec_datafusion_err, HashMap, ScalarValue}; -use crate::expressions::case::literal_lookup_table::LookupTable; +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; -trait BytesMapHelperWrapperTrait: Send + Sync { +/// Helper trait to convert various byte-like array types to iterator over byte slices +pub(super) trait BytesMapHelperWrapperTrait: Send + Sync { + /// Iterator over byte slices that will return type IntoIter<'a>: Iterator> + 'a; + + /// Convert the array to an iterator over byte slices fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; } #[derive(Debug, Clone, Default)] -struct GenericBytesHelper(PhantomData); +pub(super) struct GenericBytesHelper(PhantomData); impl BytesMapHelperWrapperTrait for GenericBytesHelper { type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; @@ -33,7 +37,7 @@ impl BytesMapHelperWrapperTrait for GenericBytesHelper { } #[derive(Debug, Clone, Default)] -struct FixedBinaryHelper; +pub(super) struct FixedBinaryHelper; impl BytesMapHelperWrapperTrait for FixedBinaryHelper { type IntoIter<'a> = FixedSizeBinaryIter<'a>; @@ -45,7 +49,7 @@ impl BytesMapHelperWrapperTrait for FixedBinaryHelper { #[derive(Debug, Clone, Default)] -struct GenericBytesViewHelper(PhantomData); +pub(super) struct GenericBytesViewHelper(PhantomData); impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; @@ -60,9 +64,8 @@ impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { } } - #[derive(Debug, Clone, Default)] -struct BytesDictionaryHelper(PhantomData<(Key, Value)>); +pub(super) struct BytesDictionaryHelper(PhantomData<(Key, Value)>); impl BytesMapHelperWrapperTrait for BytesDictionaryHelper where @@ -93,7 +96,7 @@ where } #[derive(Debug, Clone, Default)] -struct FixedBytesDictionaryHelper(PhantomData); +pub(super) struct FixedBytesDictionaryHelper(PhantomData); impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper where @@ -117,7 +120,7 @@ where } #[derive(Debug, Clone, Default)] -struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); +pub(super) struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper where @@ -147,17 +150,24 @@ where } } +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps #[derive(Clone)] -pub(super) struct BytesLookupTable { +pub(super) struct BytesLikeIndexMap { /// Map from non-null literal value the first occurrence index in the literals map: HashMap, i32>, + + /// The index for null literal value (when no null value this will equal to `else_index`) null_index: i32, + + /// The index to return when no match is found else_index: i32, _phantom_data: PhantomData, } -impl Debug for BytesLookupTable { +impl Debug for BytesLikeIndexMap { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("BytesMapHelper") .field("map", &self.map) @@ -167,7 +177,7 @@ impl Debug for BytesLookupTable { } } -impl LookupTable for BytesLookupTable { +impl WhenLiteralIndexMap for BytesLikeIndexMap { fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result where Self: Sized, diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 2e6ee2248e32..0f8813a1def3 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -2,38 +2,208 @@ mod boolean_lookup_table; mod bytes_like_lookup_table; mod primitive_lookup_table; +use datafusion_common::DataFusionError; +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::{ + BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper, + FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper, + GenericBytesViewHelper, +}; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveArrayMapHolder; +use arrow::array::{downcast_integer, downcast_primitive, ArrayRef, Int32Array}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, GenericBinaryType, + GenericStringType, StringViewType, +}; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; use std::fmt::Debug; use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use crate::expressions::case::case::WhenThen; +use crate::expressions::Literal; -/// Lookup table for mapping literal values to their corresponding indices +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( in (, ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// +/// -- After +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Arc, + + /// ArrayRef where array[i] = then_literals[i] + /// the last value in the array is the else_expr + values_to_take_from: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new( + when_then_expr: &Vec, else_expr: &Option> + ) -> Option { + // We can't use the optimization if we don't have any when then pairs + if when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if when_then_expr.len() == 1 { + return None; + } + + let when_then_exprs_maybe_literals = when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::(); + let then_maybe_literal = then.as_any().downcast_ref::(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); + + // If not all the when/then expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let (when_literals, then_literals): (Vec, Vec) = when_then_exprs_maybe_literals + .iter() + // Unwrap the options as we have already checked they are all Some + .flatten() + .map(|(when_lit, then_lit)| (when_lit.value().clone(), then_lit.value().clone())) + .unzip(); + + + let else_expr: ScalarValue = if let Some(else_expr) = else_expr { + let literal = else_expr.as_any().downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].data_type()) + else { + return None; + }; + + null_scalar + }; + + { + let data_type = when_literals[0].data_type(); + + // If not all the when literals are the same data type we cannot use this optimization + if when_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + } + + { + let data_type = then_literals[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_expr.data_type() != data_type { + return None; + } + } + + + let output_array = ScalarValue::iter_to_array( + then_literals.iter() + // The else is in the end + .chain(std::iter::once(&else_expr)) + .cloned() + ).ok()?; + + let lookup = try_creating_lookup_table( + when_literals, + + // The else expression is in the end + output_array.len() as i32 - 1, + ).ok()?; + + Some(Self { + lookup, + values_to_take_from: output_array, + }) + } + + pub(in super::super) fn create_output(&self, expr_array: &ArrayRef) -> datafusion_common::Result { + let take_indices = self.lookup.match_values(&expr_array)?; + + // Zero-copy conversion + let take_indices = Int32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + +/// Lookup table for mapping literal values to their corresponding indices in the THEN clauses /// /// The else index is used when a value is not found in the lookup table -pub(super) trait LookupTable: Debug + Send + Sync { - /// Try creating a new lookup table from the given literals and else index - fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result - where - Self: Sized; - - /// Return indices to take from the literals based on the values in the given array - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; +pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { + /// Try creating a new lookup table from the given literals and else index + fn try_new( + literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized; + + /// Return indices to take from the literals based on the values in the given array + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; } pub(crate) fn try_creating_lookup_table( - literals: Vec, - else_index: i32, -) -> datafusion_common::Result> { - assert_ne!(literals.len(), 0, "Must have at least one literal"); - match literals[0].data_type() { - DataType::Boolean => { - let lookup_table = BooleanLookupMap::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { + assert_ne!(literals.len(), 0, "Must have at least one literal"); + match literals[0].data_type() { + DataType::Boolean => { + let lookup_table = BooleanIndexMap::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - data_type if data_type.is_primitive() => { - macro_rules! create_matching_map { + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { ($t:ty) => {{ let lookup_table = PrimitiveArrayMapHolder::<$t>::try_new(literals, else_index)?; @@ -41,58 +211,66 @@ pub(crate) fn try_creating_lookup_table( }}; } - downcast_primitive! { + downcast_primitive! { data_type => (create_matching_map), _ => Err(plan_datafusion_err!( "Unsupported field type for primitive: {:?}", data_type )), } - } + } - DataType::Utf8 => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::Utf8 => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::LargeUtf8 => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::LargeUtf8 => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::Binary => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::Binary => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::LargeBinary => { - let lookup_table = - BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::LargeBinary => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::FixedSizeBinary(_) => { - let lookup_table = - BytesLookupTable::::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::FixedSizeBinary(_) => { + let lookup_table = + BytesLikeIndexMap::::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::Utf8View => { - let lookup_table = - BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - DataType::BinaryView => { - let lookup_table = - BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::Utf8View => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + literals, else_index, + )?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + literals, else_index, + )?; + Ok(Arc::new(lookup_table)) + } - DataType::Dictionary(key, value) => { - macro_rules! downcast_dictionary_array_helper { + DataType::Dictionary(key, value) => { + macro_rules! downcast_dictionary_array_helper { ($t:ty) => {{ create_lookup_table_for_dictionary_input::<$t>( value.as_ref(), @@ -102,60 +280,77 @@ pub(crate) fn try_creating_lookup_table( }}; } - downcast_integer! { + downcast_integer! { key.as_ref() => (downcast_dictionary_array_helper), k => unreachable!("unsupported dictionary key type: {}", k) } - } - _ => Err(plan_datafusion_err!( + } + _ => Err(plan_datafusion_err!( "Unsupported data type for lookup table: {}", literals[0].data_type() )), - } + } } fn create_lookup_table_for_dictionary_input( - value: &DataType, - literals: Vec, - else_index: i32, -) -> datafusion_common::Result> { - match value { - DataType::Utf8 => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + value: &DataType, + literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { - DataType::LargeUtf8 => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + // TODO - optimize dictionary to use different wrapper that takes advantage of it being a dictionary + match value { + DataType::Utf8 => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::Binary => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::LargeUtf8 => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::LargeBinary => { - let lookup_table = BytesLookupTable::>>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::Binary => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::FixedSizeBinary(_) => { - let lookup_table =BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } + DataType::LargeBinary => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } - DataType::Utf8View => { - let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - DataType::BinaryView => { - let lookup_table = BytesLookupTable::>::try_new(literals, else_index)?; - Ok(Arc::new(lookup_table)) - } - _ => Err(plan_datafusion_err!( + DataType::FixedSizeBinary(_) => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + literals, else_index, + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = BytesLikeIndexMap::< + BytesViewDictionaryHelper, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = BytesLikeIndexMap::< + BytesViewDictionaryHelper, + >::try_new(literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + _ => Err(plan_datafusion_err!( "Unsupported dictionary value type for lookup table: {}", value )), - } -} \ No newline at end of file + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs index c4380c7491da..fd7462868c75 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -4,10 +4,10 @@ use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; use half::f16; use datafusion_common::{HashMap, ScalarValue}; -use crate::expressions::case::literal_lookup_table::LookupTable; +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; #[derive(Clone)] -struct PrimitiveArrayMapHolder +pub(super) struct PrimitiveArrayMapHolder where T: ArrowPrimitiveType, T::Native: ToHashableKey, @@ -19,9 +19,20 @@ where else_index: i32, } +impl Debug for PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveArrayMapHolder") + .field("map", &self.map) + .field("else_index", &self.else_index) + .finish() + } +} - -impl LookupTable for PrimitiveArrayMapHolder +impl WhenLiteralIndexMap for PrimitiveArrayMapHolder where T: ArrowPrimitiveType, T::Native: ToHashableKey, From 98d2cca76ef4e6c9bd65b74cd296d00bda804305 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:19:12 +0300 Subject: [PATCH 05/22] finish --- .../src/expressions/case/case.rs | 1 - .../case/literal_lookup_table/mod.rs | 25 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs index e0feda5ba851..529587dc5190 100644 --- a/datafusion/physical-expr/src/expressions/case/case.rs +++ b/datafusion/physical-expr/src/expressions/case/case.rs @@ -72,7 +72,6 @@ enum EvalMethod { WithExprScalarLookupTable(LiteralLookupTable) } - /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 0f8813a1def3..a3bf9ad2abe8 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -17,6 +17,7 @@ use arrow::datatypes::{ }; use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expressions::case::case::WhenThen; @@ -66,6 +67,30 @@ pub(in super::super) struct LiteralLookupTable { values_to_take_from: ArrayRef, } +impl Hash for LiteralLookupTable { + fn hash(&self, state: &mut H) { + // Hashing the pointer as this is the best we can do here + + let lookup_ptr = Arc::as_ptr(&self.lookup); + lookup_ptr.hash(state); + + let values_ptr = Arc::as_ptr(&self.lookup); + values_ptr.hash(state); + } +} + +impl PartialEq for LiteralLookupTable { + fn eq(&self, other: &Self) -> bool { + // Comparing the pointers as this is the best we can do here + Arc::ptr_eq(&self.lookup, &other.lookup) && + &self.values_to_take_from == &other.values_to_take_from + } +} + +impl Eq for LiteralLookupTable { + +} + impl LiteralLookupTable { pub(in super::super) fn maybe_new( when_then_expr: &Vec, else_expr: &Option> From cfaf4a3d5e212f1c76ccc50d9ed067360d400096 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:27:12 +0300 Subject: [PATCH 06/22] cleanup --- .../src/expressions/case/case.rs | 1495 ---------------- .../boolean_lookup_table.rs | 4 +- .../bytes_like_lookup_table.rs | 15 +- .../case/literal_lookup_table/mod.rs | 6 +- .../primitive_lookup_table.rs | 2 +- .../physical-expr/src/expressions/case/mod.rs | 1497 ++++++++++++++++- 6 files changed, 1507 insertions(+), 1512 deletions(-) delete mode 100644 datafusion/physical-expr/src/expressions/case/case.rs diff --git a/datafusion/physical-expr/src/expressions/case/case.rs b/datafusion/physical-expr/src/expressions/case/case.rs deleted file mode 100644 index 529587dc5190..000000000000 --- a/datafusion/physical-expr/src/expressions/case/case.rs +++ /dev/null @@ -1,1495 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::expressions::try_cast; -use crate::PhysicalExpr; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - -use arrow::array::*; -use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; -use arrow::datatypes::{DataType, Schema}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::{arrow_datafusion_err, exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::ColumnarValue; - -use super::super::{Column, Literal}; -use datafusion_physical_expr_common::datum::compare_with_eq; -use itertools::Itertools; -use crate::expressions::case::literal_lookup_table::LiteralLookupTable; - -pub(super) type WhenThen = (Arc, Arc); - -#[derive(Debug, Hash, PartialEq, Eq)] -enum EvalMethod { - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - NoExpression, - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - WithExpression, - /// This is a specialization for a specific use case where we can take a fast path - /// for expressions that are infallible and can be cheaply computed for the entire - /// record batch rather than just for the rows where the predicate is true. - /// - /// CASE WHEN condition THEN column [ELSE NULL] END - InfallibleExprOrNull, - /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` expressions - /// are literal values - /// CASE WHEN condition THEN literal ELSE literal END - ScalarOrScalar, - /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions - /// - /// CASE WHEN condition THEN expression ELSE expression END - ExpressionOrExpression, - - /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals - /// - /// See [`LiteralLookupTable`] for more details - WithExprScalarLookupTable(LiteralLookupTable) -} - -/// The CASE expression is similar to a series of nested if/else and there are two forms that -/// can be used. The first form consists of a series of boolean "when" expressions with -/// corresponding "then" expressions, and an optional "else" expression. -/// -/// CASE WHEN condition THEN result -/// [WHEN ...] -/// [ELSE result] -/// END -/// -/// The second form uses a base expression and then a series of "when" clauses that match on a -/// literal value. -/// -/// CASE expression -/// WHEN value THEN result -/// [WHEN ...] -/// [ELSE result] -/// END -#[derive(Debug, Hash, PartialEq, Eq)] -pub struct CaseExpr { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec, - /// Optional "else" expression - else_expr: Option>, - /// Evaluation method to use - eval_method: EvalMethod, -} - -impl std::fmt::Display for CaseExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "CASE ")?; - if let Some(e) = &self.expr { - write!(f, "{e} ")?; - } - for (w, t) in &self.when_then_expr { - write!(f, "WHEN {w} THEN {t} ")?; - } - if let Some(e) = &self.else_expr { - write!(f, "ELSE {e} ")?; - } - write!(f, "END") - } -} - -/// This is a specialization for a specific use case where we can take a fast path -/// for expressions that are infallible and can be cheaply computed for the entire -/// record batch rather than just for the rows where the predicate is true. For now, -/// this is limited to use with Column expressions but could potentially be used for other -/// expressions in the future -fn is_cheap_and_infallible(expr: &Arc) -> bool { - expr.as_any().is::() -} - -impl CaseExpr { - /// Create a new CASE WHEN expression - pub fn try_new( - expr: Option>, - when_then_expr: Vec, - else_expr: Option>, - ) -> Result { - // normalize null literals to None in the else_expr (this already happens - // during SQL planning, but not necessarily for other use cases) - let else_expr = match &else_expr { - Some(e) => match e.as_any().downcast_ref::() { - Some(lit) if lit.value().is_null() => None, - _ => else_expr, - }, - _ => else_expr, - }; - - if when_then_expr.is_empty() { - exec_err!("There must be at least one WHEN clause") - } else { - let eval_method = Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); - - Ok(Self { - expr, - when_then_expr, - else_expr, - eval_method, - }) - } - } - - fn find_best_eval_method(expr: &Option>, when_then_expr: &Vec, else_expr: &Option>) -> EvalMethod { - if expr.is_some() { - if let Some(mapping) = LiteralLookupTable::maybe_new(when_then_expr, else_expr) { - return EvalMethod::WithExprScalarLookupTable(mapping); - } - - return EvalMethod::WithExpression - } - - if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if when_then_expr.len() == 1 - && when_then_expr[0].1.as_any().is::() - && else_expr.is_some() - && else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression - } else { - EvalMethod::NoExpression - } - } - - /// Optional base expression that can be compared to literal values in the "when" expressions - pub fn expr(&self) -> Option<&Arc> { - self.expr.as_ref() - } - - /// One or more when/then expressions - pub fn when_then_expr(&self) -> &[WhenThen] { - &self.when_then_expr - } - - /// Optional "else" expression - pub fn else_expr(&self) -> Option<&Arc> { - self.else_expr.as_ref() - } -} - -impl CaseExpr { - /// This function evaluates the form of CASE that matches an expression to fixed values. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - let expr = self.expr.as_ref().unwrap(); - let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows())?; - let base_nulls = is_null(base_value.as_ref())?; - - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - // We only consider non-null values while comparing with whens - let mut remainder = not(&base_nulls)?; - let mut non_null_remainder_count = remainder.true_count(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if non_null_remainder_count == 0 { - break; - } - - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - // build boolean array representing which rows match the "when" value - let when_match = compare_with_eq( - &when_value, - &base_value, - // The types of case and when expressions will be coerced to match. - // We only need to check if the base_value is nested. - base_value.data_type().is_nested(), - )?; - // Treat nulls as false - let when_match = match when_match.null_count() { - 0 => Cow::Borrowed(&when_match), - _ => Cow::Owned(prep_null_mask_filter(&when_match)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_match, &remainder)?; - - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { - continue; - } - - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; - - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? - } - }; - - remainder = and_not(&remainder, &when_value)?; - non_null_remainder_count -= when_match_count; - } - - if let Some(e) = self.else_expr() { - // null and unmatched tuples should be assigned else value - remainder = or(&base_nulls, &remainder)?; - - if remainder.true_count() > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } - } - - Ok(ColumnarValue::Array(current_value)) - } - - /// This function evaluates the form of CASE where each WHEN expression is a boolean - /// expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - let mut remainder_count = batch.num_rows(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if remainder_count == 0 { - break; - } - - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|_| { - internal_datafusion_err!("WHEN expression did not return a BooleanArray") - })?; - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_value, &remainder)?; - - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { - continue; - } - - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; - - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? - } - }; - - // Succeed tuples should be filtered out for short-circuit evaluation, - // null values for the current when expr should be kept - remainder = and_not(&remainder, &when_value)?; - remainder_count -= when_match_count; - } - - if let Some(e) = self.else_expr() { - if remainder_count > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } - } - - Ok(ColumnarValue::Array(current_value)) - } - - /// This function evaluates the specialized case of: - /// - /// CASE WHEN condition THEN column - /// [ELSE NULL] - /// END - /// - /// Note that this function is only safe to use for "then" expressions - /// that are infallible because the expression will be evaluated for all - /// rows in the input batch. - fn case_column_or_null(&self, batch: &RecordBatch) -> Result { - let when_expr = &self.when_then_expr[0].0; - let then_expr = &self.when_then_expr[0].1; - - match when_expr.evaluate(batch)? { - // WHEN true --> column - ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => { - then_expr.evaluate(batch) - } - // WHEN [false | null] --> NULL - ColumnarValue::Scalar(_) => { - // return scalar NULL value - ScalarValue::try_from(self.data_type(&batch.schema())?) - .map(ColumnarValue::Scalar) - } - // WHEN column --> column - ColumnarValue::Array(bit_mask) => { - let bit_mask = bit_mask - .as_any() - .downcast_ref::() - .expect("predicate should evaluate to a boolean array"); - // invert the bitmask - let bit_mask = match bit_mask.null_count() { - 0 => not(bit_mask)?, - _ => not(&prep_null_mask_filter(bit_mask))?, - }; - match then_expr.evaluate(batch)? { - ColumnarValue::Array(array) => { - Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) - } - ColumnarValue::Scalar(_) => { - internal_err!("expression did not evaluate to an array") - } - } - } - } - } - - fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // evaluate when expression - let when_value = self.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|_| { - internal_datafusion_err!("WHEN expression did not return a BooleanArray") - })?; - - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - - // evaluate then_value - let then_value = self.when_then_expr[0].1.evaluate(batch)?; - let then_value = Scalar::new(then_value.into_array(1)?); - - let Some(e) = self.else_expr() else { - return internal_err!("expression did not evaluate to an array"); - }; - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; - let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); - Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) - } - - fn expr_or_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // evaluate when condition on batch - let when_value = self.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) - })?; - - // For the true and false/null selection vectors, bypass `evaluate_selection` and merging - // results. This avoids materializing the array for the other branch which we will discard - // entirely anyway. - let true_count = when_value.true_count(); - if true_count == batch.num_rows() { - return self.when_then_expr[0].1.evaluate(batch); - } else if true_count == 0 { - return self.else_expr.as_ref().unwrap().evaluate(batch); - } - - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - - let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; - - // evaluate else expression on the values not covered by when_value - let remainder = not(&when_value)?; - let e = self.else_expr.as_ref().unwrap(); - - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - - Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) - } - - fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &LiteralLookupTable) -> Result { - let expr = self.expr.as_ref().unwrap(); - let evaluated_expression = expr.evaluate(batch)?; - - let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); - let evaluated_expression = evaluated_expression.to_array(1)?; - - let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; - - let result = if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) - } else { - ColumnarValue::Array(output) - }; - - Ok(result) - } -} - -impl PhysicalExpr for CaseExpr { - /// Return a reference to Any that can be used for down-casting - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> Result { - // since all then results have the same data type, we can choose any one as the - // return data type except for the null. - let mut data_type = DataType::Null; - for i in 0..self.when_then_expr.len() { - data_type = self.when_then_expr[i].1.data_type(input_schema)?; - if !data_type.equals_datatype(&DataType::Null) { - break; - } - } - // if all then results are null, we use data type of else expr instead if possible. - if data_type.equals_datatype(&DataType::Null) { - if let Some(e) = &self.else_expr { - data_type = e.data_type(input_schema)?; - } - } - - Ok(data_type) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = self - .when_then_expr - .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) - } else if let Some(e) = &self.else_expr { - e.nullable(input_schema) - } else { - // CASE produces NULL if there is no `else` expr - // (aka when none of the `when_then_exprs` match) - Ok(true) - } - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - match self.eval_method { - EvalMethod::WithExpression => { - // this use case evaluates "expr" and then compares the values with the "when" - // values - self.case_when_with_expr(batch) - } - EvalMethod::NoExpression => { - // The "when" conditions all evaluate to boolean in this use case and can be - // arbitrary expressions - self.case_when_no_expr(batch) - } - EvalMethod::InfallibleExprOrNull => { - // Specialization for CASE WHEN expr THEN column [ELSE NULL] END - self.case_column_or_null(batch) - } - EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), - EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), - EvalMethod::WithExprScalarLookupTable(ref e) => self.with_lookup_table(batch, e), - } - } - - fn children(&self) -> Vec<&Arc> { - let mut children = vec![]; - if let Some(expr) = &self.expr { - children.push(expr) - } - self.when_then_expr.iter().for_each(|(cond, value)| { - children.push(cond); - children.push(value); - }); - - if let Some(else_expr) = &self.else_expr { - children.push(else_expr) - } - children - } - - // For physical CaseExpr, we do not allow modifying children size - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - if children.len() != self.children().len() { - internal_err!("CaseExpr: Wrong number of children") - } else { - let (expr, when_then_expr, else_expr) = - match (self.expr().is_some(), self.else_expr().is_some()) { - (true, true) => ( - Some(&children[0]), - &children[1..children.len() - 1], - Some(&children[children.len() - 1]), - ), - (true, false) => { - (Some(&children[0]), &children[1..children.len()], None) - } - (false, true) => ( - None, - &children[0..children.len() - 1], - Some(&children[children.len() - 1]), - ), - (false, false) => (None, &children[0..children.len()], None), - }; - Ok(Arc::new(CaseExpr::try_new( - expr.cloned(), - when_then_expr.iter().cloned().tuples().collect(), - else_expr.cloned(), - )?)) - } - } - - fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CASE ")?; - if let Some(e) = &self.expr { - e.fmt_sql(f)?; - write!(f, " ")?; - } - - for (w, t) in &self.when_then_expr { - write!(f, "WHEN ")?; - w.fmt_sql(f)?; - write!(f, " THEN ")?; - t.fmt_sql(f)?; - write!(f, " ")?; - } - - if let Some(e) = &self.else_expr { - write!(f, "ELSE ")?; - e.fmt_sql(f)?; - write!(f, " ")?; - } - write!(f, "END") - } -} - -/// Create a CASE expression -pub fn case( - expr: Option>, - when_thens: Vec, - else_expr: Option>, -) -> Result> { - Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; - use arrow::buffer::Buffer; - use arrow::datatypes::DataType::Float64; - use arrow::datatypes::Field; - use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::plan_err; - use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; - use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; - use datafusion_physical_expr_common::physical_expr::fmt_sql; - - #[test] - fn case_with_expr() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_else() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_divide_by_zero() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - - // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END - let when1 = lit(0i32); - let then1 = lit(ScalarValue::Float64(None)); - let else_value = binary( - lit(25.0f64), - Operator::Divide, - cast(col("a", &schema)?, &batch.schema(), Float64)?, - &batch.schema(), - )?; - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_without_expr() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_when_null() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END - let when1 = lit(ScalarValue::Utf8(None)); - let then1 = lit(0i32); - let when2 = col("a", &schema)?; - let then2 = lit(123i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_without_expr_divide_by_zero() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - - // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END - let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?; - let then1 = binary( - lit(25.0f64), - Operator::Divide, - cast(col("a", &schema)?, &batch.schema(), Float64)?, - &batch.schema(), - )?; - let x = lit(ScalarValue::Float64(None)); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1)], - Some(x), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - fn case_test_batch1() -> Result { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); - let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]); - let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]); - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], - )?; - Ok(batch) - } - - #[test] - fn case_without_expr_else() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_type_cast() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END - let when = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then = lit(123.3f64); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_matches_and_nulls() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE WHEN load4 = 1.77 THEN load4 END - let when = binary( - col("load4", &schema)?, - Operator::Eq, - lit(1.77f64), - &batch.schema(), - )?; - let then = col("load4", &schema)?; - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_scalar_predicate() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE WHEN TRUE THEN load4 END - let when = lit(true); - let then = col("load4", &schema)?; - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - None, - schema.as_ref(), - )?; - - // many rows - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - let expected = &Float64Array::from(vec![ - Some(1.77), - None, - None, - Some(1.78), - None, - Some(1.77), - ]); - assert_eq!(expected, result); - - // one row - let expected = Float64Array::from(vec![Some(1.1)]); - let batch = - RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - assert_eq!(&expected, result); - - Ok(()) - } - - #[test] - fn case_expr_matches_and_nulls() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE load4 WHEN 1.77 THEN load4 END - let expr = col("load4", &schema)?; - let when = lit(1.77f64); - let then = col("load4", &schema)?; - - let expr = generate_case_when_with_type_coercion( - Some(expr), - vec![(when, then)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn test_when_null_and_some_cond_else_null() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - let when = binary( - Arc::new(Literal::new(ScalarValue::Boolean(None))), - Operator::And, - binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, - &schema, - )?; - let then = col("a", &schema)?; - - // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END - let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_string_array(&result); - - // all result values should be null - assert_eq!(result.logical_null_count(), batch.num_rows()); - Ok(()) - } - - fn case_test_batch() -> Result { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - Ok(batch) - } - - // Construct an array that has several NULL values whose - // underlying buffer actually matches the where expr predicate - fn case_test_batch_nulls() -> Result { - let load4: Float64Array = vec![ - Some(1.77), // 1.77 - Some(1.77), // null <-- same value, but will be set to null - Some(1.77), // null <-- same value, but will be set to null - Some(1.78), // 1.78 - None, // null - Some(1.77), // 1.77 - ] - .into_iter() - .collect(); - - let null_buffer = Buffer::from([0b00101001u8]); - let load4 = load4 - .into_data() - .into_builder() - .null_bit_buffer(Some(null_buffer)) - .build() - .unwrap(); - let load4: Float64Array = load4.into(); - - let batch = - RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; - Ok(batch) - } - - #[test] - fn case_test_incompatible() -> Result<()> { - // 1 then is int64 - // 2 then is boolean - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(true); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - ); - assert!(expr.is_err()); - - // then 1 is int32 - // then 2 is int64 - // else is float - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i64); - let else_expr = lit(1.23f64); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - Some(else_expr), - schema.as_ref(), - ); - assert!(expr.is_ok()); - let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(Float64, result_type); - Ok(()) - } - - #[test] - fn case_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr1 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr2 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr3 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], - None, - &schema, - )?; - - let expr4 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1)], - Some(else_value), - &schema, - )?; - - assert!(expr1.eq(&expr2)); - assert!(expr2.eq(&expr1)); - - assert!(expr2.ne(&expr3)); - assert!(expr3.ne(&expr2)); - - assert!(expr1.ne(&expr4)); - assert!(expr4.ne(&expr1)); - - Ok(()) - } - - #[test] - fn case_transform() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr2 = Arc::clone(&expr) - .transform(|e| { - let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, - _ => None, - }; - Ok(if let Some(transformed) = transformed { - Transformed::yes(transformed) - } else { - Transformed::no(e) - }) - }) - .data() - .unwrap(); - - let expr3 = Arc::clone(&expr) - .transform_down(|e| { - let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, - _ => None, - }; - Ok(if let Some(transformed) = transformed { - Transformed::yes(transformed) - } else { - Transformed::no(e) - }) - }) - .data() - .unwrap(); - - assert!(expr.ne(&expr2)); - assert!(expr2.eq(&expr3)); - - Ok(()) - } - - #[test] - fn test_column_or_null_specialization() -> Result<()> { - // create input data - let mut c1 = Int32Builder::new(); - let mut c2 = StringBuilder::new(); - for i in 0..1000 { - c1.append_value(i); - if i % 7 == 0 { - c2.append_null(); - } else { - c2.append_value(format!("string {i}")); - } - } - let c1 = Arc::new(c1.finish()); - let c2 = Arc::new(c2.finish()); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); - - // CaseWhenExprOrNull should produce same results as CaseExpr - let predicate = Arc::new(BinaryExpr::new( - make_col("c1", 0), - Operator::LtEq, - make_lit_i32(250), - )); - let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; - assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); - match expr.evaluate(&batch)? { - ColumnarValue::Array(array) => { - assert_eq!(1000, array.len()); - assert_eq!(785, array.null_count()); - } - _ => unreachable!(), - } - Ok(()) - } - - #[test] - fn test_expr_or_expr_specialization() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - let when = binary( - col("a", &schema)?, - Operator::LtEq, - lit(2i32), - &batch.schema(), - )?; - let then = col("b", &schema)?; - let else_expr = col("c", &schema)?; - let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; - assert!(matches!( - expr.eval_method, - EvalMethod::ExpressionOrExpression - )); - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); - - let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]); - - assert_eq!(expected, result); - Ok(()) - } - - fn make_col(name: &str, index: usize) -> Arc { - Arc::new(Column::new(name, index)) - } - - fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) - } - - fn generate_case_when_with_type_coercion( - expr: Option>, - when_thens: Vec, - else_expr: Option>, - input_schema: &Schema, - ) -> Result> { - let coerce_type = - get_case_common_type(&when_thens, else_expr.clone(), input_schema); - let (when_thens, else_expr) = match coerce_type { - None => plan_err!( - "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" - ), - Some(data_type) => { - // cast then expr - let left = when_thens - .into_iter() - .map(|(when, then)| { - let then = try_cast(then, input_schema, data_type.clone())?; - Ok((when, then)) - }) - .collect::>>()?; - let right = match else_expr { - None => None, - Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), - }; - - Ok((left, right)) - } - }?; - case(expr, when_thens, else_expr) - } - - fn get_case_common_type( - when_thens: &[WhenThen], - else_expr: Option>, - input_schema: &Schema, - ) -> Option { - let thens_type = when_thens - .iter() - .map(|when_then| { - let data_type = &when_then.1.data_type(input_schema).unwrap(); - data_type.clone() - }) - .collect::>(); - let else_type = match else_expr { - None => { - // case when then exprs must have one then value - thens_type[0].clone() - } - Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), - }; - thens_type - .iter() - .try_fold(else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) - }) - } - - #[test] - fn test_fmt_sql() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - - // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END - let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?; - let then = lit(123.3f64); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - Some(else_value), - &schema, - )?; - - let display_string = expr.to_string(); - assert_eq!( - display_string, - "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" - ); - - let sql_string = fmt_sql(expr.as_ref()).to_string(); - assert_eq!( - sql_string, - "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" - ); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs index fb4167d2fb1a..427b0405fd2b 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -20,7 +20,7 @@ impl WhenLiteralIndexMap for BooleanIndexMap { ) -> Option { literals .iter() - .position(|literal| matches!(literal, ScalarValue::Boolean(target))) + .position(|literal| matches!(literal, ScalarValue::Boolean(v) if v == &target)) .map(|pos| pos as i32) } @@ -44,4 +44,4 @@ impl WhenLiteralIndexMap for BooleanIndexMap { .collect::>() ) } -} \ No newline at end of file +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index fb8db56a1961..3aff3e945714 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -10,7 +10,7 @@ use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; pub(super) trait BytesMapHelperWrapperTrait: Send + Sync { /// Iterator over byte slices that will return type IntoIter<'a>: Iterator> + 'a; - + /// Convert the array to an iterator over byte slices fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; } @@ -151,16 +151,16 @@ where } /// Map from byte-like literal values to their first occurrence index -/// +/// /// This is a wrapper for handling different kinds of literal maps #[derive(Clone)] pub(super) struct BytesLikeIndexMap { /// Map from non-null literal value the first occurrence index in the literals map: HashMap, i32>, - + /// The index for null literal value (when no null value this will equal to `else_index`) null_index: i32, - + /// The index to return when no match is found else_index: i32, @@ -192,10 +192,8 @@ impl WhenLiteralIndexMap for BytesLikeIndexM for (map_index, value) in bytes_iter.enumerate() { match value { Some(value) => { - let slice_value: &[u8] = value.as_ref(); - // Insert only the first occurrence - map.entry(slice_value.to_vec()).or_insert(map_index as i32); + map.entry(value.to_vec()).or_insert(map_index as i32); } None => { // Only set the null index once @@ -220,8 +218,7 @@ impl WhenLiteralIndexMap for BytesLikeIndexM .map(|value| { match value { Some(value) => { - let slice_value: &[u8] = value.as_ref(); - self.map.get(slice_value).copied().unwrap_or(self.else_index) + self.map.get(value).copied().unwrap_or(self.else_index) } None => { self.null_index diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index a3bf9ad2abe8..b56d6590b037 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expressions::case::case::WhenThen; +use crate::expressions::case::WhenThen; use crate::expressions::Literal; /// Optimization for CASE expressions with literal WHEN and THEN clauses @@ -83,7 +83,7 @@ impl PartialEq for LiteralLookupTable { fn eq(&self, other: &Self) -> bool { // Comparing the pointers as this is the best we can do here Arc::ptr_eq(&self.lookup, &other.lookup) && - &self.values_to_take_from == &other.values_to_take_from + self.values_to_take_from.as_ref() == other.values_to_take_from.as_ref() } } @@ -185,7 +185,7 @@ impl LiteralLookupTable { } pub(in super::super) fn create_output(&self, expr_array: &ArrayRef) -> datafusion_common::Result { - let take_indices = self.lookup.match_values(&expr_array)?; + let take_indices = self.lookup.match_values(expr_array)?; // Zero-copy conversion let take_indices = Int32Array::from(take_indices); diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs index fd7462868c75..ec58bf31a661 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -69,7 +69,7 @@ where /// Trait that help convert a value to a key that is hashable and equatable /// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly -trait ToHashableKey: ArrowNativeTypeOp { +pub(super) trait ToHashableKey: ArrowNativeTypeOp { /// The type that is hashable and equatable /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self /// this is just a helper trait so you can reuse the same code for all arrow native types diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 31a16f844b6a..e7947164b068 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -1,4 +1,1497 @@ -mod case; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + mod literal_lookup_table; -pub use case::*; +use crate::expressions::{try_cast, Column, Literal}; +use crate::PhysicalExpr; +use std::borrow::Cow; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +use arrow::array::*; +use arrow::compute::kernels::zip::zip; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::ColumnarValue; + +use datafusion_physical_expr_common::datum::compare_with_eq; +use itertools::Itertools; +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; + +pub(super) type WhenThen = (Arc, Arc); + +#[derive(Debug, Hash, PartialEq, Eq)] +enum EvalMethod { + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + NoExpression, + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + WithExpression, + /// This is a specialization for a specific use case where we can take a fast path + /// for expressions that are infallible and can be cheaply computed for the entire + /// record batch rather than just for the rows where the predicate is true. + /// + /// CASE WHEN condition THEN column [ELSE NULL] END + InfallibleExprOrNull, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` expressions + /// are literal values + /// CASE WHEN condition THEN literal ELSE literal END + ScalarOrScalar, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` are expressions + /// + /// CASE WHEN condition THEN expression ELSE expression END + ExpressionOrExpression, + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable) +} + +/// The CASE expression is similar to a series of nested if/else and there are two forms that +/// can be used. The first form consists of a series of boolean "when" expressions with +/// corresponding "then" expressions, and an optional "else" expression. +/// +/// CASE WHEN condition THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// +/// The second form uses a base expression and then a series of "when" clauses that match on a +/// literal value. +/// +/// CASE expression +/// WHEN value THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct CaseExpr { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec, + /// Optional "else" expression + else_expr: Option>, + /// Evaluation method to use + eval_method: EvalMethod, +} + +impl std::fmt::Display for CaseExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CASE ")?; + if let Some(e) = &self.expr { + write!(f, "{e} ")?; + } + for (w, t) in &self.when_then_expr { + write!(f, "WHEN {w} THEN {t} ")?; + } + if let Some(e) = &self.else_expr { + write!(f, "ELSE {e} ")?; + } + write!(f, "END") + } +} + +/// This is a specialization for a specific use case where we can take a fast path +/// for expressions that are infallible and can be cheaply computed for the entire +/// record batch rather than just for the rows where the predicate is true. For now, +/// this is limited to use with Column expressions but could potentially be used for other +/// expressions in the future +fn is_cheap_and_infallible(expr: &Arc) -> bool { + expr.as_any().is::() +} + +impl CaseExpr { + /// Create a new CASE WHEN expression + pub fn try_new( + expr: Option>, + when_then_expr: Vec, + else_expr: Option>, + ) -> Result { + // normalize null literals to None in the else_expr (this already happens + // during SQL planning, but not necessarily for other use cases) + let else_expr = match &else_expr { + Some(e) => match e.as_any().downcast_ref::() { + Some(lit) if lit.value().is_null() => None, + _ => else_expr, + }, + _ => else_expr, + }; + + if when_then_expr.is_empty() { + exec_err!("There must be at least one WHEN clause") + } else { + let eval_method = Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); + + Ok(Self { + expr, + when_then_expr, + else_expr, + eval_method, + }) + } + } + + fn find_best_eval_method(expr: &Option>, when_then_expr: &Vec, else_expr: &Option>) -> EvalMethod { + if expr.is_some() { + if let Some(mapping) = LiteralLookupTable::maybe_new(when_then_expr, else_expr) { + return EvalMethod::WithExprScalarLookupTable(mapping); + } + + return EvalMethod::WithExpression + } + + if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if when_then_expr.len() == 1 && else_expr.is_some() { + EvalMethod::ExpressionOrExpression + } else { + EvalMethod::NoExpression + } + } + + /// Optional base expression that can be compared to literal values in the "when" expressions + pub fn expr(&self) -> Option<&Arc> { + self.expr.as_ref() + } + + /// One or more when/then expressions + pub fn when_then_expr(&self) -> &[WhenThen] { + &self.when_then_expr + } + + /// Optional "else" expression + pub fn else_expr(&self) -> Option<&Arc> { + self.else_expr.as_ref() + } +} + +impl CaseExpr { + /// This function evaluates the form of CASE that matches an expression to fixed values. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + let expr = self.expr.as_ref().unwrap(); + let base_value = expr.evaluate(batch)?; + let base_value = base_value.into_array(batch.num_rows())?; + let base_nulls = is_null(base_value.as_ref())?; + + // start with nulls as default output + let mut current_value = new_null_array(&return_type, batch.num_rows()); + // We only consider non-null values while comparing with whens + let mut remainder = not(&base_nulls)?; + let mut non_null_remainder_count = remainder.true_count(); + for i in 0..self.when_then_expr.len() { + // If there are no rows left to process, break out of the loop early + if non_null_remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; + let when_value = when_value.into_array(batch.num_rows())?; + // build boolean array representing which rows match the "when" value + let when_match = compare_with_eq( + &when_value, + &base_value, + // The types of case and when expressions will be coerced to match. + // We only need to check if the base_value is nested. + base_value.data_type().is_nested(), + )?; + // Treat nulls as false + let when_match = match when_match.null_count() { + 0 => Cow::Borrowed(&when_match), + _ => Cow::Owned(prep_null_mask_filter(&when_match)), + }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_match, &remainder)?; + + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { + continue; + } + + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; + + current_value = match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => { + nullif(current_value.as_ref(), &when_value)? + } + ColumnarValue::Scalar(then_value) => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + ColumnarValue::Array(then_value) => { + zip(&when_value, &then_value, ¤t_value)? + } + }; + + remainder = and_not(&remainder, &when_value)?; + non_null_remainder_count -= when_match_count; + } + + if let Some(e) = self.else_expr() { + // null and unmatched tuples should be assigned else value + remainder = or(&base_nulls, &remainder)?; + + if remainder.true_count() > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } + } + + Ok(ColumnarValue::Array(current_value)) + } + + /// This function evaluates the form of CASE where each WHEN expression is a boolean + /// expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // start with nulls as default output + let mut current_value = new_null_array(&return_type, batch.num_rows()); + let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_count = batch.num_rows(); + for i in 0..self.when_then_expr.len() { + // If there are no rows left to process, break out of the loop early + if remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_value, &remainder)?; + + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { + continue; + } + + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; + + current_value = match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => { + nullif(current_value.as_ref(), &when_value)? + } + ColumnarValue::Scalar(then_value) => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + ColumnarValue::Array(then_value) => { + zip(&when_value, &then_value, ¤t_value)? + } + }; + + // Succeed tuples should be filtered out for short-circuit evaluation, + // null values for the current when expr should be kept + remainder = and_not(&remainder, &when_value)?; + remainder_count -= when_match_count; + } + + if let Some(e) = self.else_expr() { + if remainder_count > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } + } + + Ok(ColumnarValue::Array(current_value)) + } + + /// This function evaluates the specialized case of: + /// + /// CASE WHEN condition THEN column + /// [ELSE NULL] + /// END + /// + /// Note that this function is only safe to use for "then" expressions + /// that are infallible because the expression will be evaluated for all + /// rows in the input batch. + fn case_column_or_null(&self, batch: &RecordBatch) -> Result { + let when_expr = &self.when_then_expr[0].0; + let then_expr = &self.when_then_expr[0].1; + + match when_expr.evaluate(batch)? { + // WHEN true --> column + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => { + then_expr.evaluate(batch) + } + // WHEN [false | null] --> NULL + ColumnarValue::Scalar(_) => { + // return scalar NULL value + ScalarValue::try_from(self.data_type(&batch.schema())?) + .map(ColumnarValue::Scalar) + } + // WHEN column --> column + ColumnarValue::Array(bit_mask) => { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = match bit_mask.null_count() { + 0 => not(bit_mask)?, + _ => not(&prep_null_mask_filter(bit_mask))?, + }; + match then_expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) + } + ColumnarValue::Scalar(_) => { + internal_err!("expression did not evaluate to an array") + } + } + } + } + } + + fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when expression + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + let Some(e) = self.else_expr() else { + return internal_err!("expression did not evaluate to an array"); + }; + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } + + fn expr_or_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when condition on batch + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; + + // For the true and false/null selection vectors, bypass `evaluate_selection` and merging + // results. This avoids materializing the array for the other branch which we will discard + // entirely anyway. + let true_count = when_value.true_count(); + if true_count == batch.num_rows() { + return self.when_then_expr[0].1.evaluate(batch); + } else if true_count == 0 { + return self.else_expr.as_ref().unwrap().evaluate(batch); + } + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + let then_value = self.when_then_expr[0] + .1 + .evaluate_selection(batch, &when_value)? + .into_array(batch.num_rows())?; + + // evaluate else expression on the values not covered by when_value + let remainder = not(&when_value)?; + let e = self.else_expr.as_ref().unwrap(); + + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + + Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) + } + + fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &LiteralLookupTable) -> Result { + let expr = self.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; + + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; + + let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) + } else { + ColumnarValue::Array(output) + }; + + Ok(result) + } +} + +impl PhysicalExpr for CaseExpr { + /// Return a reference to Any that can be used for down-casting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + // since all then results have the same data type, we can choose any one as the + // return data type except for the null. + let mut data_type = DataType::Null; + for i in 0..self.when_then_expr.len() { + data_type = self.when_then_expr[i].1.data_type(input_schema)?; + if !data_type.equals_datatype(&DataType::Null) { + break; + } + } + // if all then results are null, we use data type of else expr instead if possible. + if data_type.equals_datatype(&DataType::Null) { + if let Some(e) = &self.else_expr { + data_type = e.data_type(input_schema)?; + } + } + + Ok(data_type) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + // this expression is nullable if any of the input expressions are nullable + let then_nullable = self + .when_then_expr + .iter() + .map(|(_, t)| t.nullable(input_schema)) + .collect::>>()?; + if then_nullable.contains(&true) { + Ok(true) + } else if let Some(e) = &self.else_expr { + e.nullable(input_schema) + } else { + // CASE produces NULL if there is no `else` expr + // (aka when none of the `when_then_exprs` match) + Ok(true) + } + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + match self.eval_method { + EvalMethod::WithExpression => { + // this use case evaluates "expr" and then compares the values with the "when" + // values + self.case_when_with_expr(batch) + } + EvalMethod::NoExpression => { + // The "when" conditions all evaluate to boolean in this use case and can be + // arbitrary expressions + self.case_when_no_expr(batch) + } + EvalMethod::InfallibleExprOrNull => { + // Specialization for CASE WHEN expr THEN column [ELSE NULL] END + self.case_column_or_null(batch) + } + EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), + EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), + EvalMethod::WithExprScalarLookupTable(ref e) => self.with_lookup_table(batch, e), + } + } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![]; + if let Some(expr) = &self.expr { + children.push(expr) + } + self.when_then_expr.iter().for_each(|(cond, value)| { + children.push(cond); + children.push(value); + }); + + if let Some(else_expr) = &self.else_expr { + children.push(else_expr) + } + children + } + + // For physical CaseExpr, we do not allow modifying children size + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != self.children().len() { + internal_err!("CaseExpr: Wrong number of children") + } else { + let (expr, when_then_expr, else_expr) = + match (self.expr().is_some(), self.else_expr().is_some()) { + (true, true) => ( + Some(&children[0]), + &children[1..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (true, false) => { + (Some(&children[0]), &children[1..children.len()], None) + } + (false, true) => ( + None, + &children[0..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (false, false) => (None, &children[0..children.len()], None), + }; + Ok(Arc::new(CaseExpr::try_new( + expr.cloned(), + when_then_expr.iter().cloned().tuples().collect(), + else_expr.cloned(), + )?)) + } + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CASE ")?; + if let Some(e) = &self.expr { + e.fmt_sql(f)?; + write!(f, " ")?; + } + + for (w, t) in &self.when_then_expr { + write!(f, "WHEN ")?; + w.fmt_sql(f)?; + write!(f, " THEN ")?; + t.fmt_sql(f)?; + write!(f, " ")?; + } + + if let Some(e) = &self.else_expr { + write!(f, "ELSE ")?; + e.fmt_sql(f)?; + write!(f, " ")?; + } + write!(f, "END") + } +} + +/// Create a CASE expression +pub fn case( + expr: Option>, + when_thens: Vec, + else_expr: Option>, +) -> Result> { + Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use arrow::buffer::Buffer; + use arrow::datatypes::DataType::Float64; + use arrow::datatypes::Field; + use datafusion_common::cast::{as_float64_array, as_int32_array}; + use datafusion_common::plan_err; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + use datafusion_expr::type_coercion::binary::comparison_coercion; + use datafusion_expr::Operator; + use datafusion_physical_expr_common::physical_expr::fmt_sql; + + #[test] + fn case_with_expr() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_expr_else() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_expr_divide_by_zero() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END + let when1 = lit(0i32); + let then1 = lit(ScalarValue::Float64(None)); + let else_value = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &batch.schema(), Float64)?, + &batch.schema(), + )?; + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_without_expr() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_expr_when_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END + let when1 = lit(ScalarValue::Utf8(None)); + let then1 = lit(0i32); + let when2 = col("a", &schema)?; + let then2 = lit(123i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_without_expr_divide_by_zero() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END + let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?; + let then1 = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &batch.schema(), Float64)?, + &batch.schema(), + )?; + let x = lit(ScalarValue::Float64(None)); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1)], + Some(x), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + + assert_eq!(expected, result); + + Ok(()) + } + + fn case_test_batch1() -> Result { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); + let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]); + let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + )?; + Ok(batch) + } + + #[test] + fn case_without_expr_else() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_type_cast() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then = lit(123.3f64); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_matches_and_nulls() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE WHEN load4 = 1.77 THEN load4 END + let when = binary( + col("load4", &schema)?, + Operator::Eq, + lit(1.77f64), + &batch.schema(), + )?; + let then = col("load4", &schema)?; + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_scalar_predicate() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE WHEN TRUE THEN load4 END + let when = lit(true); + let then = col("load4", &schema)?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + None, + schema.as_ref(), + )?; + + // many rows + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + let expected = &Float64Array::from(vec![ + Some(1.77), + None, + None, + Some(1.78), + None, + Some(1.77), + ]); + assert_eq!(expected, result); + + // one row + let expected = Float64Array::from(vec![Some(1.1)]); + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn case_expr_matches_and_nulls() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE load4 WHEN 1.77 THEN load4 END + let expr = col("load4", &schema)?; + let when = lit(1.77f64); + let then = col("load4", &schema)?; + + let expr = generate_case_when_with_type_coercion( + Some(expr), + vec![(when, then)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_when_null_and_some_cond_else_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + let when = binary( + Arc::new(Literal::new(ScalarValue::Boolean(None))), + Operator::And, + binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, + &schema, + )?; + let then = col("a", &schema)?; + + // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END + let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_string_array(&result); + + // all result values should be null + assert_eq!(result.logical_null_count(), batch.num_rows()); + Ok(()) + } + + fn case_test_batch() -> Result { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + Ok(batch) + } + + // Construct an array that has several NULL values whose + // underlying buffer actually matches the where expr predicate + fn case_test_batch_nulls() -> Result { + let load4: Float64Array = vec![ + Some(1.77), // 1.77 + Some(1.77), // null <-- same value, but will be set to null + Some(1.77), // null <-- same value, but will be set to null + Some(1.78), // 1.78 + None, // null + Some(1.77), // 1.77 + ] + .into_iter() + .collect(); + + let null_buffer = Buffer::from([0b00101001u8]); + let load4 = load4 + .into_data() + .into_builder() + .null_bit_buffer(Some(null_buffer)) + .build() + .unwrap(); + let load4: Float64Array = load4.into(); + + let batch = + RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; + Ok(batch) + } + + #[test] + fn case_test_incompatible() -> Result<()> { + // 1 then is int64 + // 2 then is boolean + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(true); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + ); + assert!(expr.is_err()); + + // then 1 is int32 + // then 2 is int64 + // else is float + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i64); + let else_expr = lit(1.23f64); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + Some(else_expr), + schema.as_ref(), + ); + assert!(expr.is_ok()); + let result_type = expr.unwrap().data_type(schema.as_ref())?; + assert_eq!(Float64, result_type); + Ok(()) + } + + #[test] + fn case_eq() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr1 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr2 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr3 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], + None, + &schema, + )?; + + let expr4 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1)], + Some(else_value), + &schema, + )?; + + assert!(expr1.eq(&expr2)); + assert!(expr2.eq(&expr1)); + + assert!(expr2.ne(&expr3)); + assert!(expr3.ne(&expr2)); + + assert!(expr1.ne(&expr4)); + assert!(expr4.ne(&expr1)); + + Ok(()) + } + + #[test] + fn case_transform() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr2 = Arc::clone(&expr) + .transform(|e| { + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::yes(transformed) + } else { + Transformed::no(e) + }) + }) + .data() + .unwrap(); + + let expr3 = Arc::clone(&expr) + .transform_down(|e| { + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::yes(transformed) + } else { + Transformed::no(e) + }) + }) + .data() + .unwrap(); + + assert!(expr.ne(&expr2)); + assert!(expr2.eq(&expr3)); + + Ok(()) + } + + #[test] + fn test_column_or_null_specialization() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + match expr.evaluate(&batch)? { + ColumnarValue::Array(array) => { + assert_eq!(1000, array.len()); + assert_eq!(785, array.null_count()); + } + _ => unreachable!(), + } + Ok(()) + } + + #[test] + fn test_expr_or_expr_specialization() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + let when = binary( + col("a", &schema)?, + Operator::LtEq, + lit(2i32), + &batch.schema(), + )?; + let then = col("b", &schema)?; + let else_expr = col("c", &schema)?; + let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; + assert!(matches!( + expr.eval_method, + EvalMethod::ExpressionOrExpression + )); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); + + let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]); + + assert_eq!(expected, result); + Ok(()) + } + + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } + + fn generate_case_when_with_type_coercion( + expr: Option>, + when_thens: Vec, + else_expr: Option>, + input_schema: &Schema, + ) -> Result> { + let coerce_type = + get_case_common_type(&when_thens, else_expr.clone(), input_schema); + let (when_thens, else_expr) = match coerce_type { + None => plan_err!( + "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" + ), + Some(data_type) => { + // cast then expr + let left = when_thens + .into_iter() + .map(|(when, then)| { + let then = try_cast(then, input_schema, data_type.clone())?; + Ok((when, then)) + }) + .collect::>>()?; + let right = match else_expr { + None => None, + Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), + }; + + Ok((left, right)) + } + }?; + case(expr, when_thens, else_expr) + } + + fn get_case_common_type( + when_thens: &[WhenThen], + else_expr: Option>, + input_schema: &Schema, + ) -> Option { + let thens_type = when_thens + .iter() + .map(|when_then| { + let data_type = &when_then.1.data_type(input_schema).unwrap(); + data_type.clone() + }) + .collect::>(); + let else_type = match else_expr { + None => { + // case when then exprs must have one then value + thens_type[0].clone() + } + Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), + }; + thens_type + .iter() + .try_fold(else_type, |left_type, right_type| { + // TODO: now just use the `equal` coercion rule for case when. If find the issue, and + // refactor again. + comparison_coercion(&left_type, right_type) + }) + } + + #[test] + fn test_fmt_sql() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?; + let then = lit(123.3f64); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + Some(else_value), + &schema, + )?; + + let display_string = expr.to_string(); + assert_eq!( + display_string, + "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" + ); + + let sql_string = fmt_sql(expr.as_ref()).to_string(); + assert_eq!( + sql_string, + "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" + ); + + Ok(()) + } +} From f95888236300c7520a9aacfcbc5c384a339698d7 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:30:23 +0300 Subject: [PATCH 07/22] format --- .../boolean_lookup_table.rs | 75 +- .../bytes_like_lookup_table.rs | 315 +- .../case/literal_lookup_table/mod.rs | 247 +- .../primitive_lookup_table.rs | 159 +- .../physical-expr/src/expressions/case/mod.rs | 2656 +++++++++-------- 5 files changed, 1751 insertions(+), 1701 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs index 427b0405fd2b..b85f877339a1 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -1,47 +1,50 @@ +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ArrayRef, AsArray}; use datafusion_common::ScalarValue; -use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; #[derive(Clone, Debug)] pub(super) struct BooleanIndexMap { - true_index: i32, - false_index: i32, - null_index: i32, + true_index: i32, + false_index: i32, + null_index: i32, } impl WhenLiteralIndexMap for BooleanIndexMap { - fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - fn get_first_index( - literals: &[ScalarValue], - target: Option, - ) -> Option { - literals - .iter() - .position(|literal| matches!(literal, ScalarValue::Boolean(v) if v == &target)) - .map(|pos| pos as i32) - } - - Ok(Self { - false_index: get_first_index(&literals, Some(false)).unwrap_or(else_index), - true_index: get_first_index(&literals, Some(true)).unwrap_or(else_index), - null_index: get_first_index(&literals, None).unwrap_or(else_index), - }) - } + fn try_new( + literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + fn get_first_index( + literals: &[ScalarValue], + target: Option, + ) -> Option { + literals + .iter() + .position( + |literal| matches!(literal, ScalarValue::Boolean(v) if v == &target), + ) + .map(|pos| pos as i32) + } - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - Ok( - array - .as_boolean() - .into_iter() - .map(|value| match value { - Some(true) => self.true_index, - Some(false) => self.false_index, - None => self.null_index, + Ok(Self { + false_index: get_first_index(&literals, Some(false)).unwrap_or(else_index), + true_index: get_first_index(&literals, Some(true)).unwrap_or(else_index), + null_index: get_first_index(&literals, None).unwrap_or(else_index), }) - .collect::>() - ) - } + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + Ok(array + .as_boolean() + .into_iter() + .map(|value| match value { + Some(true) => self.true_index, + Some(false) => self.false_index, + None => self.null_index, + }) + .collect::>()) + } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index 3aff3e945714..969309716c63 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -1,111 +1,125 @@ +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, + GenericByteArray, GenericByteViewArray, TypedDictionaryArray, +}; +use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; +use datafusion_common::{exec_datafusion_err, HashMap, ScalarValue}; use std::fmt::Debug; use std::iter::Map; use std::marker::PhantomData; -use arrow::array::{ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, GenericByteArray, GenericByteViewArray, TypedDictionaryArray}; -use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; -use datafusion_common::{exec_datafusion_err, HashMap, ScalarValue}; -use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; /// Helper trait to convert various byte-like array types to iterator over byte slices pub(super) trait BytesMapHelperWrapperTrait: Send + Sync { - /// Iterator over byte slices that will return - type IntoIter<'a>: Iterator> + 'a; + /// Iterator over byte slices that will return + type IntoIter<'a>: Iterator> + 'a; - /// Convert the array to an iterator over byte slices - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; + /// Convert the array to an iterator over byte slices + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; } - #[derive(Debug, Clone, Default)] pub(super) struct GenericBytesHelper(PhantomData); impl BytesMapHelperWrapperTrait for GenericBytesHelper { - type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array - .as_bytes::() - .into_iter() - .map(|item| { - item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }) - })) - } + type IntoIter<'a> = Map< + ArrayIter<&'a GenericByteArray>, + fn(Option<&'a ::Native>) -> Option<&[u8]>, + >; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_bytes::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } } #[derive(Debug, Clone, Default)] pub(super) struct FixedBinaryHelper; impl BytesMapHelperWrapperTrait for FixedBinaryHelper { - type IntoIter<'a> = FixedSizeBinaryIter<'a>; + type IntoIter<'a> = FixedSizeBinaryIter<'a>; - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array.as_fixed_size_binary().into_iter()) - } + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_fixed_size_binary().into_iter()) + } } - #[derive(Debug, Clone, Default)] pub(super) struct GenericBytesViewHelper(PhantomData); impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { - type IntoIter<'a> = Map>, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - Ok(array.as_byte_view::().into_iter().map(|item| { - item.map(|v| { - let bytes: &[u8] = v.as_ref(); - - bytes - }) - })) - } + type IntoIter<'a> = Map< + ArrayIter<&'a GenericByteViewArray>, + fn(Option<&'a ::Native>) -> Option<&[u8]>, + >; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_byte_view::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } } #[derive(Debug, Clone, Default)] -pub(super) struct BytesDictionaryHelper(PhantomData<(Key, Value)>); +pub(super) struct BytesDictionaryHelper( + PhantomData<(Key, Value)>, +); impl BytesMapHelperWrapperTrait for BytesDictionaryHelper where - Key: ArrowDictionaryKeyType + Send + Sync, - Value: ByteArrayType, - for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: - IntoIterator> { - type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array - .as_dictionary::() - .downcast_dict::>() - .ok_or_else(|| { - exec_datafusion_err!( + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteArrayType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: + IntoIterator>, +{ + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( "Failed to downcast dictionary array {} to expected dictionary value {}", array.data_type(), Value::DATA_TYPE ) - })?; + })?; - Ok(dict_array.into_iter().map(|item| item.map(|v| { - let bytes: &[u8] = v.as_ref(); + Ok(dict_array.into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); - bytes - }))) - } + bytes + }) + })) + } } #[derive(Debug, Clone, Default)] -pub(super) struct FixedBytesDictionaryHelper(PhantomData); +pub(super) struct FixedBytesDictionaryHelper( + PhantomData, +); impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper where - Key: ArrowDictionaryKeyType + Send + Sync, - for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: IntoIterator> { - type IntoIter<'a> = as IntoIterator>::IntoIter; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array + Key: ArrowDictionaryKeyType + Send + Sync, + for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: + IntoIterator>, +{ + type IntoIter<'a> = + as IntoIterator>::IntoIter; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array .as_dictionary::() .downcast_dict::() .ok_or_else(|| { @@ -115,39 +129,45 @@ where ) })?; - Ok(dict_array.into_iter()) - } + Ok(dict_array.into_iter()) + } } #[derive(Debug, Clone, Default)] -pub(super) struct BytesViewDictionaryHelper(PhantomData<(Key, Value)>); +pub(super) struct BytesViewDictionaryHelper< + Key: ArrowDictionaryKeyType, + Value: ByteViewType, +>(PhantomData<(Key, Value)>); impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper where - Key: ArrowDictionaryKeyType + Send + Sync, - Value: ByteViewType, - for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: - IntoIterator> { - type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; - - fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { - let dict_array = array - .as_dictionary::() - .downcast_dict::>() - .ok_or_else(|| { - exec_datafusion_err!( + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteViewType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: + IntoIterator>, +{ + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( "Failed to downcast dictionary array {} to expected dictionary value {}", array.data_type(), Value::DATA_TYPE ) - })?; + })?; - Ok(dict_array.into_iter().map(|item| item.map(|v| { - let bytes: &[u8] = v.as_ref(); + Ok(dict_array.into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); - bytes - }))) - } + bytes + }) + })) + } } /// Map from byte-like literal values to their first occurrence index @@ -155,78 +175,77 @@ where /// This is a wrapper for handling different kinds of literal maps #[derive(Clone)] pub(super) struct BytesLikeIndexMap { - /// Map from non-null literal value the first occurrence index in the literals - map: HashMap, i32>, + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, i32>, - /// The index for null literal value (when no null value this will equal to `else_index`) - null_index: i32, + /// The index for null literal value (when no null value this will equal to `else_index`) + null_index: i32, - /// The index to return when no match is found - else_index: i32, + /// The index to return when no match is found + else_index: i32, - _phantom_data: PhantomData, + _phantom_data: PhantomData, } impl Debug for BytesLikeIndexMap { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("BytesMapHelper") - .field("map", &self.map) - .field("null_index", &self.null_index) - .field("else_index", &self.else_index) - .finish() - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BytesMapHelper") + .field("map", &self.map) + .field("null_index", &self.null_index) + .field("else_index", &self.else_index) + .finish() + } } -impl WhenLiteralIndexMap for BytesLikeIndexMap { - fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - let input = ScalarValue::iter_to_array(literals)?; - let bytes_iter = Helper::array_to_iter(&input)?; - - let mut null_index = None; - - let mut map: HashMap, i32> = HashMap::new(); - - for (map_index, value) in bytes_iter.enumerate() { - match value { - Some(value) => { - // Insert only the first occurrence - map.entry(value.to_vec()).or_insert(map_index as i32); - } - None => { - // Only set the null index once - if null_index.is_none() { - null_index = Some(map_index as i32); - } +impl WhenLiteralIndexMap + for BytesLikeIndexMap +{ + fn try_new( + literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals)?; + let bytes_iter = Helper::array_to_iter(&input)?; + + let mut null_index = None; + + let mut map: HashMap, i32> = HashMap::new(); + + for (map_index, value) in bytes_iter.enumerate() { + match value { + Some(value) => { + // Insert only the first occurrence + map.entry(value.to_vec()).or_insert(map_index as i32); + } + None => { + // Only set the null index once + if null_index.is_none() { + null_index = Some(map_index as i32); + } + } + } } - } + + Ok(Self { + map, + null_index: null_index.unwrap_or(else_index), + else_index, + _phantom_data: Default::default(), + }) } - Ok(Self { - map, - null_index: null_index.unwrap_or(else_index), - else_index, - _phantom_data: Default::default(), - }) - } - - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - let bytes_iter = Helper::array_to_iter(array)?; - let indices = bytes_iter - .map(|value| { - match value { - Some(value) => { - self.map.get(value).copied().unwrap_or(self.else_index) - } - None => { - self.null_index - } - } - }) - .collect::>(); + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let bytes_iter = Helper::array_to_iter(array)?; + let indices = bytes_iter + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(self.else_index), + None => self.null_index, + }) + .collect::>(); - Ok(indices) - } + Ok(indices) + } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index b56d6590b037..b46707a10063 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -2,26 +2,26 @@ mod boolean_lookup_table; mod bytes_like_lookup_table; mod primitive_lookup_table; -use datafusion_common::DataFusionError; use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::{ - BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper, - FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper, - GenericBytesViewHelper, + BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper, + FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper, + GenericBytesViewHelper, }; use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveArrayMapHolder; +use crate::expressions::case::WhenThen; +use crate::expressions::Literal; use arrow::array::{downcast_integer, downcast_primitive, ArrayRef, Int32Array}; use arrow::datatypes::{ ArrowDictionaryKeyType, BinaryViewType, DataType, GenericBinaryType, GenericStringType, StringViewType, }; +use datafusion_common::DataFusionError; use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expressions::case::WhenThen; -use crate::expressions::Literal; /// Optimization for CASE expressions with literal WHEN and THEN clauses /// @@ -59,145 +59,151 @@ use crate::expressions::Literal; /// #[derive(Debug)] pub(in super::super) struct LiteralLookupTable { - /// The lookup table to use for evaluating the CASE expression - lookup: Arc, + /// The lookup table to use for evaluating the CASE expression + lookup: Arc, - /// ArrayRef where array[i] = then_literals[i] - /// the last value in the array is the else_expr - values_to_take_from: ArrayRef, + /// ArrayRef where array[i] = then_literals[i] + /// the last value in the array is the else_expr + values_to_take_from: ArrayRef, } impl Hash for LiteralLookupTable { - fn hash(&self, state: &mut H) { - // Hashing the pointer as this is the best we can do here + fn hash(&self, state: &mut H) { + // Hashing the pointer as this is the best we can do here - let lookup_ptr = Arc::as_ptr(&self.lookup); - lookup_ptr.hash(state); + let lookup_ptr = Arc::as_ptr(&self.lookup); + lookup_ptr.hash(state); - let values_ptr = Arc::as_ptr(&self.lookup); - values_ptr.hash(state); - } + let values_ptr = Arc::as_ptr(&self.lookup); + values_ptr.hash(state); + } } impl PartialEq for LiteralLookupTable { - fn eq(&self, other: &Self) -> bool { - // Comparing the pointers as this is the best we can do here - Arc::ptr_eq(&self.lookup, &other.lookup) && - self.values_to_take_from.as_ref() == other.values_to_take_from.as_ref() - } + fn eq(&self, other: &Self) -> bool { + // Comparing the pointers as this is the best we can do here + Arc::ptr_eq(&self.lookup, &other.lookup) + && self.values_to_take_from.as_ref() == other.values_to_take_from.as_ref() + } } -impl Eq for LiteralLookupTable { - -} +impl Eq for LiteralLookupTable {} impl LiteralLookupTable { - pub(in super::super) fn maybe_new( - when_then_expr: &Vec, else_expr: &Option> - ) -> Option { - // We can't use the optimization if we don't have any when then pairs - if when_then_expr.is_empty() { - return None; - } - - // If we only have 1 than this optimization is not useful - if when_then_expr.len() == 1 { - return None; - } - - let when_then_exprs_maybe_literals = when_then_expr - .iter() - .map(|(when, then)| { - let when_maybe_literal = when.as_any().downcast_ref::(); - let then_maybe_literal = then.as_any().downcast_ref::(); - - when_maybe_literal.zip(then_maybe_literal) - }) - .collect::>(); - - // If not all the when/then expressions are literals we cannot use this optimization - if when_then_exprs_maybe_literals.contains(&None) { - return None; - } - - let (when_literals, then_literals): (Vec, Vec) = when_then_exprs_maybe_literals - .iter() - // Unwrap the options as we have already checked they are all Some - .flatten() - .map(|(when_lit, then_lit)| (when_lit.value().clone(), then_lit.value().clone())) - .unzip(); + pub(in super::super) fn maybe_new( + when_then_expr: &Vec, + else_expr: &Option>, + ) -> Option { + // We can't use the optimization if we don't have any when then pairs + if when_then_expr.is_empty() { + return None; + } + // If we only have 1 than this optimization is not useful + if when_then_expr.len() == 1 { + return None; + } - let else_expr: ScalarValue = if let Some(else_expr) = else_expr { - let literal = else_expr.as_any().downcast_ref::()?; + let when_then_exprs_maybe_literals = when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::(); + let then_maybe_literal = then.as_any().downcast_ref::(); - literal.value().clone() - } else { - let Ok(null_scalar) = ScalarValue::try_new_null(&then_literals[0].data_type()) - else { - return None; - }; + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); - null_scalar - }; + // If not all the when/then expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } - { - let data_type = when_literals[0].data_type(); + let (when_literals, then_literals): (Vec, Vec) = + when_then_exprs_maybe_literals + .iter() + // Unwrap the options as we have already checked they are all Some + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + .unzip(); + + let else_expr: ScalarValue = if let Some(else_expr) = else_expr { + let literal = else_expr.as_any().downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = + ScalarValue::try_new_null(&then_literals[0].data_type()) + else { + return None; + }; + + null_scalar + }; + + { + let data_type = when_literals[0].data_type(); + + // If not all the when literals are the same data type we cannot use this optimization + if when_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + } - // If not all the when literals are the same data type we cannot use this optimization - if when_literals.iter().any(|l| l.data_type() != data_type) { - return None; - } - } + { + let data_type = then_literals[0].data_type(); - { - let data_type = then_literals[0].data_type(); + // If not all the then and the else literals are the same data type we cannot use this optimization + if then_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } - // If not all the then and the else literals are the same data type we cannot use this optimization - if then_literals.iter().any(|l| l.data_type() != data_type) { - return None; - } + if else_expr.data_type() != data_type { + return None; + } + } - if else_expr.data_type() != data_type { - return None; - } + let output_array = ScalarValue::iter_to_array( + then_literals + .iter() + // The else is in the end + .chain(std::iter::once(&else_expr)) + .cloned(), + ) + .ok()?; + + let lookup = try_creating_lookup_table( + when_literals, + // The else expression is in the end + output_array.len() as i32 - 1, + ) + .ok()?; + + Some(Self { + lookup, + values_to_take_from: output_array, + }) } + pub(in super::super) fn create_output( + &self, + expr_array: &ArrayRef, + ) -> datafusion_common::Result { + let take_indices = self.lookup.match_values(expr_array)?; - let output_array = ScalarValue::iter_to_array( - then_literals.iter() - // The else is in the end - .chain(std::iter::once(&else_expr)) - .cloned() - ).ok()?; - - let lookup = try_creating_lookup_table( - when_literals, - - // The else expression is in the end - output_array.len() as i32 - 1, - ).ok()?; - - Some(Self { - lookup, - values_to_take_from: output_array, - }) - } + // Zero-copy conversion + let take_indices = Int32Array::from(take_indices); - pub(in super::super) fn create_output(&self, expr_array: &ArrayRef) -> datafusion_common::Result { - let take_indices = self.lookup.match_values(expr_array)?; + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; - // Zero-copy conversion - let take_indices = Int32Array::from(take_indices); - - // An optimize version would depend on the type of the values_to_take_from - // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) - // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array - let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(output) - } + Ok(output) + } } /// Lookup table for mapping literal values to their corresponding indices in the THEN clauses @@ -322,8 +328,7 @@ fn create_lookup_table_for_dictionary_input, else_index: i32, ) -> datafusion_common::Result> { - - // TODO - optimize dictionary to use different wrapper that takes advantage of it being a dictionary + // TODO - optimize dictionary to use different wrapper that takes advantage of it being a dictionary match value { DataType::Utf8 => { let lookup_table = BytesLikeIndexMap::< diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs index ec58bf31a661..86d707b1d9ca 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -1,83 +1,92 @@ -use std::fmt::Debug; -use std::hash::Hash; +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; -use half::f16; use datafusion_common::{HashMap, ScalarValue}; -use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use half::f16; +use std::fmt::Debug; +use std::hash::Hash; #[derive(Clone)] pub(super) struct PrimitiveArrayMapHolder where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, + T: ArrowPrimitiveType, + T::Native: ToHashableKey, { - /// Literal value to map index - /// - /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps - map: HashMap::HashableKey>, i32>, - else_index: i32, + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap::HashableKey>, i32>, + else_index: i32, } impl Debug for PrimitiveArrayMapHolder where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, + T: ArrowPrimitiveType, + T::Native: ToHashableKey, { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PrimitiveArrayMapHolder") - .field("map", &self.map) - .field("else_index", &self.else_index) - .finish() - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveArrayMapHolder") + .field("map", &self.map) + .field("else_index", &self.else_index) + .finish() + } } impl WhenLiteralIndexMap for PrimitiveArrayMapHolder where - T: ArrowPrimitiveType, - T::Native: ToHashableKey, + T: ArrowPrimitiveType, + T::Native: ToHashableKey, { - fn try_new(literals: Vec, else_index: i32) -> datafusion_common::Result - where - Self: Sized, - { - let input = ScalarValue::iter_to_array(literals)?; - - let map = input - .as_primitive::() - .into_iter() - .enumerate() - .map(|(map_index, value)| (value.map(|v| v.into_hashable_key()), map_index as i32)) - .collect(); - - Ok(Self { map, else_index }) - } - - fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { - let indices = array - .as_primitive::() - .into_iter() - .map(|value| self.map.get(&value.map(|item| item.into_hashable_key())).copied().unwrap_or(self.else_index)) - .collect::>(); - - Ok(indices) - } -} + fn try_new( + literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(literals)?; + + let map = input + .as_primitive::() + .into_iter() + .enumerate() + .map(|(map_index, value)| { + (value.map(|v| v.into_hashable_key()), map_index as i32) + }) + .collect(); + + Ok(Self { map, else_index }) + } + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let indices = array + .as_primitive::() + .into_iter() + .map(|value| { + self.map + .get(&value.map(|item| item.into_hashable_key())) + .copied() + .unwrap_or(self.else_index) + }) + .collect::>(); + + Ok(indices) + } +} // TODO - We need to port it to arrow so that it can be reused in other places /// Trait that help convert a value to a key that is hashable and equatable /// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly pub(super) trait ToHashableKey: ArrowNativeTypeOp { - /// The type that is hashable and equatable - /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self - /// this is just a helper trait so you can reuse the same code for all arrow native types - type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; - - /// Converts self to a hashable key - /// the result of this value can be used as the key in hash maps/sets - fn into_hashable_key(self) -> Self::HashableKey; + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; } macro_rules! impl_to_hashable_key { @@ -115,37 +124,37 @@ impl_to_hashable_key!(@float | f64 => u64); #[cfg(test)] mod tests { - use super::ToHashableKey; - use arrow::array::downcast_primitive; - - // This test ensure that all arrow primitive types implement ToHashableKey - // otherwise the code will not compile - #[test] - fn should_implement_to_hashable_key_for_all_primitives() { - #[derive(Debug, Default)] - struct ExampleSet - where - T: arrow::datatypes::ArrowPrimitiveType, - T::Native: ToHashableKey, - { - _map: std::collections::HashSet<::HashableKey>, - } + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } - macro_rules! create_matching_set { + macro_rules! create_matching_set { ($t:ty) => {{ let _lookup_table = ExampleSet::<$t> { - _map: Default::default() + _map: Default::default(), }; return; }}; } - let data_type = arrow::datatypes::DataType::Float16; + let data_type = arrow::datatypes::DataType::Float16; - downcast_primitive! { + downcast_primitive! { data_type => (create_matching_set), _ => panic!("not implemented for {data_type}"), } - } + } } diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index e7947164b068..3dccdfbb594e 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - mod literal_lookup_table; use crate::expressions::{try_cast, Column, Literal}; @@ -29,49 +28,51 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; -use crate::expressions::case::literal_lookup_table::LiteralLookupTable; pub(super) type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - NoExpression, - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - WithExpression, - /// This is a specialization for a specific use case where we can take a fast path - /// for expressions that are infallible and can be cheaply computed for the entire - /// record batch rather than just for the rows where the predicate is true. - /// - /// CASE WHEN condition THEN column [ELSE NULL] END - InfallibleExprOrNull, - /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` expressions - /// are literal values - /// CASE WHEN condition THEN literal ELSE literal END - ScalarOrScalar, - /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions - /// - /// CASE WHEN condition THEN expression ELSE expression END - ExpressionOrExpression, - - /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals - /// - /// See [`LiteralLookupTable`] for more details - WithExprScalarLookupTable(LiteralLookupTable) + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + NoExpression, + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + WithExpression, + /// This is a specialization for a specific use case where we can take a fast path + /// for expressions that are infallible and can be cheaply computed for the entire + /// record batch rather than just for the rows where the predicate is true. + /// + /// CASE WHEN condition THEN column [ELSE NULL] END + InfallibleExprOrNull, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` expressions + /// are literal values + /// CASE WHEN condition THEN literal ELSE literal END + ScalarOrScalar, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` are expressions + /// + /// CASE WHEN condition THEN expression ELSE expression END + ExpressionOrExpression, + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable), } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -93,30 +94,30 @@ enum EvalMethod { /// END #[derive(Debug, Hash, PartialEq, Eq)] pub struct CaseExpr { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec, - /// Optional "else" expression - else_expr: Option>, - /// Evaluation method to use - eval_method: EvalMethod, + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec, + /// Optional "else" expression + else_expr: Option>, + /// Evaluation method to use + eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "CASE ")?; - if let Some(e) = &self.expr { - write!(f, "{e} ")?; - } - for (w, t) in &self.when_then_expr { - write!(f, "WHEN {w} THEN {t} ")?; - } - if let Some(e) = &self.else_expr { - write!(f, "ELSE {e} ")?; + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CASE ")?; + if let Some(e) = &self.expr { + write!(f, "{e} ")?; + } + for (w, t) in &self.when_then_expr { + write!(f, "WHEN {w} THEN {t} ")?; + } + if let Some(e) = &self.else_expr { + write!(f, "ELSE {e} ")?; + } + write!(f, "END") } - write!(f, "END") - } } /// This is a specialization for a specific use case where we can take a fast path @@ -125,1294 +126,1307 @@ impl std::fmt::Display for CaseExpr { /// this is limited to use with Column expressions but could potentially be used for other /// expressions in the future fn is_cheap_and_infallible(expr: &Arc) -> bool { - expr.as_any().is::() + expr.as_any().is::() } impl CaseExpr { - /// Create a new CASE WHEN expression - pub fn try_new( - expr: Option>, - when_then_expr: Vec, - else_expr: Option>, - ) -> Result { - // normalize null literals to None in the else_expr (this already happens - // during SQL planning, but not necessarily for other use cases) - let else_expr = match &else_expr { - Some(e) => match e.as_any().downcast_ref::() { - Some(lit) if lit.value().is_null() => None, - _ => else_expr, - }, - _ => else_expr, - }; - - if when_then_expr.is_empty() { - exec_err!("There must be at least one WHEN clause") - } else { - let eval_method = Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); - - Ok(Self { - expr, - when_then_expr, - else_expr, - eval_method, - }) + /// Create a new CASE WHEN expression + pub fn try_new( + expr: Option>, + when_then_expr: Vec, + else_expr: Option>, + ) -> Result { + // normalize null literals to None in the else_expr (this already happens + // during SQL planning, but not necessarily for other use cases) + let else_expr = match &else_expr { + Some(e) => match e.as_any().downcast_ref::() { + Some(lit) if lit.value().is_null() => None, + _ => else_expr, + }, + _ => else_expr, + }; + + if when_then_expr.is_empty() { + exec_err!("There must be at least one WHEN clause") + } else { + let eval_method = + Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); + + Ok(Self { + expr, + when_then_expr, + else_expr, + eval_method, + }) + } } - } - fn find_best_eval_method(expr: &Option>, when_then_expr: &Vec, else_expr: &Option>) -> EvalMethod { - if expr.is_some() { - if let Some(mapping) = LiteralLookupTable::maybe_new(when_then_expr, else_expr) { - return EvalMethod::WithExprScalarLookupTable(mapping); - } + fn find_best_eval_method( + expr: &Option>, + when_then_expr: &Vec, + else_expr: &Option>, + ) -> EvalMethod { + if expr.is_some() { + if let Some(mapping) = + LiteralLookupTable::maybe_new(when_then_expr, else_expr) + { + return EvalMethod::WithExprScalarLookupTable(mapping); + } + + return EvalMethod::WithExpression; + } + + if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if when_then_expr.len() == 1 && else_expr.is_some() { + EvalMethod::ExpressionOrExpression + } else { + EvalMethod::NoExpression + } + } + + /// Optional base expression that can be compared to literal values in the "when" expressions + pub fn expr(&self) -> Option<&Arc> { + self.expr.as_ref() + } - return EvalMethod::WithExpression + /// One or more when/then expressions + pub fn when_then_expr(&self) -> &[WhenThen] { + &self.when_then_expr } - if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if when_then_expr.len() == 1 - && when_then_expr[0].1.as_any().is::() - && else_expr.is_some() - && else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression - } else { - EvalMethod::NoExpression + /// Optional "else" expression + pub fn else_expr(&self) -> Option<&Arc> { + self.else_expr.as_ref() } - } - - /// Optional base expression that can be compared to literal values in the "when" expressions - pub fn expr(&self) -> Option<&Arc> { - self.expr.as_ref() - } - - /// One or more when/then expressions - pub fn when_then_expr(&self) -> &[WhenThen] { - &self.when_then_expr - } - - /// Optional "else" expression - pub fn else_expr(&self) -> Option<&Arc> { - self.else_expr.as_ref() - } } impl CaseExpr { - /// This function evaluates the form of CASE that matches an expression to fixed values. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - let expr = self.expr.as_ref().unwrap(); - let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows())?; - let base_nulls = is_null(base_value.as_ref())?; - - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - // We only consider non-null values while comparing with whens - let mut remainder = not(&base_nulls)?; - let mut non_null_remainder_count = remainder.true_count(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if non_null_remainder_count == 0 { - break; - } + /// This function evaluates the form of CASE that matches an expression to fixed values. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + let expr = self.expr.as_ref().unwrap(); + let base_value = expr.evaluate(batch)?; + let base_value = base_value.into_array(batch.num_rows())?; + let base_nulls = is_null(base_value.as_ref())?; + + // start with nulls as default output + let mut current_value = new_null_array(&return_type, batch.num_rows()); + // We only consider non-null values while comparing with whens + let mut remainder = not(&base_nulls)?; + let mut non_null_remainder_count = remainder.true_count(); + for i in 0..self.when_then_expr.len() { + // If there are no rows left to process, break out of the loop early + if non_null_remainder_count == 0 { + break; + } - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - // build boolean array representing which rows match the "when" value - let when_match = compare_with_eq( - &when_value, - &base_value, - // The types of case and when expressions will be coerced to match. - // We only need to check if the base_value is nested. - base_value.data_type().is_nested(), - )?; - // Treat nulls as false - let when_match = match when_match.null_count() { - 0 => Cow::Borrowed(&when_match), - _ => Cow::Owned(prep_null_mask_filter(&when_match)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_match, &remainder)?; - - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { - continue; - } + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; + let when_value = when_value.into_array(batch.num_rows())?; + // build boolean array representing which rows match the "when" value + let when_match = compare_with_eq( + &when_value, + &base_value, + // The types of case and when expressions will be coerced to match. + // We only need to check if the base_value is nested. + base_value.data_type().is_nested(), + )?; + // Treat nulls as false + let when_match = match when_match.null_count() { + 0 => Cow::Borrowed(&when_match), + _ => Cow::Owned(prep_null_mask_filter(&when_match)), + }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_match, &remainder)?; + + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { + continue; + } - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; + + current_value = match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => { + nullif(current_value.as_ref(), &when_value)? + } + ColumnarValue::Scalar(then_value) => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + ColumnarValue::Array(then_value) => { + zip(&when_value, &then_value, ¤t_value)? + } + }; + + remainder = and_not(&remainder, &when_value)?; + non_null_remainder_count -= when_match_count; + } + + if let Some(e) = self.else_expr() { + // null and unmatched tuples should be assigned else value + remainder = or(&base_nulls, &remainder)?; + + if remainder.true_count() > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + + Ok(ColumnarValue::Array(current_value)) + } + + /// This function evaluates the form of CASE where each WHEN expression is a boolean + /// expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // start with nulls as default output + let mut current_value = new_null_array(&return_type, batch.num_rows()); + let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_count = batch.num_rows(); + for i in 0..self.when_then_expr.len() { + // If there are no rows left to process, break out of the loop early + if remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_value, &remainder)?; + + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { + continue; + } + + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; + + current_value = match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => { + nullif(current_value.as_ref(), &when_value)? + } + ColumnarValue::Scalar(then_value) => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + ColumnarValue::Array(then_value) => { + zip(&when_value, &then_value, ¤t_value)? + } + }; + + // Succeed tuples should be filtered out for short-circuit evaluation, + // null values for the current when expr should be kept + remainder = and_not(&remainder, &when_value)?; + remainder_count -= when_match_count; } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? + + if let Some(e) = self.else_expr() { + if remainder_count > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } } - }; - remainder = and_not(&remainder, &when_value)?; - non_null_remainder_count -= when_match_count; + Ok(ColumnarValue::Array(current_value)) + } + + /// This function evaluates the specialized case of: + /// + /// CASE WHEN condition THEN column + /// [ELSE NULL] + /// END + /// + /// Note that this function is only safe to use for "then" expressions + /// that are infallible because the expression will be evaluated for all + /// rows in the input batch. + fn case_column_or_null(&self, batch: &RecordBatch) -> Result { + let when_expr = &self.when_then_expr[0].0; + let then_expr = &self.when_then_expr[0].1; + + match when_expr.evaluate(batch)? { + // WHEN true --> column + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => { + then_expr.evaluate(batch) + } + // WHEN [false | null] --> NULL + ColumnarValue::Scalar(_) => { + // return scalar NULL value + ScalarValue::try_from(self.data_type(&batch.schema())?) + .map(ColumnarValue::Scalar) + } + // WHEN column --> column + ColumnarValue::Array(bit_mask) => { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = match bit_mask.null_count() { + 0 => not(bit_mask)?, + _ => not(&prep_null_mask_filter(bit_mask))?, + }; + match then_expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) + } + ColumnarValue::Scalar(_) => { + internal_err!("expression did not evaluate to an array") + } + } + } + } } - if let Some(e) = self.else_expr() { - // null and unmatched tuples should be assigned else value - remainder = or(&base_nulls, &remainder)?; + fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; - if remainder.true_count() > 0 { + // evaluate when expression + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|_| { + internal_datafusion_err!("WHEN expression did not return a BooleanArray") + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + let Some(e) = self.else_expr() else { + return internal_err!("expression did not evaluate to an array"); + }; // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } + fn expr_or_expr(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when condition on batch + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; + + // For the true and false/null selection vectors, bypass `evaluate_selection` and merging + // results. This avoids materializing the array for the other branch which we will discard + // entirely anyway. + let true_count = when_value.true_count(); + if true_count == batch.num_rows() { + return self.when_then_expr[0].1.evaluate(batch); + } else if true_count == 0 { + return self.else_expr.as_ref().unwrap().evaluate(batch); + } + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + let then_value = self.when_then_expr[0] + .1 + .evaluate_selection(batch, &when_value)? + .into_array(batch.num_rows())?; + + // evaluate else expression on the values not covered by when_value + let remainder = not(&when_value)?; + let e = self.else_expr.as_ref().unwrap(); + + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + + Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } - Ok(ColumnarValue::Array(current_value)) - } - - /// This function evaluates the form of CASE where each WHEN expression is a boolean - /// expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - let mut remainder_count = batch.num_rows(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if remainder_count == 0 { - break; - } + fn with_lookup_table( + &self, + batch: &RecordBatch, + scalars_or_null_lookup: &LiteralLookupTable, + ) -> Result { + let expr = self.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|_| { - internal_datafusion_err!("WHEN expression did not return a BooleanArray") - })?; - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_value, &remainder)?; - - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { - continue; - } + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; + let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) + } else { + ColumnarValue::Array(output) + }; + + Ok(result) + } +} + +impl PhysicalExpr for CaseExpr { + /// Return a reference to Any that can be used for down-casting + fn as_any(&self) -> &dyn Any { + self + } - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? + fn data_type(&self, input_schema: &Schema) -> Result { + // since all then results have the same data type, we can choose any one as the + // return data type except for the null. + let mut data_type = DataType::Null; + for i in 0..self.when_then_expr.len() { + data_type = self.when_then_expr[i].1.data_type(input_schema)?; + if !data_type.equals_datatype(&DataType::Null) { + break; + } } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + // if all then results are null, we use data type of else expr instead if possible. + if data_type.equals_datatype(&DataType::Null) { + if let Some(e) = &self.else_expr { + data_type = e.data_type(input_schema)?; + } } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? + + Ok(data_type) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + // this expression is nullable if any of the input expressions are nullable + let then_nullable = self + .when_then_expr + .iter() + .map(|(_, t)| t.nullable(input_schema)) + .collect::>>()?; + if then_nullable.contains(&true) { + Ok(true) + } else if let Some(e) = &self.else_expr { + e.nullable(input_schema) + } else { + // CASE produces NULL if there is no `else` expr + // (aka when none of the `when_then_exprs` match) + Ok(true) } - }; + } - // Succeed tuples should be filtered out for short-circuit evaluation, - // null values for the current when expr should be kept - remainder = and_not(&remainder, &when_value)?; - remainder_count -= when_match_count; + fn evaluate(&self, batch: &RecordBatch) -> Result { + match self.eval_method { + EvalMethod::WithExpression => { + // this use case evaluates "expr" and then compares the values with the "when" + // values + self.case_when_with_expr(batch) + } + EvalMethod::NoExpression => { + // The "when" conditions all evaluate to boolean in this use case and can be + // arbitrary expressions + self.case_when_no_expr(batch) + } + EvalMethod::InfallibleExprOrNull => { + // Specialization for CASE WHEN expr THEN column [ELSE NULL] END + self.case_column_or_null(batch) + } + EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), + EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), + EvalMethod::WithExprScalarLookupTable(ref e) => { + self.with_lookup_table(batch, e) + } + } } - if let Some(e) = self.else_expr() { - if remainder_count > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } + fn children(&self) -> Vec<&Arc> { + let mut children = vec![]; + if let Some(expr) = &self.expr { + children.push(expr) + } + self.when_then_expr.iter().for_each(|(cond, value)| { + children.push(cond); + children.push(value); + }); + + if let Some(else_expr) = &self.else_expr { + children.push(else_expr) + } + children } - Ok(ColumnarValue::Array(current_value)) - } - - /// This function evaluates the specialized case of: - /// - /// CASE WHEN condition THEN column - /// [ELSE NULL] - /// END - /// - /// Note that this function is only safe to use for "then" expressions - /// that are infallible because the expression will be evaluated for all - /// rows in the input batch. - fn case_column_or_null(&self, batch: &RecordBatch) -> Result { - let when_expr = &self.when_then_expr[0].0; - let then_expr = &self.when_then_expr[0].1; - - match when_expr.evaluate(batch)? { - // WHEN true --> column - ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => { - then_expr.evaluate(batch) - } - // WHEN [false | null] --> NULL - ColumnarValue::Scalar(_) => { - // return scalar NULL value - ScalarValue::try_from(self.data_type(&batch.schema())?) - .map(ColumnarValue::Scalar) - } - // WHEN column --> column - ColumnarValue::Array(bit_mask) => { - let bit_mask = bit_mask - .as_any() - .downcast_ref::() - .expect("predicate should evaluate to a boolean array"); - // invert the bitmask - let bit_mask = match bit_mask.null_count() { - 0 => not(bit_mask)?, - _ => not(&prep_null_mask_filter(bit_mask))?, - }; - match then_expr.evaluate(batch)? { - ColumnarValue::Array(array) => { - Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) - } - ColumnarValue::Scalar(_) => { - internal_err!("expression did not evaluate to an array") - } + // For physical CaseExpr, we do not allow modifying children size + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != self.children().len() { + internal_err!("CaseExpr: Wrong number of children") + } else { + let (expr, when_then_expr, else_expr) = + match (self.expr().is_some(), self.else_expr().is_some()) { + (true, true) => ( + Some(&children[0]), + &children[1..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (true, false) => { + (Some(&children[0]), &children[1..children.len()], None) + } + (false, true) => ( + None, + &children[0..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (false, false) => (None, &children[0..children.len()], None), + }; + Ok(Arc::new(CaseExpr::try_new( + expr.cloned(), + when_then_expr.iter().cloned().tuples().collect(), + else_expr.cloned(), + )?)) } - } } - } - - fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // evaluate when expression - let when_value = self.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|_| { - internal_datafusion_err!("WHEN expression did not return a BooleanArray") - })?; - - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - - // evaluate then_value - let then_value = self.when_then_expr[0].1.evaluate(batch)?; - let then_value = Scalar::new(then_value.into_array(1)?); - - let Some(e) = self.else_expr() else { - return internal_err!("expression did not evaluate to an array"); - }; - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?; - let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); - Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) - } - - fn expr_or_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // evaluate when condition on batch - let when_value = self.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) - })?; - - // For the true and false/null selection vectors, bypass `evaluate_selection` and merging - // results. This avoids materializing the array for the other branch which we will discard - // entirely anyway. - let true_count = when_value.true_count(); - if true_count == batch.num_rows() { - return self.when_then_expr[0].1.evaluate(batch); - } else if true_count == 0 { - return self.else_expr.as_ref().unwrap().evaluate(batch); + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CASE ")?; + if let Some(e) = &self.expr { + e.fmt_sql(f)?; + write!(f, " ")?; + } + + for (w, t) in &self.when_then_expr { + write!(f, "WHEN ")?; + w.fmt_sql(f)?; + write!(f, " THEN ")?; + t.fmt_sql(f)?; + write!(f, " ")?; + } + + if let Some(e) = &self.else_expr { + write!(f, "ELSE ")?; + e.fmt_sql(f)?; + write!(f, " ")?; + } + write!(f, "END") } +} - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; +/// Create a CASE expression +pub fn case( + expr: Option>, + when_thens: Vec, + else_expr: Option>, +) -> Result> { + Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) +} - let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; +#[cfg(test)] +mod tests { + use super::*; + + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use arrow::buffer::Buffer; + use arrow::datatypes::DataType::Float64; + use arrow::datatypes::Field; + use datafusion_common::cast::{as_float64_array, as_int32_array}; + use datafusion_common::plan_err; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + use datafusion_expr::type_coercion::binary::comparison_coercion; + use datafusion_expr::Operator; + use datafusion_physical_expr_common::physical_expr::fmt_sql; + + #[test] + fn case_with_expr() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; - // evaluate else expression on the values not covered by when_value - let remainder = not(&when_value)?; - let e = self.else_expr.as_ref().unwrap(); + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; + assert_eq!(expected, result); - Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) - } + Ok(()) + } - fn with_lookup_table(&self, batch: &RecordBatch, scalars_or_null_lookup: &LiteralLookupTable) -> Result { - let expr = self.expr.as_ref().unwrap(); - let evaluated_expression = expr.evaluate(batch)?; + #[test] + fn case_with_expr_else() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } - let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); - let evaluated_expression = evaluated_expression.to_array(1)?; + #[test] + fn case_with_expr_divide_by_zero() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END + let when1 = lit(0i32); + let then1 = lit(ScalarValue::Float64(None)); + let else_value = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &batch.schema(), Float64)?, + &batch.schema(), + )?; + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + + assert_eq!(expected, result); + + Ok(()) + } - let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; + #[test] + fn case_without_expr() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; - let result = if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) - } else { - ColumnarValue::Array(output) - }; + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - Ok(result) - } -} + assert_eq!(expected, result); -impl PhysicalExpr for CaseExpr { - /// Return a reference to Any that can be used for down-casting - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> Result { - // since all then results have the same data type, we can choose any one as the - // return data type except for the null. - let mut data_type = DataType::Null; - for i in 0..self.when_then_expr.len() { - data_type = self.when_then_expr[i].1.data_type(input_schema)?; - if !data_type.equals_datatype(&DataType::Null) { - break; - } + Ok(()) } - // if all then results are null, we use data type of else expr instead if possible. - if data_type.equals_datatype(&DataType::Null) { - if let Some(e) = &self.else_expr { - data_type = e.data_type(input_schema)?; - } + + #[test] + fn case_with_expr_when_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END + let when1 = lit(ScalarValue::Utf8(None)); + let then1 = lit(0i32); + let when2 = col("a", &schema)?; + let then2 = lit(123i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); + + assert_eq!(expected, result); + + Ok(()) } - Ok(data_type) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = self - .when_then_expr - .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) - } else if let Some(e) = &self.else_expr { - e.nullable(input_schema) - } else { - // CASE produces NULL if there is no `else` expr - // (aka when none of the `when_then_exprs` match) - Ok(true) + #[test] + fn case_without_expr_divide_by_zero() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END + let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?; + let then1 = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &batch.schema(), Float64)?, + &batch.schema(), + )?; + let x = lit(ScalarValue::Float64(None)); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1)], + Some(x), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + + assert_eq!(expected, result); + + Ok(()) } - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - match self.eval_method { - EvalMethod::WithExpression => { - // this use case evaluates "expr" and then compares the values with the "when" - // values - self.case_when_with_expr(batch) - } - EvalMethod::NoExpression => { - // The "when" conditions all evaluate to boolean in this use case and can be - // arbitrary expressions - self.case_when_no_expr(batch) - } - EvalMethod::InfallibleExprOrNull => { - // Specialization for CASE WHEN expr THEN column [ELSE NULL] END - self.case_column_or_null(batch) - } - EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), - EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), - EvalMethod::WithExprScalarLookupTable(ref e) => self.with_lookup_table(batch, e), + + fn case_test_batch1() -> Result { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); + let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]); + let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + )?; + Ok(batch) } - } - fn children(&self) -> Vec<&Arc> { - let mut children = vec![]; - if let Some(expr) = &self.expr { - children.push(expr) + #[test] + fn case_without_expr_else() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); + + assert_eq!(expected, result); + + Ok(()) } - self.when_then_expr.iter().for_each(|(cond, value)| { - children.push(cond); - children.push(value); - }); - if let Some(else_expr) = &self.else_expr { - children.push(else_expr) + #[test] + fn case_with_type_cast() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then = lit(123.3f64); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); + + assert_eq!(expected, result); + + Ok(()) } - children - } - - // For physical CaseExpr, we do not allow modifying children size - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - if children.len() != self.children().len() { - internal_err!("CaseExpr: Wrong number of children") - } else { - let (expr, when_then_expr, else_expr) = - match (self.expr().is_some(), self.else_expr().is_some()) { - (true, true) => ( - Some(&children[0]), - &children[1..children.len() - 1], - Some(&children[children.len() - 1]), - ), - (true, false) => { - (Some(&children[0]), &children[1..children.len()], None) - } - (false, true) => ( + + #[test] + fn case_with_matches_and_nulls() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE WHEN load4 = 1.77 THEN load4 END + let when = binary( + col("load4", &schema)?, + Operator::Eq, + lit(1.77f64), + &batch.schema(), + )?; + let then = col("load4", &schema)?; + + let expr = generate_case_when_with_type_coercion( None, - &children[0..children.len() - 1], - Some(&children[children.len() - 1]), - ), - (false, false) => (None, &children[0..children.len()], None), - }; - Ok(Arc::new(CaseExpr::try_new( - expr.cloned(), - when_then_expr.iter().cloned().tuples().collect(), - else_expr.cloned(), - )?)) + vec![(when, then)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); + + assert_eq!(expected, result); + + Ok(()) } - } - fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CASE ")?; - if let Some(e) = &self.expr { - e.fmt_sql(f)?; - write!(f, " ")?; + #[test] + fn case_with_scalar_predicate() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE WHEN TRUE THEN load4 END + let when = lit(true); + let then = col("load4", &schema)?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + None, + schema.as_ref(), + )?; + + // many rows + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + let expected = &Float64Array::from(vec![ + Some(1.77), + None, + None, + Some(1.78), + None, + Some(1.77), + ]); + assert_eq!(expected, result); + + // one row + let expected = Float64Array::from(vec![Some(1.1)]); + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + assert_eq!(&expected, result); + + Ok(()) } - for (w, t) in &self.when_then_expr { - write!(f, "WHEN ")?; - w.fmt_sql(f)?; - write!(f, " THEN ")?; - t.fmt_sql(f)?; - write!(f, " ")?; + #[test] + fn case_expr_matches_and_nulls() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE load4 WHEN 1.77 THEN load4 END + let expr = col("load4", &schema)?; + let when = lit(1.77f64); + let then = col("load4", &schema)?; + + let expr = generate_case_when_with_type_coercion( + Some(expr), + vec![(when, then)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); + + assert_eq!(expected, result); + + Ok(()) } - if let Some(e) = &self.else_expr { - write!(f, "ELSE ")?; - e.fmt_sql(f)?; - write!(f, " ")?; + #[test] + fn test_when_null_and_some_cond_else_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + let when = binary( + Arc::new(Literal::new(ScalarValue::Boolean(None))), + Operator::And, + binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, + &schema, + )?; + let then = col("a", &schema)?; + + // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END + let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_string_array(&result); + + // all result values should be null + assert_eq!(result.logical_null_count(), batch.num_rows()); + Ok(()) } - write!(f, "END") - } -} -/// Create a CASE expression -pub fn case( - expr: Option>, - when_thens: Vec, - else_expr: Option>, -) -> Result> { - Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) -} + fn case_test_batch() -> Result { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + Ok(batch) + } -#[cfg(test)] -mod tests { - use super::*; - - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; - use arrow::buffer::Buffer; - use arrow::datatypes::DataType::Float64; - use arrow::datatypes::Field; - use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::plan_err; - use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; - use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; - use datafusion_physical_expr_common::physical_expr::fmt_sql; - - #[test] - fn case_with_expr() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_else() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_divide_by_zero() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - - // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END - let when1 = lit(0i32); - let then1 = lit(ScalarValue::Float64(None)); - let else_value = binary( - lit(25.0f64), - Operator::Divide, - cast(col("a", &schema)?, &batch.schema(), Float64)?, - &batch.schema(), - )?; - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_without_expr() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_expr_when_null() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END - let when1 = lit(ScalarValue::Utf8(None)); - let then1 = lit(0i32); - let when2 = col("a", &schema)?; - let then2 = lit(123i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_without_expr_divide_by_zero() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - - // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END - let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?; - let then1 = binary( - lit(25.0f64), - Operator::Divide, - cast(col("a", &schema)?, &batch.schema(), Float64)?, - &batch.schema(), - )?; - let x = lit(ScalarValue::Float64(None)); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1)], - Some(x), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - fn case_test_batch1() -> Result { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); - let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]); - let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]); - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], - )?; - Ok(batch) - } - - #[test] - fn case_without_expr_else() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = - &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_type_cast() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END - let when = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then = lit(123.3f64); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - Some(else_value), - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_matches_and_nulls() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE WHEN load4 = 1.77 THEN load4 END - let when = binary( - col("load4", &schema)?, - Operator::Eq, - lit(1.77f64), - &batch.schema(), - )?; - let then = col("load4", &schema)?; - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn case_with_scalar_predicate() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE WHEN TRUE THEN load4 END - let when = lit(true); - let then = col("load4", &schema)?; - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - None, - schema.as_ref(), - )?; - - // many rows - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - let expected = &Float64Array::from(vec![ - Some(1.77), - None, - None, - Some(1.78), - None, - Some(1.77), - ]); - assert_eq!(expected, result); - - // one row - let expected = Float64Array::from(vec![Some(1.1)]); - let batch = - RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - assert_eq!(&expected, result); - - Ok(()) - } - - #[test] - fn case_expr_matches_and_nulls() -> Result<()> { - let batch = case_test_batch_nulls()?; - let schema = batch.schema(); - - // SELECT CASE load4 WHEN 1.77 THEN load4 END - let expr = col("load4", &schema)?; - let when = lit(1.77f64); - let then = col("load4", &schema)?; - - let expr = generate_case_when_with_type_coercion( - Some(expr), - vec![(when, then)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = - as_float64_array(&result).expect("failed to downcast to Float64Array"); - - let expected = - &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]); - - assert_eq!(expected, result); - - Ok(()) - } - - #[test] - fn test_when_null_and_some_cond_else_null() -> Result<()> { - let batch = case_test_batch()?; - let schema = batch.schema(); - - let when = binary( - Arc::new(Literal::new(ScalarValue::Boolean(None))), - Operator::And, - binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, - &schema, - )?; - let then = col("a", &schema)?; - - // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END - let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_string_array(&result); - - // all result values should be null - assert_eq!(result.logical_null_count(), batch.num_rows()); - Ok(()) - } - - fn case_test_batch() -> Result { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - Ok(batch) - } - - // Construct an array that has several NULL values whose - // underlying buffer actually matches the where expr predicate - fn case_test_batch_nulls() -> Result { - let load4: Float64Array = vec![ - Some(1.77), // 1.77 - Some(1.77), // null <-- same value, but will be set to null - Some(1.77), // null <-- same value, but will be set to null - Some(1.78), // 1.78 - None, // null - Some(1.77), // 1.77 - ] - .into_iter() - .collect(); - - let null_buffer = Buffer::from([0b00101001u8]); - let load4 = load4 - .into_data() - .into_builder() - .null_bit_buffer(Some(null_buffer)) - .build() - .unwrap(); - let load4: Float64Array = load4.into(); - - let batch = - RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; - Ok(batch) - } - - #[test] - fn case_test_incompatible() -> Result<()> { - // 1 then is int64 - // 2 then is boolean - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(true); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - ); - assert!(expr.is_err()); - - // then 1 is int32 - // then 2 is int64 - // else is float - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END - let when1 = binary( - col("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - col("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i64); - let else_expr = lit(1.23f64); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - Some(else_expr), - schema.as_ref(), - ); - assert!(expr.is_ok()); - let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(Float64, result_type); - Ok(()) - } - - #[test] - fn case_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr1 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr2 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr3 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], - None, - &schema, - )?; - - let expr4 = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![(when1, then1)], - Some(else_value), - &schema, - )?; - - assert!(expr1.eq(&expr2)); - assert!(expr2.eq(&expr1)); - - assert!(expr2.ne(&expr3)); - assert!(expr3.ne(&expr2)); - - assert!(expr1.ne(&expr4)); - assert!(expr4.ne(&expr1)); - - Ok(()) - } - - #[test] - fn case_transform() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - - let when1 = lit("foo"); - let then1 = lit(123i32); - let when2 = lit("bar"); - let then2 = lit(456i32); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - Some(col("a", &schema)?), - vec![ - (Arc::clone(&when1), Arc::clone(&then1)), - (Arc::clone(&when2), Arc::clone(&then2)), - ], - Some(Arc::clone(&else_value)), - &schema, - )?; - - let expr2 = Arc::clone(&expr) - .transform(|e| { - let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) + // Construct an array that has several NULL values whose + // underlying buffer actually matches the where expr predicate + fn case_test_batch_nulls() -> Result { + let load4: Float64Array = vec![ + Some(1.77), // 1.77 + Some(1.77), // null <-- same value, but will be set to null + Some(1.77), // null <-- same value, but will be set to null + Some(1.78), // 1.78 + None, // null + Some(1.77), // 1.77 + ] + .into_iter() + .collect(); + + let null_buffer = Buffer::from([0b00101001u8]); + let load4 = load4 + .into_data() + .into_builder() + .null_bit_buffer(Some(null_buffer)) + .build() + .unwrap(); + let load4: Float64Array = load4.into(); + + let batch = + RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?; + Ok(batch) + } + + #[test] + fn case_test_incompatible() -> Result<()> { + // 1 then is int64 + // 2 then is boolean + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(true); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + ); + assert!(expr.is_err()); + + // then 1 is int32 + // then 2 is int64 + // else is float + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END + let when1 = binary( + col("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + col("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i64); + let else_expr = lit(1.23f64); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + Some(else_expr), + schema.as_ref(), + ); + assert!(expr.is_ok()); + let result_type = expr.unwrap().data_type(schema.as_ref())?; + assert_eq!(Float64, result_type); + Ok(()) + } + + #[test] + fn case_eq() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr1 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr2 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr3 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], + None, + &schema, + )?; + + let expr4 = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1)], + Some(else_value), + &schema, + )?; + + assert!(expr1.eq(&expr2)); + assert!(expr2.eq(&expr1)); + + assert!(expr2.ne(&expr3)); + assert!(expr3.ne(&expr2)); + + assert!(expr1.ne(&expr4)); + assert!(expr4.ne(&expr1)); + + Ok(()) + } + + #[test] + fn case_transform() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![ + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), + ], + Some(Arc::clone(&else_value)), + &schema, + )?; + + let expr2 = Arc::clone(&expr) + .transform(|e| { + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::yes(transformed) + } else { + Transformed::no(e) + }) + }) + .data() + .unwrap(); + + let expr3 = Arc::clone(&expr) + .transform_down(|e| { + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } + _ => None, + }, + _ => None, + }; + Ok(if let Some(transformed) = transformed { + Transformed::yes(transformed) + } else { + Transformed::no(e) + }) + }) + .data() + .unwrap(); + + assert!(expr.ne(&expr2)); + assert!(expr2.eq(&expr3)); + + Ok(()) + } + + #[test] + fn test_column_or_null_specialization() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(format!("string {i}")); } - _ => None, - }, - _ => None, - }; - Ok(if let Some(transformed) = transformed { - Transformed::yes(transformed) - } else { - Transformed::no(e) - }) - }) - .data() - .unwrap(); - - let expr3 = Arc::clone(&expr) - .transform_down(|e| { - let transformed = match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + match expr.evaluate(&batch)? { + ColumnarValue::Array(array) => { + assert_eq!(1000, array.len()); + assert_eq!(785, array.null_count()); } - _ => None, - }, - _ => None, - }; - Ok(if let Some(transformed) = transformed { - Transformed::yes(transformed) - } else { - Transformed::no(e) - }) - }) - .data() - .unwrap(); - - assert!(expr.ne(&expr2)); - assert!(expr2.eq(&expr3)); - - Ok(()) - } - - #[test] - fn test_column_or_null_specialization() -> Result<()> { - // create input data - let mut c1 = Int32Builder::new(); - let mut c2 = StringBuilder::new(); - for i in 0..1000 { - c1.append_value(i); - if i % 7 == 0 { - c2.append_null(); - } else { - c2.append_value(format!("string {i}")); - } - } - let c1 = Arc::new(c1.finish()); - let c2 = Arc::new(c2.finish()); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); - - // CaseWhenExprOrNull should produce same results as CaseExpr - let predicate = Arc::new(BinaryExpr::new( - make_col("c1", 0), - Operator::LtEq, - make_lit_i32(250), - )); - let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; - assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); - match expr.evaluate(&batch)? { - ColumnarValue::Array(array) => { - assert_eq!(1000, array.len()); - assert_eq!(785, array.null_count()); - } - _ => unreachable!(), + _ => unreachable!(), + } + Ok(()) } - Ok(()) - } - - #[test] - fn test_expr_or_expr_specialization() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - let when = binary( - col("a", &schema)?, - Operator::LtEq, - lit(2i32), - &batch.schema(), - )?; - let then = col("b", &schema)?; - let else_expr = col("c", &schema)?; - let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; - assert!(matches!( + + #[test] + fn test_expr_or_expr_specialization() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + let when = binary( + col("a", &schema)?, + Operator::LtEq, + lit(2i32), + &batch.schema(), + )?; + let then = col("b", &schema)?; + let else_expr = col("c", &schema)?; + let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; + assert!(matches!( expr.eval_method, EvalMethod::ExpressionOrExpression )); - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); - let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]); + let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]); - assert_eq!(expected, result); - Ok(()) - } + assert_eq!(expected, result); + Ok(()) + } - fn make_col(name: &str, index: usize) -> Arc { - Arc::new(Column::new(name, index)) - } + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } - fn make_lit_i32(n: i32) -> Arc { - Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) - } + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } - fn generate_case_when_with_type_coercion( - expr: Option>, - when_thens: Vec, - else_expr: Option>, - input_schema: &Schema, - ) -> Result> { - let coerce_type = - get_case_common_type(&when_thens, else_expr.clone(), input_schema); - let (when_thens, else_expr) = match coerce_type { + fn generate_case_when_with_type_coercion( + expr: Option>, + when_thens: Vec, + else_expr: Option>, + input_schema: &Schema, + ) -> Result> { + let coerce_type = + get_case_common_type(&when_thens, else_expr.clone(), input_schema); + let (when_thens, else_expr) = match coerce_type { None => plan_err!( "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" ), @@ -1433,65 +1447,65 @@ mod tests { Ok((left, right)) } }?; - case(expr, when_thens, else_expr) - } + case(expr, when_thens, else_expr) + } - fn get_case_common_type( - when_thens: &[WhenThen], - else_expr: Option>, - input_schema: &Schema, - ) -> Option { - let thens_type = when_thens - .iter() - .map(|when_then| { - let data_type = &when_then.1.data_type(input_schema).unwrap(); - data_type.clone() - }) - .collect::>(); - let else_type = match else_expr { - None => { - // case when then exprs must have one then value - thens_type[0].clone() - } - Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), - }; - thens_type - .iter() - .try_fold(else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) - }) - } - - #[test] - fn test_fmt_sql() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - - // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END - let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?; - let then = lit(123.3f64); - let else_value = lit(999i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when, then)], - Some(else_value), - &schema, - )?; - - let display_string = expr.to_string(); - assert_eq!( - display_string, - "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" - ); - - let sql_string = fmt_sql(expr.as_ref()).to_string(); - assert_eq!( - sql_string, - "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" - ); - - Ok(()) - } + fn get_case_common_type( + when_thens: &[WhenThen], + else_expr: Option>, + input_schema: &Schema, + ) -> Option { + let thens_type = when_thens + .iter() + .map(|when_then| { + let data_type = &when_then.1.data_type(input_schema).unwrap(); + data_type.clone() + }) + .collect::>(); + let else_type = match else_expr { + None => { + // case when then exprs must have one then value + thens_type[0].clone() + } + Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), + }; + thens_type + .iter() + .try_fold(else_type, |left_type, right_type| { + // TODO: now just use the `equal` coercion rule for case when. If find the issue, and + // refactor again. + comparison_coercion(&left_type, right_type) + }) + } + + #[test] + fn test_fmt_sql() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?; + let then = lit(123.3f64); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + Some(else_value), + &schema, + )?; + + let display_string = expr.to_string(); + assert_eq!( + display_string, + "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" + ); + + let sql_string = fmt_sql(expr.as_ref()).to_string(); + assert_eq!( + sql_string, + "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END" + ); + + Ok(()) + } } From 1787d54ca71f17dbba749d46de72ca9f6d27ac64 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:06:55 +0300 Subject: [PATCH 08/22] add benchmarks for lookup --- datafusion/physical-expr/benches/case_when.rs | 305 +++++++++++++++++- .../physical-expr/src/expressions/case/mod.rs | 2 + 2 files changed, 304 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index ec850047e586..414f2e9dfdf2 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,13 +15,24 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder}; -use arrow::datatypes::{Field, Schema}; +use arrow::array::{ + Array, ArrayRef, Int32Array, Int32Builder, + StringArray, +}; +use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use arrow::util::test_util::seedable_rng; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, RngCore}; +use std::fmt::{Display, Formatter}; +use std::ops::Range; use std::sync::Arc; fn make_x_cmp_y( @@ -82,6 +93,8 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 3)); run_benchmarks(c, &make_batch(8192, 50)); run_benchmarks(c, &make_batch(8192, 100)); + + benchmark_lookup_table_case_when(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -228,6 +241,292 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { ); b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); + + // Lookup table with literal values + // when(col("platform_type") === 1, "Desktop") + // .when(col("platform_type") === 2, "iPhone") + // .when(col("platform_type") === 3, "iPad") + // .when(col("platform_type") === 4, "Android") + // .when(col("platform_type") === 5, "iPod") + // .when(col("platform_type") === 6, "Mobile-Other") + // .when(col("platform_type") === 7, "Android-Tablet") + // .when(col("platform_type") === 30, "Email") + // .when(col("platform_type") === 31, "Email") + // .when(col("platform_type") === 90, "Facebook") + // .when(col("platform_type") === 91, "Facebook") + // .when(col("platform_type") === 92, "Facebook") + // .when(col("platform_type") === 93, "Facebook") + // .when(col("platform_type") === 120, "API-Feed") + // .when(col("platform_type") === 240, "Web-group") + // .when(col("platform_type") === 241, "Mobile") + // .when(col("platform_type") === 244, "Facebook-group") + // .otherwise("Else") + // .alias("platform_type"), +} + +struct Options { + number_of_rows: usize, + range_of_values: Vec, + in_range_probability: f32, + null_probability: f32, +} + +fn generate_other_primitive_value( + rng: &mut impl RngCore, + exclude: &[T], +) -> T { + let mut value; + let retry_limit = 100; + for _ in 0..retry_limit { + value = rng.random_range(T::MIN_TOTAL_ORDER..=T::MAX_TOTAL_ORDER); + if !exclude.contains(&value) { + return value; + } + } + + panic!( + "Could not generate out of range value after {retry_limit} attempts" + ); +} + +fn create_random_string_generator( + length: Range, +) -> impl Fn(&mut dyn RngCore, &[String]) -> String { + assert!(length.end > length.start); + + move |rng, exclude| { + let retry_limit = 100; + for _ in 0..retry_limit { + let length = rng.random_range(length.clone()); + let value: String = rng + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + + if !exclude.contains(&value) { + return value; + } + } + + panic!( + "Could not generate out of range value after {retry_limit} attempts" + ); + } +} + +/// Create column with the provided number of rows +/// `in_range_percentage` is the percentage of values that should be inside the specified range +/// `null_percentage` is the percentage of null values +/// The rest of the values will be outside the specified range +fn generate_values_for_lookup( + options: Options, + generate_other_value: impl Fn(&mut StdRng, &[T]) -> T, +) -> A +where + T: Clone, + A: FromIterator>, +{ + // Create a value with specified range most of the time, but also some nulls and the rest is generic + + assert!( + options.in_range_probability + options.null_probability <= 1.0, + "Percentages must sum to 1.0 or less" + ); + + let rng = &mut seedable_rng(); + + let in_range_probability = 0.0..options.in_range_probability; + let null_range_probability = in_range_probability.start + ..in_range_probability.start + options.null_probability; + let out_range_probability = null_range_probability.end..1.0; + + (0..options.number_of_rows) + .map(|_| { + let roll: f32 = rng.random(); + + match roll { + v if out_range_probability.contains(&v) => { + let index = rng.random_range(0..options.range_of_values.len()); + // Generate value in range + Some(options.range_of_values[index].clone()) + } + v if null_range_probability.contains(&v) => None, + _ => { + // Generate value out of range + Some(generate_other_value(rng, &options.range_of_values)) + } + } + }) + .collect::() +} + +fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { + #[derive(Clone, Copy, Debug)] + struct CaseWhenLookupInput { + batch_size: usize, + + in_range_probability: f32, + null_probability: f32, + } + + impl Display for CaseWhenLookupInput { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "case_when {} rows: in_range: {}, nulls: {}", + self.batch_size, self.in_range_probability, self.null_probability, + ) + } + } + + let mut case_when_lookup = c.benchmark_group("lookup_table_case_when"); + + for in_range_probability in [0.1, 0.5, 0.9, 1.0] { + for null_probability in [0.0, 0.1, 0.5] { + if in_range_probability + null_probability > 1.0 { + continue; + } + + let input = CaseWhenLookupInput { + batch_size, + in_range_probability, + null_probability, + }; + + let when_thens_primitive_to_string = vec![ + (1, "something"), + (2, "very"), + (3, "interesting"), + (4, "is"), + (5, "going"), + (6, "to"), + (7, "happen"), + (30, "in"), + (31, "datafusion"), + (90, "when"), + (91, "you"), + (92, "find"), + (93, "it"), + (120, "let"), + (240, "me"), + (241, "know"), + (244, "please"), + (246, "thank"), + (250, "you"), + ]; + let when_thens_string_to_primitive = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (value, key)) + .collect_vec(); + + for num_entries in [5, 10, 20] { + let when_thens_primitive_to_string = + when_thens_primitive_to_string[..num_entries].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!("case when i32 -> utf8, {num_entries} entries"), + input, + ), + &input, + |b, input| { + let array: Int32Array = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_primitive_to_string + .iter() + .map(|(key, _)| *key) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + generate_other_primitive_value::(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit("whatever")), + ) + .unwrap(), + ); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + + let when_thens_string_to_primitive = + when_thens_string_to_primitive[..num_entries].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!("case when utf8 -> i32, {num_entries} entries"), + input, + ), + &input, + |b, input| { + let array: StringArray = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_string_to_primitive + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + create_random_string_generator(3..10)(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_string_to_primitive + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit(1000)), + ) + .unwrap(), + ); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + } + } + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 3dccdfbb594e..d117ebb63b32 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -1508,4 +1508,6 @@ mod tests { Ok(()) } + + // TODO - add tests for case when with lookup table specialization } From b2d4b5195f57a345a00b3bf7e425efbdd0086ef2 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:09:32 +0300 Subject: [PATCH 09/22] fix test --- datafusion/physical-expr/benches/case_when.rs | 1 + .../case/literal_lookup_table/mod.rs | 22 ------------------- .../physical-expr/src/expressions/case/mod.rs | 16 ++++++++++++++ 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 414f2e9dfdf2..c40d0fd984ae 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -414,6 +414,7 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { (244, "please"), (246, "thank"), (250, "you"), + (252, "!"), ]; let when_thens_string_to_primitive = when_thens_primitive_to_string .iter() diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index b46707a10063..3776a7543b12 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -67,28 +67,6 @@ pub(in super::super) struct LiteralLookupTable { values_to_take_from: ArrayRef, } -impl Hash for LiteralLookupTable { - fn hash(&self, state: &mut H) { - // Hashing the pointer as this is the best we can do here - - let lookup_ptr = Arc::as_ptr(&self.lookup); - lookup_ptr.hash(state); - - let values_ptr = Arc::as_ptr(&self.lookup); - values_ptr.hash(state); - } -} - -impl PartialEq for LiteralLookupTable { - fn eq(&self, other: &Self) -> bool { - // Comparing the pointers as this is the best we can do here - Arc::ptr_eq(&self.lookup, &other.lookup) - && self.values_to_take_from.as_ref() == other.values_to_take_from.as_ref() - } -} - -impl Eq for LiteralLookupTable {} - impl LiteralLookupTable { pub(in super::super) fn maybe_new( when_then_expr: &Vec, diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index d117ebb63b32..31490e30fab7 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -75,6 +75,22 @@ enum EvalMethod { WithExprScalarLookupTable(LiteralLookupTable), } + +// Implement empty hash as the data is derived from PhysicalExprs which are already hashed +impl Hash for LiteralLookupTable { + fn hash(&self, _state: &mut H) { + } +} + +// Implement always equal as the data is derived from PhysicalExprs which are already compared +impl PartialEq for LiteralLookupTable { + fn eq(&self, other: &Self) -> bool { + true + } +} + +impl Eq for LiteralLookupTable {} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. From 81d82e36231f3324f39011849a51e9aa16c03079 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:10:21 +0300 Subject: [PATCH 10/22] remove --- datafusion/physical-expr/benches/case_when.rs | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index c40d0fd984ae..d3961649de52 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -90,9 +90,9 @@ fn make_batch(row_count: usize, column_count: usize) -> RecordBatch { } fn criterion_benchmark(c: &mut Criterion) { - run_benchmarks(c, &make_batch(8192, 3)); - run_benchmarks(c, &make_batch(8192, 50)); - run_benchmarks(c, &make_batch(8192, 100)); + // run_benchmarks(c, &make_batch(8192, 3)); + // run_benchmarks(c, &make_batch(8192, 50)); + // run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); } @@ -241,27 +241,6 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { ); b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); - - // Lookup table with literal values - // when(col("platform_type") === 1, "Desktop") - // .when(col("platform_type") === 2, "iPhone") - // .when(col("platform_type") === 3, "iPad") - // .when(col("platform_type") === 4, "Android") - // .when(col("platform_type") === 5, "iPod") - // .when(col("platform_type") === 6, "Mobile-Other") - // .when(col("platform_type") === 7, "Android-Tablet") - // .when(col("platform_type") === 30, "Email") - // .when(col("platform_type") === 31, "Email") - // .when(col("platform_type") === 90, "Facebook") - // .when(col("platform_type") === 91, "Facebook") - // .when(col("platform_type") === 92, "Facebook") - // .when(col("platform_type") === 93, "Facebook") - // .when(col("platform_type") === 120, "API-Feed") - // .when(col("platform_type") === 240, "Web-group") - // .when(col("platform_type") === 241, "Mobile") - // .when(col("platform_type") === 244, "Facebook-group") - // .otherwise("Else") - // .alias("platform_type"), } struct Options { From a7f54b4131415407c866388270a147957e521179 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:20:16 +0300 Subject: [PATCH 11/22] remove --- datafusion/physical-expr/benches/case_when.rs | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index d3961649de52..4f06066e72c6 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use arrow::array::{ Array, ArrayRef, Int32Array, Int32Builder, StringArray, @@ -247,6 +248,13 @@ struct Options { number_of_rows: usize, range_of_values: Vec, in_range_probability: f32, + + /// (value index, probability) + /// Used to weight the selection of in-range values + /// If empty, all in-range values are equally likely + /// the rest of the in-range values will have equal probability + /// the sum of all probabilities must be less than or equal to 1.0 + value_probability: Vec<(usize, f32)>, null_probability: f32, } @@ -313,6 +321,15 @@ where "Percentages must sum to 1.0 or less" ); + let total_value_probability: f32 = options + .value_probability + .iter() + .map(|(_, p)| *p) + .sum(); + assert!(total_value_probability <= 1.0, "Value probabilities must sum to 1.0 or less"); + options.value_probability.into_iter().collect::>() + let filled_value_probability = + let rng = &mut seedable_rng(); let in_range_probability = 0.0..options.in_range_probability; @@ -326,6 +343,10 @@ where match roll { v if out_range_probability.contains(&v) => { + if options.value_probability.is_empty() { + // No values in range, generate any value + Some(generate_other_value(rng, &[])) + } let index = rng.random_range(0..options.range_of_values.len()); // Generate value in range Some(options.range_of_values[index].clone()) From 67933888007c5cd98db5cf940ceadc1144285c6c Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:37:00 +0300 Subject: [PATCH 12/22] cleanup --- datafusion/physical-expr/benches/case_when.rs | 229 +++++++++--------- 1 file changed, 108 insertions(+), 121 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 4f06066e72c6..4f9868d9703e 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use arrow::array::{ Array, ArrayRef, Int32Array, Int32Builder, StringArray, @@ -248,13 +247,6 @@ struct Options { number_of_rows: usize, range_of_values: Vec, in_range_probability: f32, - - /// (value index, probability) - /// Used to weight the selection of in-range values - /// If empty, all in-range values are equally likely - /// the rest of the in-range values will have equal probability - /// the sum of all probabilities must be less than or equal to 1.0 - value_probability: Vec<(usize, f32)>, null_probability: f32, } @@ -321,15 +313,6 @@ where "Percentages must sum to 1.0 or less" ); - let total_value_probability: f32 = options - .value_probability - .iter() - .map(|(_, p)| *p) - .sum(); - assert!(total_value_probability <= 1.0, "Value probabilities must sum to 1.0 or less"); - options.value_probability.into_iter().collect::>() - let filled_value_probability = - let rng = &mut seedable_rng(); let in_range_probability = 0.0..options.in_range_probability; @@ -343,10 +326,6 @@ where match roll { v if out_range_probability.contains(&v) => { - if options.value_probability.is_empty() { - // No values in range, generate any value - Some(generate_other_value(rng, &[])) - } let index = rng.random_range(0..options.range_of_values.len()); // Generate value in range Some(options.range_of_values[index].clone()) @@ -422,109 +401,117 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { .collect_vec(); for num_entries in [5, 10, 20] { - let when_thens_primitive_to_string = - when_thens_primitive_to_string[..num_entries].to_vec(); - - case_when_lookup.bench_with_input( - BenchmarkId::new( - format!("case when i32 -> utf8, {num_entries} entries"), - input, - ), - &input, - |b, input| { - let array: Int32Array = generate_values_for_lookup( - Options:: { - number_of_rows: batch_size, - range_of_values: when_thens_primitive_to_string - .iter() - .map(|(key, _)| *key) - .collect(), - in_range_probability: input.in_range_probability, - null_probability: input.null_probability, - }, - |rng, exclude| { - generate_other_primitive_value::(rng, exclude) - }, - ); - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new( - "col1", - array.data_type().clone(), - true, - )])), - vec![Arc::new(array)], - ) - .unwrap(); - - let when_thens = when_thens_primitive_to_string - .iter() - .map(|&(key, value)| (lit(key), lit(value))) - .collect(); - - let expr = Arc::new( - case( - Some(col("col1", batch.schema_ref()).unwrap()), - when_thens, - Some(lit("whatever")), + + for (name, values_range) in [ + ("all equally true", 0..num_entries), + ("only first 2 are true", 0..2), + ] { + + let when_thens_primitive_to_string = + when_thens_primitive_to_string[values_range.clone()].to_vec(); + + let when_thens_string_to_primitive = + when_thens_string_to_primitive[values_range].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!("case when i32 -> utf8, {num_entries} entries, {name}"), + input, + ), + &input, + |b, input| { + let array: Int32Array = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_primitive_to_string + .iter() + .map(|(key, _)| *key) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + generate_other_primitive_value::(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], ) - .unwrap(), - ); - - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) - }, - ); - - let when_thens_string_to_primitive = - when_thens_string_to_primitive[..num_entries].to_vec(); - - case_when_lookup.bench_with_input( - BenchmarkId::new( - format!("case when utf8 -> i32, {num_entries} entries"), - input, - ), - &input, - |b, input| { - let array: StringArray = generate_values_for_lookup( - Options:: { - number_of_rows: batch_size, - range_of_values: when_thens_string_to_primitive - .iter() - .map(|(key, _)| (*key).to_string()) - .collect(), - in_range_probability: input.in_range_probability, - null_probability: input.null_probability, - }, - |rng, exclude| { - create_random_string_generator(3..10)(rng, exclude) - }, - ); - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new( - "col1", - array.data_type().clone(), - true, - )])), - vec![Arc::new(array)], - ) - .unwrap(); - - let when_thens = when_thens_string_to_primitive - .iter() - .map(|&(key, value)| (lit(key), lit(value))) - .collect(); - - let expr = Arc::new( - case( - Some(col("col1", batch.schema_ref()).unwrap()), - when_thens, - Some(lit(1000)), + .unwrap(); + + let when_thens = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit("whatever")), + ) + .unwrap(), + ); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!("case when utf8 -> i32, {num_entries} entries, {name}"), + input, + ), + &input, + |b, input| { + let array: StringArray = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_string_to_primitive + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + create_random_string_generator(3..10)(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], ) - .unwrap(), - ); + .unwrap(); + + let when_thens = when_thens_string_to_primitive + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit(1000)), + ) + .unwrap(), + ); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + } - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) - }, - ); } } } From d9448251b317c122071750e421165d559a0e83c2 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:39:48 +0300 Subject: [PATCH 13/22] format and lint --- datafusion/physical-expr/benches/case_when.rs | 78 +++++++++---------- .../case/literal_lookup_table/mod.rs | 1 - .../physical-expr/src/expressions/case/mod.rs | 6 +- 3 files changed, 40 insertions(+), 45 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 4f9868d9703e..7cd43e7d4735 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - Array, ArrayRef, Int32Array, Int32Builder, - StringArray, -}; +use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder, StringArray}; use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::seedable_rng; @@ -90,9 +87,9 @@ fn make_batch(row_count: usize, column_count: usize) -> RecordBatch { } fn criterion_benchmark(c: &mut Criterion) { - // run_benchmarks(c, &make_batch(8192, 3)); - // run_benchmarks(c, &make_batch(8192, 50)); - // run_benchmarks(c, &make_batch(8192, 100)); + run_benchmarks(c, &make_batch(8192, 3)); + run_benchmarks(c, &make_batch(8192, 50)); + run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); } @@ -263,9 +260,7 @@ fn generate_other_primitive_value( } } - panic!( - "Could not generate out of range value after {retry_limit} attempts" - ); + panic!("Could not generate out of range value after {retry_limit} attempts"); } fn create_random_string_generator( @@ -288,9 +283,7 @@ fn create_random_string_generator( } } - panic!( - "Could not generate out of range value after {retry_limit} attempts" - ); + panic!("Could not generate out of range value after {retry_limit} attempts"); } } @@ -316,8 +309,8 @@ where let rng = &mut seedable_rng(); let in_range_probability = 0.0..options.in_range_probability; - let null_range_probability = in_range_probability.start - ..in_range_probability.start + options.null_probability; + let null_range_probability = + in_range_probability.start..in_range_probability.start + options.null_probability; let out_range_probability = null_range_probability.end..1.0; (0..options.number_of_rows) @@ -401,21 +394,21 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { .collect_vec(); for num_entries in [5, 10, 20] { - for (name, values_range) in [ ("all equally true", 0..num_entries), ("only first 2 are true", 0..2), ] { - let when_thens_primitive_to_string = - when_thens_primitive_to_string[values_range.clone()].to_vec(); + when_thens_primitive_to_string[values_range.clone()].to_vec(); let when_thens_string_to_primitive = - when_thens_string_to_primitive[values_range].to_vec(); + when_thens_string_to_primitive[values_range].to_vec(); case_when_lookup.bench_with_input( BenchmarkId::new( - format!("case when i32 -> utf8, {num_entries} entries, {name}"), + format!( + "case when i32 -> utf8, {num_entries} entries, {name}" + ), input, ), &input, @@ -424,9 +417,9 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { Options:: { number_of_rows: batch_size, range_of_values: when_thens_primitive_to_string - .iter() - .map(|(key, _)| *key) - .collect(), + .iter() + .map(|(key, _)| *key) + .collect(), in_range_probability: input.in_range_probability, null_probability: input.null_probability, }, @@ -442,12 +435,12 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { )])), vec![Arc::new(array)], ) - .unwrap(); + .unwrap(); let when_thens = when_thens_primitive_to_string - .iter() - .map(|&(key, value)| (lit(key), lit(value))) - .collect(); + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); let expr = Arc::new( case( @@ -455,16 +448,20 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { when_thens, Some(lit("whatever")), ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) }, ); case_when_lookup.bench_with_input( BenchmarkId::new( - format!("case when utf8 -> i32, {num_entries} entries, {name}"), + format!( + "case when utf8 -> i32, {num_entries} entries, {name}" + ), input, ), &input, @@ -473,9 +470,9 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { Options:: { number_of_rows: batch_size, range_of_values: when_thens_string_to_primitive - .iter() - .map(|(key, _)| (*key).to_string()) - .collect(), + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), in_range_probability: input.in_range_probability, null_probability: input.null_probability, }, @@ -491,12 +488,12 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { )])), vec![Arc::new(array)], ) - .unwrap(); + .unwrap(); let when_thens = when_thens_string_to_primitive - .iter() - .map(|&(key, value)| (lit(key), lit(value))) - .collect(); + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); let expr = Arc::new( case( @@ -504,14 +501,15 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { when_thens, Some(lit(1000)), ) - .unwrap(), + .unwrap(), ); - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) }, ); } - } } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 3776a7543b12..fe66ba75ebea 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -20,7 +20,6 @@ use datafusion_common::DataFusionError; use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::fmt::Debug; -use std::hash::Hash; use std::sync::Arc; /// Optimization for CASE expressions with literal WHEN and THEN clauses diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 31490e30fab7..04d1646dfddf 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -75,16 +75,14 @@ enum EvalMethod { WithExprScalarLookupTable(LiteralLookupTable), } - // Implement empty hash as the data is derived from PhysicalExprs which are already hashed impl Hash for LiteralLookupTable { - fn hash(&self, _state: &mut H) { - } + fn hash(&self, _state: &mut H) {} } // Implement always equal as the data is derived from PhysicalExprs which are already compared impl PartialEq for LiteralLookupTable { - fn eq(&self, other: &Self) -> bool { + fn eq(&self, _other: &Self) -> bool { true } } From d3d5a32e5bbf8f860fc228d640f06e9c8b8db358 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:41:53 +0300 Subject: [PATCH 14/22] bench: create benchmark for lookup table like case when --- datafusion/physical-expr/benches/case_when.rs | 291 +++++++++++++++++- 1 file changed, 288 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index ec850047e586..7cd43e7d4735 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder}; -use arrow::datatypes::{Field, Schema}; +use arrow::array::{Array, ArrayRef, Int32Array, Int32Builder, StringArray}; +use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use arrow::util::test_util::seedable_rng; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, RngCore}; +use std::fmt::{Display, Formatter}; +use std::ops::Range; use std::sync::Arc; fn make_x_cmp_y( @@ -82,6 +90,8 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 3)); run_benchmarks(c, &make_batch(8192, 50)); run_benchmarks(c, &make_batch(8192, 100)); + + benchmark_lookup_table_case_when(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -230,5 +240,280 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { }); } +struct Options { + number_of_rows: usize, + range_of_values: Vec, + in_range_probability: f32, + null_probability: f32, +} + +fn generate_other_primitive_value( + rng: &mut impl RngCore, + exclude: &[T], +) -> T { + let mut value; + let retry_limit = 100; + for _ in 0..retry_limit { + value = rng.random_range(T::MIN_TOTAL_ORDER..=T::MAX_TOTAL_ORDER); + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); +} + +fn create_random_string_generator( + length: Range, +) -> impl Fn(&mut dyn RngCore, &[String]) -> String { + assert!(length.end > length.start); + + move |rng, exclude| { + let retry_limit = 100; + for _ in 0..retry_limit { + let length = rng.random_range(length.clone()); + let value: String = rng + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + + if !exclude.contains(&value) { + return value; + } + } + + panic!("Could not generate out of range value after {retry_limit} attempts"); + } +} + +/// Create column with the provided number of rows +/// `in_range_percentage` is the percentage of values that should be inside the specified range +/// `null_percentage` is the percentage of null values +/// The rest of the values will be outside the specified range +fn generate_values_for_lookup( + options: Options, + generate_other_value: impl Fn(&mut StdRng, &[T]) -> T, +) -> A +where + T: Clone, + A: FromIterator>, +{ + // Create a value with specified range most of the time, but also some nulls and the rest is generic + + assert!( + options.in_range_probability + options.null_probability <= 1.0, + "Percentages must sum to 1.0 or less" + ); + + let rng = &mut seedable_rng(); + + let in_range_probability = 0.0..options.in_range_probability; + let null_range_probability = + in_range_probability.start..in_range_probability.start + options.null_probability; + let out_range_probability = null_range_probability.end..1.0; + + (0..options.number_of_rows) + .map(|_| { + let roll: f32 = rng.random(); + + match roll { + v if out_range_probability.contains(&v) => { + let index = rng.random_range(0..options.range_of_values.len()); + // Generate value in range + Some(options.range_of_values[index].clone()) + } + v if null_range_probability.contains(&v) => None, + _ => { + // Generate value out of range + Some(generate_other_value(rng, &options.range_of_values)) + } + } + }) + .collect::() +} + +fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { + #[derive(Clone, Copy, Debug)] + struct CaseWhenLookupInput { + batch_size: usize, + + in_range_probability: f32, + null_probability: f32, + } + + impl Display for CaseWhenLookupInput { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "case_when {} rows: in_range: {}, nulls: {}", + self.batch_size, self.in_range_probability, self.null_probability, + ) + } + } + + let mut case_when_lookup = c.benchmark_group("lookup_table_case_when"); + + for in_range_probability in [0.1, 0.5, 0.9, 1.0] { + for null_probability in [0.0, 0.1, 0.5] { + if in_range_probability + null_probability > 1.0 { + continue; + } + + let input = CaseWhenLookupInput { + batch_size, + in_range_probability, + null_probability, + }; + + let when_thens_primitive_to_string = vec![ + (1, "something"), + (2, "very"), + (3, "interesting"), + (4, "is"), + (5, "going"), + (6, "to"), + (7, "happen"), + (30, "in"), + (31, "datafusion"), + (90, "when"), + (91, "you"), + (92, "find"), + (93, "it"), + (120, "let"), + (240, "me"), + (241, "know"), + (244, "please"), + (246, "thank"), + (250, "you"), + (252, "!"), + ]; + let when_thens_string_to_primitive = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (value, key)) + .collect_vec(); + + for num_entries in [5, 10, 20] { + for (name, values_range) in [ + ("all equally true", 0..num_entries), + ("only first 2 are true", 0..2), + ] { + let when_thens_primitive_to_string = + when_thens_primitive_to_string[values_range.clone()].to_vec(); + + let when_thens_string_to_primitive = + when_thens_string_to_primitive[values_range].to_vec(); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when i32 -> utf8, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: Int32Array = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_primitive_to_string + .iter() + .map(|(key, _)| *key) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + generate_other_primitive_value::(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_primitive_to_string + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit("whatever")), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + + case_when_lookup.bench_with_input( + BenchmarkId::new( + format!( + "case when utf8 -> i32, {num_entries} entries, {name}" + ), + input, + ), + &input, + |b, input| { + let array: StringArray = generate_values_for_lookup( + Options:: { + number_of_rows: batch_size, + range_of_values: when_thens_string_to_primitive + .iter() + .map(|(key, _)| (*key).to_string()) + .collect(), + in_range_probability: input.in_range_probability, + null_probability: input.null_probability, + }, + |rng, exclude| { + create_random_string_generator(3..10)(rng, exclude) + }, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "col1", + array.data_type().clone(), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(); + + let when_thens = when_thens_string_to_primitive + .iter() + .map(|&(key, value)| (lit(key), lit(value))) + .collect(); + + let expr = Arc::new( + case( + Some(col("col1", batch.schema_ref()).unwrap()), + when_thens, + Some(lit(1000)), + ) + .unwrap(), + ); + + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()) + }) + }, + ); + } + } + } + } +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); From 154710b2f610c718df35509c692d707204e7cb8b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:47:30 +0300 Subject: [PATCH 15/22] added comment --- datafusion/physical-expr/benches/case_when.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 7cd43e7d4735..78895298e08a 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -396,6 +396,8 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { for num_entries in [5, 10, 20] { for (name, values_range) in [ ("all equally true", 0..num_entries), + + // Test when early termination is beneficial ("only first 2 are true", 0..2), ] { let when_thens_primitive_to_string = From 3a8db91e407400a54fe4c448fe3e4c6f65b3c7a5 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:59:53 +0300 Subject: [PATCH 16/22] format --- datafusion/physical-expr/benches/case_when.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 78895298e08a..e52aeb1aee12 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -396,8 +396,7 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { for num_entries in [5, 10, 20] { for (name, values_range) in [ ("all equally true", 0..num_entries), - - // Test when early termination is beneficial + // Test when early termination is beneficial ("only first 2 are true", 0..2), ] { let when_thens_primitive_to_string = From e8f5cbb4b64cf0ac80becb08286b57c607074421 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:35:31 +0300 Subject: [PATCH 17/22] only keep first occurrence --- .../case/literal_lookup_table/mod.rs | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index fe66ba75ebea..7225695bf38b 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -19,6 +19,7 @@ use arrow::datatypes::{ use datafusion_common::DataFusionError; use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use indexmap::IndexMap; use std::fmt::Debug; use std::sync::Arc; @@ -96,15 +97,28 @@ impl LiteralLookupTable { return None; } - let (when_literals, then_literals): (Vec, Vec) = - when_then_exprs_maybe_literals - .iter() - // Unwrap the options as we have already checked they are all Some - .flatten() - .map(|(when_lit, then_lit)| { - (when_lit.value().clone(), then_lit.value().clone()) - }) - .unzip(); + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked they are all Some + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + .collect::>(); + + // Keep only the first occurrence of each when literal + let (when_literals, then_literals): (Vec, Vec) = { + let mut map = IndexMap::with_capacity(when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; let else_expr: ScalarValue = if let Some(else_expr) = else_expr { let literal = else_expr.as_any().downcast_ref::()?; From a803734b85e3abaa1d155fbe0c73688b010d6158 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:43:33 +0300 Subject: [PATCH 18/22] add license header --- .../boolean_lookup_table.rs | 17 +++++++++++++++++ .../bytes_like_lookup_table.rs | 17 +++++++++++++++++ .../case/literal_lookup_table/mod.rs | 17 +++++++++++++++++ .../primitive_lookup_table.rs | 17 +++++++++++++++++ 4 files changed, 68 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs index b85f877339a1..40efc9755649 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ArrayRef, AsArray}; use datafusion_common::ScalarValue; diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index 969309716c63..b912b9f15320 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 7225695bf38b..75f7e66a378d 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + mod boolean_lookup_table; mod bytes_like_lookup_table; mod primitive_lookup_table; diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs index 86d707b1d9ca..8d75e2350cbd 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; From ca059da4c1e5a6a9e623145fdaacb1a28d9cc8d6 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 21 Oct 2025 21:27:56 +0300 Subject: [PATCH 19/22] fix doc --- .../src/expressions/case/literal_lookup_table/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 75f7e66a378d..b9547b18170e 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -79,7 +79,7 @@ pub(in super::super) struct LiteralLookupTable { /// The lookup table to use for evaluating the CASE expression lookup: Arc, - /// ArrayRef where array[i] = then_literals[i] + /// [`ArrayRef`] where `array[i] = then_literals[i]` /// the last value in the array is the else_expr values_to_take_from: ArrayRef, } From 4ebacb4723b92288a0072987a019800b8344c930 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:38:38 +0300 Subject: [PATCH 20/22] fix null handling in WHEN --- .../boolean_lookup_table.rs | 56 ++++++--- .../bytes_like_lookup_table.rs | 44 +++----- .../case/literal_lookup_table/mod.rs | 106 +++++++++++++----- .../primitive_lookup_table.rs | 34 +++--- .../physical-expr/src/expressions/case/mod.rs | 2 - 5 files changed, 153 insertions(+), 89 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs index 40efc9755649..4c47e904d78d 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -17,39 +17,61 @@ use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; use arrow::array::{ArrayRef, AsArray}; -use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, ScalarValue}; #[derive(Clone, Debug)] pub(super) struct BooleanIndexMap { true_index: i32, false_index: i32, - null_index: i32, + else_index: i32, } impl WhenLiteralIndexMap for BooleanIndexMap { fn try_new( - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result where Self: Sized, { - fn get_first_index( - literals: &[ScalarValue], - target: Option, - ) -> Option { - literals - .iter() - .position( - |literal| matches!(literal, ScalarValue::Boolean(v) if v == &target), - ) - .map(|pos| pos as i32) + let mut true_index: Option = None; + let mut false_index: Option = None; + + for (index, literal) in unique_non_null_literals.into_iter().enumerate() { + match literal { + ScalarValue::Boolean(Some(true)) => { + if true_index.is_some() { + return internal_err!( + "Duplicate true literal found in literals for BooleanIndexMap" + ); + } + true_index = Some(index as i32); + } + ScalarValue::Boolean(Some(false)) => { + if false_index.is_some() { + return internal_err!( + "Duplicate false literal found in literals for BooleanIndexMap" + ); + } + false_index = Some(index as i32); + } + ScalarValue::Boolean(None) => { + return internal_err!( + "Null literal found in non-null literals for BooleanIndexMap" + ) + } + _ => { + return internal_err!( + "Non-boolean literal found in literals for BooleanIndexMap" + ) + } + } } Ok(Self { - false_index: get_first_index(&literals, Some(false)).unwrap_or(else_index), - true_index: get_first_index(&literals, Some(true)).unwrap_or(else_index), - null_index: get_first_index(&literals, None).unwrap_or(else_index), + true_index: true_index.unwrap_or(else_index), + false_index: false_index.unwrap_or(else_index), + else_index, }) } @@ -60,7 +82,7 @@ impl WhenLiteralIndexMap for BooleanIndexMap { .map(|value| match value { Some(true) => self.true_index, Some(false) => self.false_index, - None => self.null_index, + None => self.else_index, }) .collect::>()) } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index b912b9f15320..6e9a465db7c5 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -21,7 +21,7 @@ use arrow::array::{ GenericByteArray, GenericByteViewArray, TypedDictionaryArray, }; use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; -use datafusion_common::{exec_datafusion_err, HashMap, ScalarValue}; +use datafusion_common::{exec_datafusion_err, internal_err, HashMap, ScalarValue}; use std::fmt::Debug; use std::iter::Map; use std::marker::PhantomData; @@ -195,9 +195,6 @@ pub(super) struct BytesLikeIndexMap { /// Map from non-null literal value the first occurrence index in the literals map: HashMap, i32>, - /// The index for null literal value (when no null value this will equal to `else_index`) - null_index: i32, - /// The index to return when no match is found else_index: i32, @@ -208,7 +205,6 @@ impl Debug for BytesLikeIndexMap { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("BytesMapHelper") .field("map", &self.map) - .field("null_index", &self.null_index) .field("else_index", &self.else_index) .finish() } @@ -218,37 +214,31 @@ impl WhenLiteralIndexMap for BytesLikeIndexMap { fn try_new( - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result where Self: Sized, { - let input = ScalarValue::iter_to_array(literals)?; - let bytes_iter = Helper::array_to_iter(&input)?; + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; - let mut null_index = None; - - let mut map: HashMap, i32> = HashMap::new(); - - for (map_index, value) in bytes_iter.enumerate() { - match value { - Some(value) => { - // Insert only the first occurrence - map.entry(value.to_vec()).or_insert(map_index as i32); - } - None => { - // Only set the null index once - if null_index.is_none() { - null_index = Some(map_index as i32); - } - } - } + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); } + let bytes_iter = Helper::array_to_iter(&input)?; + + let map: HashMap, i32> = bytes_iter + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value): (usize, &[u8])| (value.to_vec(), map_index as i32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + Ok(Self { map, - null_index: null_index.unwrap_or(else_index), else_index, _phantom_data: Default::default(), }) @@ -259,7 +249,7 @@ impl WhenLiteralIndexMap let indices = bytes_iter .map(|value| match value { Some(value) => self.map.get(value).copied().unwrap_or(self.else_index), - None => self.null_index, + None => self.else_index, }) .collect::>(); diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index b9547b18170e..b7e3d63954bf 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -99,6 +99,7 @@ impl LiteralLookupTable { return None; } + // Try to downcast all the WHEN/THEN expressions to literals let when_then_exprs_maybe_literals = when_then_expr .iter() .map(|(when, then)| { @@ -109,26 +110,37 @@ impl LiteralLookupTable { }) .collect::>(); - // If not all the when/then expressions are literals we cannot use this optimization + // If not all the WHEN/THEN expressions are literals we cannot use this optimization if when_then_exprs_maybe_literals.contains(&None) { return None; } let when_then_exprs_scalars = when_then_exprs_maybe_literals .into_iter() - // Unwrap the options as we have already checked they are all Some + // Unwrap the options as we have already checked there is no None .flatten() .map(|(when_lit, then_lit)| { (when_lit.value().clone(), then_lit.value().clone()) }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) .collect::>(); - // Keep only the first occurrence of each when literal + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) let (when_literals, then_literals): (Vec, Vec) = { let mut map = IndexMap::with_capacity(when_then_expr.len()); for (when, then) in when_then_exprs_scalars.into_iter() { - // Don't overwrite existing entries to keep the first occurrence + // Don't overwrite existing entries as we want to keep the first occurrence if !map.contains_key(&when) { map.insert(when, then); } @@ -154,7 +166,7 @@ impl LiteralLookupTable { { let data_type = when_literals[0].data_type(); - // If not all the when literals are the same data type we cannot use this optimization + // If not all the WHEN literals are the same data type we cannot use this optimization if when_literals.iter().any(|l| l.data_type() != data_type) { return None; } @@ -219,8 +231,10 @@ impl LiteralLookupTable { /// The else index is used when a value is not found in the lookup table pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { /// Try creating a new lookup table from the given literals and else index + /// + /// `literals` are guaranteed to be unique and non-nullable fn try_new( - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result where @@ -231,21 +245,28 @@ pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { } pub(crate) fn try_creating_lookup_table( - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result> { - assert_ne!(literals.len(), 0, "Must have at least one literal"); - match literals[0].data_type() { + assert_ne!( + unique_non_null_literals.len(), + 0, + "Must have at least one literal" + ); + match unique_non_null_literals[0].data_type() { DataType::Boolean => { - let lookup_table = BooleanIndexMap::try_new(literals, else_index)?; + let lookup_table = + BooleanIndexMap::try_new(unique_non_null_literals, else_index)?; Ok(Arc::new(lookup_table)) } data_type if data_type.is_primitive() => { macro_rules! create_matching_map { ($t:ty) => {{ - let lookup_table = - PrimitiveArrayMapHolder::<$t>::try_new(literals, else_index)?; + let lookup_table = PrimitiveArrayMapHolder::<$t>::try_new( + unique_non_null_literals, + else_index, + )?; Ok(Arc::new(lookup_table)) }}; } @@ -262,48 +283,60 @@ pub(crate) fn try_creating_lookup_table( DataType::Utf8 => { let lookup_table = BytesLikeIndexMap::< GenericBytesHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::LargeUtf8 => { let lookup_table = BytesLikeIndexMap::< GenericBytesHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::Binary => { let lookup_table = BytesLikeIndexMap::< GenericBytesHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::LargeBinary => { let lookup_table = BytesLikeIndexMap::< GenericBytesHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::FixedSizeBinary(_) => { - let lookup_table = - BytesLikeIndexMap::::try_new(literals, else_index)?; + let lookup_table = BytesLikeIndexMap::::try_new( + unique_non_null_literals, + else_index, + )?; Ok(Arc::new(lookup_table)) } DataType::Utf8View => { let lookup_table = BytesLikeIndexMap::>::try_new( - literals, else_index, + unique_non_null_literals, + else_index, )?; Ok(Arc::new(lookup_table)) } DataType::BinaryView => { let lookup_table = BytesLikeIndexMap::>::try_new( - literals, else_index, + unique_non_null_literals, + else_index, )?; Ok(Arc::new(lookup_table)) } @@ -313,7 +346,7 @@ pub(crate) fn try_creating_lookup_table( ($t:ty) => {{ create_lookup_table_for_dictionary_input::<$t>( value.as_ref(), - literals, + unique_non_null_literals, else_index, ) }}; @@ -326,14 +359,14 @@ pub(crate) fn try_creating_lookup_table( } _ => Err(plan_datafusion_err!( "Unsupported data type for lookup table: {}", - literals[0].data_type() + unique_non_null_literals[0].data_type() )), } } fn create_lookup_table_for_dictionary_input( value: &DataType, - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result> { // TODO - optimize dictionary to use different wrapper that takes advantage of it being a dictionary @@ -341,35 +374,44 @@ fn create_lookup_table_for_dictionary_input { let lookup_table = BytesLikeIndexMap::< BytesDictionaryHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::LargeUtf8 => { let lookup_table = BytesLikeIndexMap::< BytesDictionaryHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::Binary => { let lookup_table = BytesLikeIndexMap::< BytesDictionaryHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::LargeBinary => { let lookup_table = BytesLikeIndexMap::< BytesDictionaryHelper>, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::FixedSizeBinary(_) => { let lookup_table = BytesLikeIndexMap::>::try_new( - literals, else_index, + unique_non_null_literals, + else_index, )?; Ok(Arc::new(lookup_table)) } @@ -377,13 +419,17 @@ fn create_lookup_table_for_dictionary_input { let lookup_table = BytesLikeIndexMap::< BytesViewDictionaryHelper, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } DataType::BinaryView => { let lookup_table = BytesLikeIndexMap::< BytesViewDictionaryHelper, - >::try_new(literals, else_index)?; + >::try_new( + unique_non_null_literals, else_index + )?; Ok(Arc::new(lookup_table)) } _ => Err(plan_datafusion_err!( diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs index 8d75e2350cbd..bc466b1a4dbe 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -16,9 +16,9 @@ // under the License. use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; -use arrow::array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; +use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; -use datafusion_common::{HashMap, ScalarValue}; +use datafusion_common::{internal_err, HashMap, ScalarValue}; use half::f16; use std::fmt::Debug; use std::hash::Hash; @@ -32,7 +32,7 @@ where /// Literal value to map index /// /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps - map: HashMap::HashableKey>, i32>, + map: HashMap<::HashableKey, i32>, else_index: i32, } @@ -55,21 +55,26 @@ where T::Native: ToHashableKey, { fn try_new( - literals: Vec, + unique_non_null_literals: Vec, else_index: i32, ) -> datafusion_common::Result where Self: Sized, { - let input = ScalarValue::iter_to_array(literals)?; + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } let map = input .as_primitive::() - .into_iter() + .values() + .iter() .enumerate() - .map(|(map_index, value)| { - (value.map(|v| v.into_hashable_key()), map_index as i32) - }) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .map(|(map_index, value)| (value.into_hashable_key(), map_index as i32)) .collect(); Ok(Self { map, else_index }) @@ -79,11 +84,14 @@ where let indices = array .as_primitive::() .into_iter() - .map(|value| { - self.map - .get(&value.map(|item| item.into_hashable_key())) + .map(|value| match value { + Some(value) => self + .map + .get(&value.into_hashable_key()) .copied() - .unwrap_or(self.else_index) + .unwrap_or(self.else_index), + + None => self.else_index, }) .collect::>(); diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 04d1646dfddf..52db5eb88770 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -1522,6 +1522,4 @@ mod tests { Ok(()) } - - // TODO - add tests for case when with lookup table specialization } From 934f7838f29a17bf12b711ee6c363f8e960a4933 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:42:02 +0300 Subject: [PATCH 21/22] revert format --- .../physical-expr/src/expressions/case/mod.rs | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/mod.rs b/datafusion/physical-expr/src/expressions/case/mod.rs index 52db5eb88770..66830c7ebe8b 100644 --- a/datafusion/physical-expr/src/expressions/case/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -500,7 +500,6 @@ impl CaseExpr { // evaluate else expression on the values not covered by when_value let remainder = not(&when_value)?; let e = self.else_expr.as_ref().unwrap(); - // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) .unwrap_or_else(|_| Arc::clone(e)); @@ -1441,26 +1440,26 @@ mod tests { let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); let (when_thens, else_expr) = match coerce_type { - None => plan_err!( + None => plan_err!( "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" ), - Some(data_type) => { - // cast then expr - let left = when_thens - .into_iter() - .map(|(when, then)| { - let then = try_cast(then, input_schema, data_type.clone())?; - Ok((when, then)) - }) - .collect::>>()?; - let right = match else_expr { - None => None, - Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), - }; + Some(data_type) => { + // cast then expr + let left = when_thens + .into_iter() + .map(|(when, then)| { + let then = try_cast(then, input_schema, data_type.clone())?; + Ok((when, then)) + }) + .collect::>>()?; + let right = match else_expr { + None => None, + Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), + }; - Ok((left, right)) - } - }?; + Ok((left, right)) + } + }?; case(expr, when_thens, else_expr) } From 601bcf3a8ff4dd333b9f1cdc6373c45b40437338 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:45:12 +0300 Subject: [PATCH 22/22] format --- .../bytes_like_lookup_table.rs | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs index 6e9a465db7c5..10cf34ca7d95 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -105,10 +105,10 @@ where .downcast_dict::>() .ok_or_else(|| { exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary value {}", - array.data_type(), - Value::DATA_TYPE - ) + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) })?; Ok(dict_array.into_iter().map(|item| { @@ -137,14 +137,12 @@ where fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", - array.data_type() - ) - })?; + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", + array.data_type() + ))?; Ok(dict_array.into_iter()) } @@ -171,10 +169,10 @@ where .downcast_dict::>() .ok_or_else(|| { exec_datafusion_err!( - "Failed to downcast dictionary array {} to expected dictionary value {}", - array.data_type(), - Value::DATA_TYPE - ) + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) })?; Ok(dict_array.into_iter().map(|item| {