diff --git a/native/spark-expr/src/array_funcs/array_repeat.rs b/native/spark-expr/src/array_funcs/array_repeat.rs index 7ba8f0b910..c38e145baf 100644 --- a/native/spark-expr/src/array_funcs/array_repeat.rs +++ b/native/spark-expr/src/array_funcs/array_repeat.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::make_scalar_function; use arrow::array::{ new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, NullBufferBuilder, OffsetSizeTrait, UInt64Array, @@ -25,48 +26,16 @@ use arrow::compute::cast; use arrow::datatypes::DataType::{LargeList, List}; use arrow::datatypes::{DataType, Field}; use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion::common::{exec_err, DataFusionError, ScalarValue}; +use datafusion::common::{exec_err, DataFusionError}; use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; -pub fn make_scalar_function( - inner: F, -) -> impl Fn(&[ColumnarValue]) -> Result -where - F: Fn(&[ArrayRef]) -> Result, -{ - move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let args = ColumnarValue::values_to_arrays(args)?; - - let result = (inner)(&args); - - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - } -} - pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result { make_scalar_function(spark_array_repeat_inner)(args) } /// Array_repeat SQL function -fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result { +fn spark_array_repeat_inner(args: &[ArrayRef]) -> Result { let element = &args[0]; let count_array = &args[1]; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs index 166bb6ddf9..c30029602e 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -15,210 +15,177 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::make_scalar_function; use arrow::array::builder::GenericStringBuilder; -use arrow::array::cast::as_dictionary_array; use arrow::array::types::Int32Type; -use arrow::array::{make_array, Array, AsArray, DictionaryArray}; +use arrow::array::{as_dictionary_array, make_array, Array, AsArray, DictionaryArray}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use datafusion::common::{cast::as_generic_string_array, DataFusionError}; use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; const SPACE: &str = " "; /// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { - spark_read_side_padding2(args, false) + make_scalar_function(spark_read_side_padding_no_truncate)(args) } /// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling pub fn spark_rpad(args: &[ColumnarValue]) -> Result { + make_scalar_function(spark_read_side_padding_truncate)(args) +} + +pub fn spark_read_side_padding_truncate(args: &[ArrayRef]) -> Result { spark_read_side_padding2(args, true) } +pub fn spark_read_side_padding_no_truncate(args: &[ArrayRef]) -> Result { + spark_read_side_padding2(args, false) +} + fn spark_read_side_padding2( - args: &[ColumnarValue], + args: &[ArrayRef], truncate: bool, -) -> Result { +) -> Result { match args { - [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { - match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, - ), - DataType::LargeUtf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, - ), - // Dictionary support required for SPARK-48498 - DataType::Dictionary(_, value_type) => { - let dict = as_dictionary_array::(array); - let col = if value_type.as_ref() == &DataType::Utf8 { - spark_read_side_padding_internal::( - dict.values(), - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, - )? - } else { - spark_read_side_padding_internal::( - dict.values(), - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, - )? - }; - // col consists of an array, so arg of to_array() is not used. Can be anything - let values = col.to_array(0)?; - let result = DictionaryArray::try_new(dict.keys().clone(), values)?; - Ok(ColumnarValue::Array(make_array(result.into()))) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", - ))), + [array, array_int] => match array.data_type() { + DataType::Utf8 => { + spark_read_side_padding_space_internal::(array, truncate, array_int) } - } - [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => - { - match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - string, - ), - DataType::LargeUtf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - string, - ), + DataType::LargeUtf8 => { + spark_read_side_padding_space_internal::(array, truncate, array_int) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function rpad/read_side_padding", + ))), + }, + [array, array_int, array_pad_string] => { + match (array.data_type(), array_pad_string.data_type()) { + (DataType::Utf8, DataType::Utf8) => { + spark_read_side_padding_internal::( + array, + truncate, + array_int, + array_pad_string, + ) + } + (DataType::Utf8, DataType::LargeUtf8) => { + spark_read_side_padding_internal::( + array, + truncate, + array_int, + array_pad_string, + ) + } + (DataType::LargeUtf8, DataType::Utf8) => { + spark_read_side_padding_internal::( + array, + truncate, + array_int, + array_pad_string, + ) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + spark_read_side_padding_internal::( + array, + truncate, + array_int, + array_pad_string, + ) + } // Dictionary support required for SPARK-48498 - DataType::Dictionary(_, value_type) => { + (DataType::Dictionary(_, value_type), DataType::Utf8) => { let dict = as_dictionary_array::(array); - let col = if value_type.as_ref() == &DataType::Utf8 { - spark_read_side_padding_internal::( + let values = if value_type.as_ref() == &DataType::Utf8 { + spark_read_side_padding_internal::( dict.values(), truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, + array_int, + array_pad_string, )? } else { - spark_read_side_padding_internal::( + spark_read_side_padding_internal::( dict.values(), truncate, - ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), - SPACE, + array_int, + array_pad_string, )? }; - // col consists of an array, so arg of to_array() is not used. Can be anything - let values = col.to_array(0)?; let result = DictionaryArray::try_new(dict.keys().clone(), values)?; - Ok(ColumnarValue::Array(make_array(result.into()))) + Ok(make_array(result.into())) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function rpad/read_side_padding", ))), } } - [ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Array(Arc::::clone(array_int)), - SPACE, - ), - DataType::LargeUtf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Array(Arc::::clone(array_int)), - SPACE, - ), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", - ))), - }, - [ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => { - match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Array(Arc::::clone(array_int)), - string, - ), - DataType::LargeUtf8 => spark_read_side_padding_internal::( - array, - truncate, - ColumnarValue::Array(Arc::::clone(array_int)), - string, - ), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", - ))), - } - } other => Err(DataFusionError::Internal(format!( "Unsupported arguments {other:?} for function rpad/read_side_padding", ))), } } -fn spark_read_side_padding_internal( +fn spark_read_side_padding_space_internal( array: &ArrayRef, truncate: bool, - pad_type: ColumnarValue, - pad_string: &str, -) -> Result { + array_int: &ArrayRef, +) -> Result { let string_array = as_generic_string_array::(array)?; - match pad_type { - ColumnarValue::Array(array_int) => { - let int_pad_array = array_int.as_primitive::(); + let int_pad_array = array_int.as_primitive::(); - let mut builder = GenericStringBuilder::::with_capacity( - string_array.len(), - string_array.len() * int_pad_array.len(), - ); + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.len() * int_pad_array.len(), + ); - for (string, length) in string_array.iter().zip(int_pad_array) { - match string { - Some(string) => builder.append_value(add_padding_string( - string.parse().unwrap(), - length.unwrap() as usize, - truncate, - pad_string, - )?), - _ => builder.append_null(), - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + for (string, length) in string_array.iter().zip(int_pad_array) { + match (string, length) { + (Some(string), Some(length)) => builder.append_value(add_padding_string( + string.parse().unwrap(), + length as usize, + truncate, + SPACE, + )?), + _ => builder.append_null(), } - ColumnarValue::Scalar(const_pad_length) => { - let length = 0.max(i32::try_from(const_pad_length)?) as usize; + } + Ok(Arc::new(builder.finish())) +} + +fn spark_read_side_padding_internal( + array: &ArrayRef, + truncate: bool, + array_int: &ArrayRef, + pad_string_array: &ArrayRef, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let int_pad_array = array_int.as_primitive::(); + let pad_string_array = as_generic_string_array::(pad_string_array)?; - let mut builder = GenericStringBuilder::::with_capacity( - string_array.len(), - string_array.len() * length, - ); + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.len() * int_pad_array.len(), + ); - for string in string_array.iter() { - match string { - Some(string) => builder.append_value(add_padding_string( - string.parse().unwrap(), - length, - truncate, - pad_string, - )?), - _ => builder.append_null(), - } + for ((string, length), pad_string) in string_array + .iter() + .zip(int_pad_array) + .zip(pad_string_array.iter()) + { + match (string, length, pad_string) { + (Some(string), Some(length), Some(pad_string)) => { + builder.append_value(add_padding_string( + string.parse().unwrap(), + length as usize, + truncate, + pad_string, + )?) } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + _ => builder.append_null(), } } + Ok(Arc::new(builder.finish())) } fn add_padding_string( diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs index 60ffe84a93..f06276c09e 100644 --- a/native/spark-expr/src/utils.rs +++ b/native/spark-expr/src/utils.rs @@ -24,7 +24,7 @@ use arrow::{ }, buffer::BooleanBuffer, }; -use datafusion::logical_expr::EmitTo; +use datafusion::logical_expr::{ColumnarValue, EmitTo}; use std::sync::Arc; use crate::timezone::Tz; @@ -36,6 +36,7 @@ use arrow::{ temporal_conversions::as_datetime, }; use chrono::{DateTime, Offset, TimeZone}; +use datafusion::common::{DataFusionError, ScalarValue}; /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or /// to apply timezone offset. @@ -242,6 +243,38 @@ pub fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> B } } +pub fn make_scalar_function( + inner: F, +) -> impl Fn(&[ColumnarValue]) -> Result +where + F: Fn(&[ArrayRef]) -> Result, +{ + move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let args = ColumnarValue::values_to_arrays(args)?; + + let result = (inner)(&args); + + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} + // These are borrowed from hashbrown crate: // https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index daf0e45cc8..b25c76e3fa 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2273,26 +2273,34 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("rpad") { val table = "rpad" val gen = new DataGenerator(new Random(42)) - withTable(table) { - // generate some data - val dataChars = "abc123" - sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using parquet") - val testData = gen.generateStrings(100, dataChars, 6) ++ Seq( - "é", // unicode 'e\\u{301}' - "é" // unicode '\\u{e9}' - ) - testData.zipWithIndex.foreach { x => - sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')") - } - // test 2-arg version - checkSparkAnswerAndOperator( - s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id") - // test 3-arg version - for (length <- Seq(2, 10)) { - checkSparkAnswerAndOperator( - s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY id") + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + withTable(table) { + // generate some data + val dataChars = "abc123" + sql( + s"create table $table(id int, name1 char(8), name2 varchar(8), len int) using parquet") + val testData = gen.generateStrings(100, dataChars, 6) ++ Seq( + "é", // unicode 'e\\u{301}' + "é", // unicode '\\u{e9}' + null) + testData.zipWithIndex.foreach { x => + val len = Random.nextInt(10) + sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}', $len)") + } + // test 2-arg version checkSparkAnswerAndOperator( - s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id") + "SELECT id, rpad(name1, 10), rpad(name2, 10), rpad('111', 10), rpad('111', null)," + + s" rpad('11', len), rpad(name1, len), rpad(name1, null) FROM $table ORDER BY id") + // test 3-arg version + for (length <- Seq(2, 10)) { + checkSparkAnswerAndOperator( + s"SELECT id, name1, rpad(name1, $length, ' '), rpad('name1', 10, ' ')," + + " rpad(name1, len, name2), rpad('111', 10, name2), rpad(name1, 10, null)," + + s" rpad(name1, null, name2) FROM $table ORDER BY id") + checkSparkAnswerAndOperator( + s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id") + } } } }