diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index c202be6b6299..e45b817dc6e8 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -18,10 +18,21 @@ //! [`zip`]: Combine values from two arrays based on boolean mask use crate::filter::{SlicesIterator, prep_null_mask_filter}; +use arrow_array::cast::AsArray; +use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type}; use arrow_array::*; -use arrow_buffer::BooleanBuffer; +use arrow_buffer::{ + BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder, + ScalarBuffer, +}; +use arrow_data::ArrayData; use arrow_data::transform::MutableArrayData; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType}; +use std::fmt::{Debug, Formatter}; +use std::hash::Hash; +use std::marker::PhantomData; +use std::ops::Not; +use std::sync::Arc; /// Zip two arrays by some boolean mask. /// @@ -87,8 +98,16 @@ pub fn zip( truthy: &dyn Datum, falsy: &dyn Datum, ) -> Result { - let (truthy, truthy_is_scalar) = truthy.get(); - let (falsy, falsy_is_scalar) = falsy.get(); + let (truthy_array, truthy_is_scalar) = truthy.get(); + let (falsy_array, falsy_is_scalar) = falsy.get(); + + if falsy_is_scalar && truthy_is_scalar { + let zipper = ScalarZipper::try_new(truthy, falsy)?; + return zipper.zip_impl.create_output(mask); + } + + let truthy = truthy_array; + let falsy = falsy_array; if truthy.data_type() != falsy.data_type() { return Err(ArrowError::InvalidArgumentError( @@ -120,7 +139,17 @@ pub fn zip( let falsy = falsy.to_data(); let truthy = truthy.to_data(); - let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len()); + zip_impl(mask, &truthy, truthy_is_scalar, &falsy, falsy_is_scalar) +} + +fn zip_impl( + mask: &BooleanArray, + truthy: &ArrayData, + truthy_is_scalar: bool, + falsy: &ArrayData, + falsy_is_scalar: bool, +) -> Result { + let mut mutable = MutableArrayData::new(vec![truthy, falsy], false, truthy.len()); // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to // fill with falsy values @@ -128,8 +157,8 @@ pub fn zip( // keep track of how much is filled let mut filled = 0; - let mask = maybe_prep_null_mask_filter(mask); - SlicesIterator::from(&mask).for_each(|(start, end)| { + let mask_buffer = maybe_prep_null_mask_filter(mask); + SlicesIterator::from(&mask_buffer).for_each(|(start, end)| { // the gap needs to be filled with falsy values if start > filled { if falsy_is_scalar { @@ -168,6 +197,455 @@ pub fn zip( Ok(make_array(data)) } +/// Zipper for 2 scalars +/// +/// Useful for using in `IF THEN ELSE END` expressions +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array, Scalar, cast::AsArray, types::Int32Type}; +/// +/// # use arrow_select::zip::ScalarZipper; +/// let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); +/// let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); +/// let zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap(); +/// +/// // Later when we have a boolean mask +/// let mask = BooleanArray::from(vec![true, false, true, false, true]); +/// let result = zipper.zip(&mask).unwrap(); +/// let actual = result.as_primitive::(); +/// let expected = Int32Array::from(vec![Some(42), Some(123), Some(42), Some(123), Some(42)]); +/// ``` +/// +#[derive(Debug, Clone)] +pub struct ScalarZipper { + zip_impl: Arc, +} + +impl ScalarZipper { + /// Try to create a new ScalarZipper from two scalar Datum + /// + /// # Errors + /// returns error if: + /// - the two Datum have different data types + /// - either Datum is not a scalar (or has more than 1 element) + /// + pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result { + let (truthy, truthy_is_scalar) = truthy.get(); + let (falsy, falsy_is_scalar) = falsy.get(); + + if truthy.data_type() != falsy.data_type() { + return Err(ArrowError::InvalidArgumentError( + "arguments need to have the same data type".into(), + )); + } + + if !truthy_is_scalar { + return Err(ArrowError::InvalidArgumentError( + "only scalar arrays are supported".into(), + )); + } + + if !falsy_is_scalar { + return Err(ArrowError::InvalidArgumentError( + "only scalar arrays are supported".into(), + )); + } + + if truthy.len() != 1 { + return Err(ArrowError::InvalidArgumentError( + "scalar arrays must have 1 element".into(), + )); + } + if falsy.len() != 1 { + return Err(ArrowError::InvalidArgumentError( + "scalar arrays must have 1 element".into(), + )); + } + + macro_rules! primitive_size_helper { + ($t:ty) => { + Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as Arc + }; + } + + let zip_impl = downcast_primitive! { + truthy.data_type() => (primitive_size_helper), + DataType::Utf8 => { + Arc::new(BytesScalarImpl::::new(truthy, falsy)) as Arc + }, + DataType::LargeUtf8 => { + Arc::new(BytesScalarImpl::::new(truthy, falsy)) as Arc + }, + DataType::Binary => { + Arc::new(BytesScalarImpl::::new(truthy, falsy)) as Arc + }, + DataType::LargeBinary => { + Arc::new(BytesScalarImpl::::new(truthy, falsy)) as Arc + }, + // TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724 + _ => { + Arc::new(FallbackImpl::new(truthy, falsy)) as Arc + }, + }; + + Ok(Self { zip_impl }) + } + + /// Creating output array based on input boolean array and the two scalar values the zipper was created with + /// See struct level documentation for examples. + pub fn zip(&self, mask: &BooleanArray) -> Result { + self.zip_impl.create_output(mask) + } +} + +/// Impl for creating output array based on a mask +trait ZipImpl: Debug + Send + Sync { + /// Creating output array based on input boolean array + fn create_output(&self, input: &BooleanArray) -> Result; +} + +#[derive(Debug, PartialEq)] +struct FallbackImpl { + truthy: ArrayData, + falsy: ArrayData, +} + +impl FallbackImpl { + fn new(left: &dyn Array, right: &dyn Array) -> Self { + Self { + truthy: left.to_data(), + falsy: right.to_data(), + } + } +} + +impl ZipImpl for FallbackImpl { + fn create_output(&self, predicate: &BooleanArray) -> Result { + zip_impl(predicate, &self.truthy, true, &self.falsy, true) + } +} + +struct PrimitiveScalarImpl { + data_type: DataType, + truthy: Option, + falsy: Option, +} + +impl Debug for PrimitiveScalarImpl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveScalarImpl") + .field("data_type", &self.data_type) + .field("truthy", &self.truthy) + .field("falsy", &self.falsy) + .finish() + } +} + +impl PrimitiveScalarImpl { + fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self { + Self { + data_type: truthy.data_type().clone(), + truthy: Self::get_value_from_scalar(truthy), + falsy: Self::get_value_from_scalar(falsy), + } + } + + fn get_value_from_scalar(scalar: &dyn Array) -> Option { + if scalar.is_null(0) { + None + } else { + let value = scalar.as_primitive::().value(0); + + Some(value) + } + } + + /// return an output array that has + /// `value` in all locations where predicate is true + /// `null` otherwise + fn get_scalar_and_null_buffer_for_single_non_nullable( + predicate: BooleanBuffer, + value: T::Native, + ) -> (Vec, Option) { + let result_len = predicate.len(); + let nulls = NullBuffer::new(predicate); + let scalars = vec![value; result_len]; + + (scalars, Some(nulls)) + } +} + +impl ZipImpl for PrimitiveScalarImpl { + fn create_output(&self, predicate: &BooleanArray) -> Result { + let result_len = predicate.len(); + // Nulls are treated as false + let predicate = maybe_prep_null_mask_filter(predicate); + + let (scalars, nulls): (Vec, Option) = match (self.truthy, self.falsy) + { + (Some(truthy_val), Some(falsy_val)) => { + let scalars: Vec = predicate + .iter() + .map(|b| if b { truthy_val } else { falsy_val }) + .collect(); + + (scalars, None) + } + (Some(truthy_val), None) => { + // If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null) + // If a value is false we need the FALSY and the null buffer will have 0 (meaning null) + + Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val) + } + (None, Some(falsy_val)) => { + // Flipping the boolean buffer as we want the opposite of the TRUE case + // + // if the condition is true we want null so we need to NOT the value so we get 0 (meaning null) + // if the condition is false we want the FALSY value so we need to NOT the value so we get 1 (meaning not null) + let predicate = predicate.not(); + + Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val) + } + (None, None) => { + // All values are null + let nulls = NullBuffer::new_null(result_len); + let scalars = vec![T::default_value(); result_len]; + + (scalars, Some(nulls)) + } + }; + + let scalars = ScalarBuffer::::from(scalars); + let output = PrimitiveArray::::try_new(scalars, nulls)?; + + // Keep decimal precisions, scales or timestamps timezones + let output = output.with_data_type(self.data_type.clone()); + + Ok(Arc::new(output)) + } +} + +#[derive(PartialEq, Hash)] +struct BytesScalarImpl { + truthy: Option>, + falsy: Option>, + phantom: PhantomData, +} + +impl Debug for BytesScalarImpl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BytesScalarImpl") + .field("truthy", &self.truthy) + .field("falsy", &self.falsy) + .finish() + } +} + +impl BytesScalarImpl { + fn new(truthy_value: &dyn Array, falsy_value: &dyn Array) -> Self { + Self { + truthy: Self::get_value_from_scalar(truthy_value), + falsy: Self::get_value_from_scalar(falsy_value), + phantom: PhantomData, + } + } + + fn get_value_from_scalar(scalar: &dyn Array) -> Option> { + if scalar.is_null(0) { + None + } else { + let bytes: &[u8] = scalar.as_bytes::().value(0).as_ref(); + + Some(bytes.to_vec()) + } + } + + /// return an output array that has + /// `value` in all locations where predicate is true + /// `null` otherwise + fn get_scalar_and_null_buffer_for_single_non_nullable( + predicate: BooleanBuffer, + value: &[u8], + ) -> (Buffer, OffsetBuffer, Option) { + let value_length = value.len(); + + let number_of_true = predicate.count_set_bits(); + + // Fast path for all nulls + if number_of_true == 0 { + // All values are null + let nulls = NullBuffer::new_null(predicate.len()); + + return ( + // Empty bytes + Buffer::from(&[]), + // All nulls so all lengths are 0 + OffsetBuffer::::new_zeroed(predicate.len()), + Some(nulls), + ); + } + + let offsets = OffsetBuffer::::from_lengths( + predicate.iter().map(|b| if b { value_length } else { 0 }), + ); + + let mut bytes = MutableBuffer::with_capacity(0); + bytes.repeat_slice_n_times(value, number_of_true); + + let bytes = Buffer::from(bytes); + + // If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null) + // If a value is false we need the FALSY and the null buffer will have 0 (meaning null) + let nulls = NullBuffer::new(predicate); + + (bytes, offsets, Some(nulls)) + } + + /// Create a [`Buffer`] where `value` slice is repeated `number_of_values` times + /// and [`OffsetBuffer`] where there are `number_of_values` lengths, and all equals to `value` length + fn get_bytes_and_offset_for_all_same_value( + number_of_values: usize, + value: &[u8], + ) -> (Buffer, OffsetBuffer) { + let value_length = value.len(); + + let offsets = + OffsetBuffer::::from_repeated_length(value_length, number_of_values); + + let mut bytes = MutableBuffer::with_capacity(0); + bytes.repeat_slice_n_times(value, number_of_values); + let bytes = Buffer::from(bytes); + + (bytes, offsets) + } + + fn create_output_on_non_nulls( + predicate: &BooleanBuffer, + truthy_val: &[u8], + falsy_val: &[u8], + ) -> (Buffer, OffsetBuffer<::Offset>) { + let true_count = predicate.count_set_bits(); + + match true_count { + 0 => { + // All values are falsy + + let (bytes, offsets) = + Self::get_bytes_and_offset_for_all_same_value(predicate.len(), falsy_val); + + return (bytes, offsets); + } + n if n == predicate.len() => { + // All values are truthy + let (bytes, offsets) = + Self::get_bytes_and_offset_for_all_same_value(predicate.len(), truthy_val); + + return (bytes, offsets); + } + + _ => { + // Fallback + } + } + + let total_number_of_bytes = + true_count * truthy_val.len() + (predicate.len() - true_count) * falsy_val.len(); + let mut mutable = MutableBuffer::with_capacity(total_number_of_bytes); + let mut offset_buffer_builder = OffsetBufferBuilder::::new(predicate.len()); + + // keep track of how much is filled + let mut filled = 0; + + let truthy_len = truthy_val.len(); + let falsy_len = falsy_val.len(); + + SlicesIterator::from(predicate).for_each(|(start, end)| { + // the gap needs to be filled with falsy values + if start > filled { + let false_repeat_count = start - filled; + // Push false value `repeat_count` times + mutable.repeat_slice_n_times(falsy_val, false_repeat_count); + + for _ in 0..false_repeat_count { + offset_buffer_builder.push_length(falsy_len) + } + } + + let true_repeat_count = end - start; + // fill with truthy values + mutable.repeat_slice_n_times(truthy_val, true_repeat_count); + + for _ in 0..true_repeat_count { + offset_buffer_builder.push_length(truthy_len) + } + filled = end; + }); + // the remaining part is falsy + if filled < predicate.len() { + let false_repeat_count = predicate.len() - filled; + // Copy the first item from the 'falsy' array into the output buffer. + mutable.repeat_slice_n_times(falsy_val, false_repeat_count); + + for _ in 0..false_repeat_count { + offset_buffer_builder.push_length(falsy_len) + } + } + + (mutable.into(), offset_buffer_builder.finish()) + } +} + +impl ZipImpl for BytesScalarImpl { + fn create_output(&self, predicate: &BooleanArray) -> Result { + let result_len = predicate.len(); + // Nulls are treated as false + let predicate = maybe_prep_null_mask_filter(predicate); + + let (bytes, offsets, nulls): (Buffer, OffsetBuffer, Option) = + match (self.truthy.as_deref(), self.falsy.as_deref()) { + (Some(truthy_val), Some(falsy_val)) => { + let (bytes, offsets) = + Self::create_output_on_non_nulls(&predicate, truthy_val, falsy_val); + + (bytes, offsets, None) + } + (Some(truthy_val), None) => { + Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val) + } + (None, Some(falsy_val)) => { + // Flipping the boolean buffer as we want the opposite of the TRUE case + // + // if the condition is true we want null so we need to NOT the value so we get 0 (meaning null) + // if the condition is false we want the FALSE value so we need to NOT the value so we get 1 (meaning not null) + let predicate = predicate.not(); + Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val) + } + (None, None) => { + // All values are null + let nulls = NullBuffer::new_null(result_len); + + ( + // Empty bytes + Buffer::from(&[]), + // All nulls so all lengths are 0 + OffsetBuffer::::new_zeroed(predicate.len()), + Some(nulls), + ) + } + }; + + let output = unsafe { + // Safety: the values are based on valid inputs + // and `try_new` is expensive for strings as it validate that the input is valid utf8 + GenericByteArray::::new_unchecked(offsets, bytes, nulls) + }; + + Ok(Arc::new(output)) + } +} + fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { // Nulls are treated as false if predicate.null_count() == 0 { @@ -182,8 +660,7 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { #[cfg(test)] mod test { use super::*; - use arrow_array::cast::AsArray; - use arrow_buffer::{BooleanBuffer, NullBuffer}; + use arrow_array::types::Int32Type; #[test] fn test_zip_kernel_one() { @@ -260,7 +737,7 @@ mod test { } #[test] - fn test_zip_kernel_scalar_both() { + fn test_zip_kernel_scalar_both_mask_ends_with_true() { let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); @@ -272,7 +749,26 @@ mod test { } #[test] - fn test_zip_kernel_scalar_none_1() { + fn test_zip_kernel_scalar_both_mask_ends_with_false() { + let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); + let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); + + let mask = BooleanArray::from(vec![true, true, false, true, false, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![ + Some(42), + Some(42), + Some(123), + Some(42), + Some(123), + Some(123), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_primitive_scalar_none_1() { let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); let scalar_falsy = Scalar::new(Int32Array::new_null(1)); @@ -284,7 +780,7 @@ mod test { } #[test] - fn test_zip_kernel_scalar_none_2() { + fn test_zip_kernel_primitive_scalar_none_2() { let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); let scalar_falsy = Scalar::new(Int32Array::new_null(1)); @@ -295,6 +791,18 @@ mod test { assert_eq!(actual, &expected); } + #[test] + fn test_zip_kernel_primitive_scalar_both_null() { + let scalar_truthy = Scalar::new(Int32Array::new_null(1)); + let scalar_falsy = Scalar::new(Int32Array::new_null(1)); + + let mask = BooleanArray::from(vec![false, false, true, true, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = Int32Array::from(vec![None, None, None, None, None]); + assert_eq!(actual, &expected); + } + #[test] fn test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() { let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]); @@ -400,4 +908,318 @@ mod test { ]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_kernel_bytes_scalar_none_1() { + let scalar_truthy = Scalar::new(StringArray::from_iter_values(["hello"])); + let scalar_falsy = Scalar::new(StringArray::new_null(1)); + + let mask = BooleanArray::from(vec![true, true, false, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringArray::from_iter(vec![ + Some("hello"), + Some("hello"), + None, + None, + Some("hello"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_bytes_scalar_none_2() { + let scalar_truthy = Scalar::new(StringArray::new_null(1)); + let scalar_falsy = Scalar::new(StringArray::from_iter_values(["hello"])); + + let mask = BooleanArray::from(vec![true, true, false, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringArray::from_iter(vec![None, None, Some("hello"), Some("hello"), None]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_bytes_scalar_both() { + let scalar_truthy = Scalar::new(StringArray::from_iter_values(["test"])); + let scalar_falsy = Scalar::new(StringArray::from_iter_values(["something else"])); + + // mask ends with false + let mask = BooleanArray::from(vec![true, true, false, true, false, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringArray::from_iter(vec![ + Some("test"), + Some("test"), + Some("something else"), + Some("test"), + Some("something else"), + Some("something else"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_scalar_bytes_only_taking_one_side() { + let mask_len = 5; + let all_true_mask = BooleanArray::from(vec![true; mask_len]); + let all_false_mask = BooleanArray::from(vec![false; mask_len]); + + let null_scalar = Scalar::new(StringArray::new_null(1)); + let non_null_scalar_1 = Scalar::new(StringArray::from_iter_values(["test"])); + let non_null_scalar_2 = Scalar::new(StringArray::from_iter_values(["something else"])); + + { + // 1. Test where left is null and right is non-null + // and mask is all true + let out = zip(&all_true_mask, &null_scalar, &non_null_scalar_1).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len)); + assert_eq!(actual, &expected); + } + + { + // 2. Test where left is null and right is non-null + // and mask is all false + let out = zip(&all_false_mask, &null_scalar, &non_null_scalar_1).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len)); + assert_eq!(actual, &expected); + } + + { + // 3. Test where left is non-null and right is null + // and mask is all true + let out = zip(&all_true_mask, &non_null_scalar_1, &null_scalar).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len)); + assert_eq!(actual, &expected); + } + + { + // 4. Test where left is non-null and right is null + // and mask is all false + let out = zip(&all_false_mask, &non_null_scalar_1, &null_scalar).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len)); + assert_eq!(actual, &expected); + } + + { + // 5. Test where both left and right are not null + // and mask is all true + let out = zip(&all_true_mask, &non_null_scalar_1, &non_null_scalar_2).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len)); + assert_eq!(actual, &expected); + } + + { + // 6. Test where both left and right are not null + // and mask is all false + let out = zip(&all_false_mask, &non_null_scalar_1, &non_null_scalar_2).unwrap(); + let actual = out.as_string::(); + let expected = + StringArray::from_iter(std::iter::repeat_n(Some("something else"), mask_len)); + assert_eq!(actual, &expected); + } + + { + // 7. Test where both left and right are null + // and mask is random + let mask = BooleanArray::from(vec![true, false, true, false, true]); + let out = zip(&mask, &null_scalar, &null_scalar).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len)); + assert_eq!(actual, &expected); + } + } + + #[test] + fn test_scalar_zipper() { + let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); + let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); + + let mask = BooleanArray::from(vec![false, false, true, true, false]); + + let scalar_zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap(); + let out = scalar_zipper.zip(&mask).unwrap(); + let actual = out.as_primitive::(); + let expected = Int32Array::from(vec![Some(123), Some(123), Some(42), Some(42), Some(123)]); + assert_eq!(actual, &expected); + + // test with different mask length as well + let mask = BooleanArray::from(vec![true, false, true]); + let out = scalar_zipper.zip(&mask).unwrap(); + let actual = out.as_primitive::(); + let expected = Int32Array::from(vec![Some(42), Some(123), Some(42)]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings() { + let scalar_truthy = Scalar::new(StringArray::from(vec!["hello"])); + let scalar_falsy = Scalar::new(StringArray::from(vec!["world"])); + + let mask = BooleanArray::from(vec![true, false, true, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string::(); + let expected = StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("hello"), + Some("world"), + Some("hello"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_binary() { + let truthy_bytes: &[u8] = b"\xFF\xFE\xFD"; + let falsy_bytes: &[u8] = b"world"; + let scalar_truthy = Scalar::new(BinaryArray::from_iter_values( + // Non valid UTF8 bytes + vec![truthy_bytes], + )); + let scalar_falsy = Scalar::new(BinaryArray::from_iter_values(vec![falsy_bytes])); + + let mask = BooleanArray::from(vec![true, false, true, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_binary::(); + let expected = BinaryArray::from(vec![ + Some(truthy_bytes), + Some(falsy_bytes), + Some(truthy_bytes), + Some(falsy_bytes), + Some(truthy_bytes), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_large_binary() { + let truthy_bytes: &[u8] = b"hey"; + let falsy_bytes: &[u8] = b"world"; + let scalar_truthy = Scalar::new(LargeBinaryArray::from_iter_values(vec![truthy_bytes])); + let scalar_falsy = Scalar::new(LargeBinaryArray::from_iter_values(vec![falsy_bytes])); + + let mask = BooleanArray::from(vec![true, false, true, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_binary::(); + let expected = LargeBinaryArray::from(vec![ + Some(truthy_bytes), + Some(falsy_bytes), + Some(truthy_bytes), + Some(falsy_bytes), + Some(truthy_bytes), + ]); + assert_eq!(actual, &expected); + } + + // Test to ensure that the precision and scale are kept when zipping Decimal128 data + #[test] + fn test_zip_decimal_with_custom_precision_and_scale() { + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) + .with_precision_and_scale(20, 2) + .unwrap(); + + let arr: ArrayRef = Arc::new(arr); + + let scalar_1 = Scalar::new(arr.slice(0, 1)); + let scalar_2 = Scalar::new(arr.slice(1, 1)); + let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1)); + let array_1: ArrayRef = arr.slice(0, 2); + let array_2: ArrayRef = arr.slice(2, 2); + + test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar, array_1, array_2); + } + + // Test to ensure that the timezone is kept when zipping TimestampArray data + #[test] + fn test_zip_timestamp_with_timezone() { + let arr = TimestampSecondArray::from(vec![0, 1000, 2000, 4000]) + .with_timezone("+01:00".to_string()); + + let arr: ArrayRef = Arc::new(arr); + + let scalar_1 = Scalar::new(arr.slice(0, 1)); + let scalar_2 = Scalar::new(arr.slice(1, 1)); + let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1)); + let array_1: ArrayRef = arr.slice(0, 2); + let array_2: ArrayRef = arr.slice(2, 2); + + test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar, array_1, array_2); + } + + fn test_zip_output_data_types_for_input( + scalar_1: Scalar, + scalar_2: Scalar, + null_scalar: Scalar, + array_1: ArrayRef, + array_2: ArrayRef, + ) { + // non null Scalar vs non null Scalar + test_zip_output_data_type(&scalar_1, &scalar_2, 10); + + // null Scalar vs non-null Scalar (and vice versa) + test_zip_output_data_type(&null_scalar, &scalar_1, 10); + test_zip_output_data_type(&scalar_1, &null_scalar, 10); + + // non-null Scalar and array (and vice versa) + test_zip_output_data_type(&array_1.as_ref(), &scalar_1, array_1.len()); + test_zip_output_data_type(&scalar_1, &array_1.as_ref(), array_1.len()); + + // Array and null scalar (and vice versa) + test_zip_output_data_type(&array_1.as_ref(), &null_scalar, array_1.len()); + + test_zip_output_data_type(&null_scalar, &array_1.as_ref(), array_1.len()); + + // Both arrays + test_zip_output_data_type(&array_1.as_ref(), &array_2.as_ref(), array_1.len()); + } + + fn test_zip_output_data_type(truthy: &dyn Datum, falsy: &dyn Datum, mask_length: usize) { + let expected_data_type = truthy.get().0.data_type().clone(); + assert_eq!(&expected_data_type, falsy.get().0.data_type()); + + // Try different masks to test different paths + let mask_all_true = BooleanArray::from(vec![true; mask_length]); + let mask_all_false = BooleanArray::from(vec![false; mask_length]); + let mask_some_true_and_false = + BooleanArray::from((0..mask_length).map(|i| i % 2 == 0).collect::>()); + + for mask in [&mask_all_true, &mask_all_false, &mask_some_true_and_false] { + let out = zip(mask, truthy, falsy).unwrap(); + assert_eq!(out.data_type(), &expected_data_type); + } + } + + #[test] + fn zip_scalar_fallback_impl() { + let truthy_list_item_scalar = Some(vec![Some(1), None, Some(3)]); + let truthy_list_array_scalar = + Scalar::new(ListArray::from_iter_primitive::(vec![ + truthy_list_item_scalar.clone(), + ])); + let falsy_list_item_scalar = Some(vec![None, Some(2), Some(4)]); + let falsy_list_array_scalar = + Scalar::new(ListArray::from_iter_primitive::(vec![ + falsy_list_item_scalar.clone(), + ])); + let mask = BooleanArray::from(vec![true, false, true, false, false, true, false]); + let out = zip(&mask, &truthy_list_array_scalar, &falsy_list_array_scalar).unwrap(); + let actual = out.as_list::(); + + let expected = ListArray::from_iter_primitive::(vec![ + truthy_list_item_scalar.clone(), + falsy_list_item_scalar.clone(), + truthy_list_item_scalar.clone(), + falsy_list_item_scalar.clone(), + falsy_list_item_scalar.clone(), + truthy_list_item_scalar.clone(), + falsy_list_item_scalar.clone(), + ]); + assert_eq!(actual, &expected); + } }