From aa2e4806b5e2419bf82d4d2b4ab770a7bb8e00b2 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Thu, 10 Jul 2025 16:03:50 +0400 Subject: [PATCH 1/7] feat: randn expression support --- docs/spark_expressions_support.md | 2 + native/core/src/execution/planner.rs | 6 +- native/proto/src/proto/expr.proto | 1 + .../nondetermenistic_funcs/internal/mod.rs | 4 + .../internal/rand_utils.rs | 74 ++++++ .../src/nondetermenistic_funcs/mod.rs | 3 + .../src/nondetermenistic_funcs/rand.rs | 79 +++--- .../src/nondetermenistic_funcs/randn.rs | 251 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 8 +- .../apache/comet/CometExpressionSuite.scala | 22 +- .../org/apache/spark/sql/CometTestBase.scala | 14 + 11 files changed, 412 insertions(+), 52 deletions(-) create mode 100644 native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs create mode 100644 native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs create mode 100644 native/spark-expr/src/nondetermenistic_funcs/randn.rs diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index be7c911814..6c4364f26c 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -351,6 +351,8 @@ - [ ] input_file_name - [ ] monotonically_increasing_id - [ ] raise_error + - [x] rand + - [x] randn - [ ] spark_partition_id - [ ] typeof - [x] user diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 558a785bb6..8ef86e97a1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -106,7 +106,7 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, - RandExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, + RandExpr, RandnExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; @@ -782,6 +782,10 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(RandExpr::new(child, self.partition))) } + ExprStruct::Randn(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(RandnExpr::new(child, self.partition))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8f4c875eec..076143a039 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -81,6 +81,7 @@ message Expr { MathExpr integral_divide = 59; ToPrettyString to_pretty_string = 60; UnaryExpr rand = 61; + UnaryExpr randn = 62; } } diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs new file mode 100644 index 0000000000..0428573cb7 --- /dev/null +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs @@ -0,0 +1,4 @@ +mod rand_utils; + +pub use rand_utils::evaluate_batch_for_rand; +pub use rand_utils::StatefulSeedValueGenerator; diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs new file mode 100644 index 0000000000..5b87baaba4 --- /dev/null +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Float64Array, Float64Builder}; +use arrow::datatypes::DataType; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +fn extract_seed_from_scalar_value(seed: &ScalarValue) -> datafusion::common::Result { + if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? { + Ok(seed_opt.unwrap_or(0)) + } else { + Err(DataFusionError::Internal( + "unexpected execution branch".to_string(), + )) + } +} + +pub fn evaluate_batch_for_rand( + state_holder: &Arc>>, + seed: ScalarValue, + init_seed_shift: i64, + num_rows: usize, +) -> datafusion::common::Result +where + R: StatefulSeedValueGenerator, + S: Copy, +{ + let seed_state = state_holder.lock().unwrap(); + let init = extract_seed_from_scalar_value(&seed)?.wrapping_add(init_seed_shift); + let mut rnd = R::from_state_ref(seed_state, init); + let mut arr_builder = Float64Builder::with_capacity(num_rows); + std::iter::repeat_with(|| rnd.next_value()) + .take(num_rows) + .for_each(|v| arr_builder.append_value(v)); + let array_ref = Arc::new(Float64Array::from(arr_builder.finish())); + let mut seed_state = state_holder.lock().unwrap(); + seed_state.replace(rnd.get_current_state()); + Ok(ColumnarValue::Array(array_ref)) +} + +pub trait StatefulSeedValueGenerator: Sized { + fn from_init_seed(init_seed: i64) -> Self; + + fn from_stored_state(stored_state: State) -> Self; + + fn next_value(&mut self) -> Value; + + fn get_current_state(&self) -> State; + + fn from_state_ref(state: impl Deref>, init_value: i64) -> Self { + if state.is_none() { + Self::from_init_seed(init_value) + } else { + Self::from_stored_state(state.unwrap()) + } + } +} diff --git a/native/spark-expr/src/nondetermenistic_funcs/mod.rs b/native/spark-expr/src/nondetermenistic_funcs/mod.rs index c5ff894e8e..94774acd51 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/mod.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +pub mod internal; pub mod rand; +pub mod randn; pub use rand::RandExpr; +pub use randn::RandnExpr; diff --git a/native/spark-expr/src/nondetermenistic_funcs/rand.rs b/native/spark-expr/src/nondetermenistic_funcs/rand.rs index d82c2cd92e..22a4093ce6 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/rand.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/rand.rs @@ -16,10 +16,11 @@ // under the License. use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash; -use arrow::array::{Float64Array, Float64Builder, RecordBatch}; + +use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator}; +use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema}; use datafusion::common::Result; -use datafusion::common::ScalarValue; use datafusion::error::DataFusionError; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; @@ -42,21 +43,11 @@ const DOUBLE_UNIT: f64 = 1.1102230246251565e-16; const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61; #[derive(Debug, Clone)] -struct XorShiftRandom { - seed: i64, +pub(crate) struct XorShiftRandom { + pub(crate) seed: i64, } impl XorShiftRandom { - fn from_init_seed(init_seed: i64) -> Self { - XorShiftRandom { - seed: Self::init_seed(init_seed), - } - } - - fn from_stored_seed(stored_seed: i64) -> Self { - XorShiftRandom { seed: stored_seed } - } - fn next(&mut self, bits: u8) -> i32 { let mut next_seed = self.seed ^ (self.seed << 21); next_seed ^= ((next_seed as u64) >> 35) as i64; @@ -70,12 +61,27 @@ impl XorShiftRandom { let b = self.next(27) as i64; ((a << 27) + b) as f64 * DOUBLE_UNIT } +} - fn init_seed(init: i64) -> i64 { - let bytes_repr = init.to_be_bytes(); +impl StatefulSeedValueGenerator for XorShiftRandom { + fn from_init_seed(init_seed: i64) -> Self { + let bytes_repr = init_seed.to_be_bytes(); let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED); let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits); - ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64) + let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64); + XorShiftRandom { seed: init_seed } + } + + fn from_stored_state(stored_state: i64) -> Self { + XorShiftRandom { seed: stored_state } + } + + fn next_value(&mut self) -> f64 { + self.next_f64() + } + + fn get_current_state(&self) -> i64 { + self.seed } } @@ -94,36 +100,6 @@ impl RandExpr { state_holder: Arc::new(Mutex::new(None::)), } } - - fn extract_init_state(seed: ScalarValue) -> Result { - if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? { - Ok(seed_opt.unwrap_or(0)) - } else { - Err(DataFusionError::Internal( - "unexpected execution branch".to_string(), - )) - } - } - fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result { - let mut seed_state = self.state_holder.lock().unwrap(); - let mut rnd = if seed_state.is_none() { - let init_seed = RandExpr::extract_init_state(seed)?; - let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64); - *seed_state = Some(init_seed); - XorShiftRandom::from_init_seed(init_seed) - } else { - let stored_seed = seed_state.unwrap(); - XorShiftRandom::from_stored_seed(stored_seed) - }; - - let mut arr_builder = Float64Builder::with_capacity(num_rows); - std::iter::repeat_with(|| rnd.next_f64()) - .take(num_rows) - .for_each(|v| arr_builder.append_value(v)); - let array_ref = Arc::new(Float64Array::from(arr_builder.finish())); - *seed_state = Some(rnd.seed); - Ok(ColumnarValue::Array(array_ref)) - } } impl Display for RandExpr { @@ -161,7 +137,12 @@ impl PhysicalExpr for RandExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { match self.seed.evaluate(batch)? { - ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()), + ColumnarValue::Scalar(seed) => evaluate_batch_for_rand::( + &self.state_holder, + seed, + self.init_seed_shift as i64, + batch.num_rows(), + ), ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!( "Only literal seeds are supported for {self}" ))), @@ -194,7 +175,7 @@ pub fn rand(seed: Arc, init_seed_shift: i32) -> Result Spark equivalence: +/// Under the hood, the spark algorithm refers to java.util.Random relying on a module StrictMath. The latter uses +/// native implementations of floating-point operations (ln, exp, sin, cos) and ensures +/// they are stable across different platforms. +/// See: https://github.com/openjdk/jdk/blob/07c9f7138affdf0d42ecdc30adcb854515569985/src/java.base/share/classes/java/util/Random.java#L745 +/// Yet, for the Rust standard library this stability is not guaranteed (https://doc.rust-lang.org/std/primitive.f64.html#method.ln) +/// Moreover, potential usage of external library like rug (https://docs.rs/rug/latest/rug/) doesn't help because still there is no +/// guarantee it matches the StrictMath jvm implementation. +/// So, we can ensure only equivalence with some error tolerance between rust and spark(jvm). + +#[derive(Debug, Clone)] +struct XorShiftRandomForGaussian { + base_generator: XorShiftRandom, + next_gaussian: Option, +} + +impl XorShiftRandomForGaussian { + pub fn next_gaussian(&mut self) -> f64 { + if let Some(stored_value) = self.next_gaussian { + self.next_gaussian = None; + return stored_value; + } + let mut v1: f64; + let mut v2: f64; + let mut s: f64; + loop { + v1 = 2f64 * self.base_generator.next_f64() - 1f64; + v2 = 2f64 * self.base_generator.next_f64() - 1f64; + s = v1 * v1 + v2 * v2; + if s < 1f64 && s != 0f64 { + break; + } + } + let multiplier = (-2f64 * s.ln() / s).sqrt(); + self.next_gaussian = Some(v2 * multiplier); + v1 * multiplier + } +} + +type RandomGaussianState = (i64, Option); + +impl StatefulSeedValueGenerator for XorShiftRandomForGaussian { + fn from_init_seed(init_value: i64) -> Self { + XorShiftRandomForGaussian { + base_generator: XorShiftRandom::from_init_seed(init_value), + next_gaussian: None, + } + } + + fn from_stored_state(stored_state: RandomGaussianState) -> Self { + XorShiftRandomForGaussian { + base_generator: XorShiftRandom::from_stored_state(stored_state.0), + next_gaussian: stored_state.1, + } + } + + fn next_value(&mut self) -> f64 { + self.next_gaussian() + } + + fn get_current_state(&self) -> RandomGaussianState { + (self.base_generator.seed, self.next_gaussian) + } +} + +#[derive(Debug, Clone)] +pub struct RandnExpr { + seed: Arc, + init_seed_shift: i32, + state_holder: Arc>>, +} + +impl RandnExpr { + pub fn new(seed: Arc, init_seed_shift: i32) -> Self { + Self { + seed, + init_seed_shift, + state_holder: Arc::new(Mutex::new(None)), + } + } +} + +impl Display for RandnExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "RANDN({})", self.seed) + } +} + +impl PartialEq for RandnExpr { + fn eq(&self, other: &Self) -> bool { + self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift + } +} + +impl Eq for RandnExpr {} + +impl Hash for RandnExpr { + fn hash(&self, state: &mut H) { + self.children().hash(state); + } +} + +impl PhysicalExpr for RandnExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> datafusion::common::Result { + Ok(DataType::Float64) + } + + fn nullable(&self, _input_schema: &Schema) -> datafusion::common::Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { + match self.seed.evaluate(batch)? { + ColumnarValue::Scalar(seed) => { + evaluate_batch_for_rand::( + &self.state_holder, + seed, + self.init_seed_shift as i64, + batch.num_rows(), + ) + } + ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!( + "Only literal seeds are supported for {self}" + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.seed] + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + Ok(Arc::new(RandnExpr::new( + Arc::clone(&children[0]), + self.init_seed_shift, + ))) + } +} + +pub fn randn( + seed: Arc, + init_seed_shift: i32, +) -> datafusion::common::Result> { + Ok(Arc::new(RandnExpr::new(seed, init_seed_shift))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array, Int64Array}; + use arrow::{array::StringArray, compute::concat, datatypes::*}; + use datafusion::common::cast::as_float64_array; + use datafusion::physical_expr::expressions::lit; + + const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [ + 2.384479054241165, + 0.1920934041293524, + 0.7337336533286575, + -0.5224480195716871, + 2.060084179317831, + ]; + + #[test] + fn test_rand_single_batch() -> datafusion::common::Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; + let randn_expr = randn(lit(42), 0)?; + let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_float64_array(&result)?; + let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); + assert_eq!(result, expected); + Ok(()) + } + + #[test] + fn test_rand_multi_batch() -> datafusion::common::Result<()> { + let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let first_batch_data = Int64Array::from(vec![Some(24), None, None]); + let second_batch_schema = first_batch_schema.clone(); + let second_batch_data = Int64Array::from(vec![None, Some(22)]); + let randn_expr = randn(lit(42), 0)?; + let first_batch = RecordBatch::try_new( + Arc::new(first_batch_schema), + vec![Arc::new(first_batch_data)], + )?; + let first_batch_result = randn_expr + .evaluate(&first_batch)? + .into_array(first_batch.num_rows())?; + let second_batch = RecordBatch::try_new( + Arc::new(second_batch_schema), + vec![Arc::new(second_batch_data)], + )?; + let second_batch_result = randn_expr + .evaluate(&second_batch)? + .into_array(second_batch.num_rows())?; + let result_arrays: Vec<&dyn Array> = vec![ + as_float64_array(&first_batch_result)?, + as_float64_array(&second_batch_result)?, + ]; + let result_arrays = &concat(&result_arrays)?; + let final_result = as_float64_array(result_arrays)?; + let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); + assert_eq!(final_result, expected); + Ok(()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5c94942830..eed4561fd1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1922,7 +1922,13 @@ object QueryPlanSerde extends Logging with CometExprShim { inputs, binding, (builder, unaryExpr) => builder.setRand(unaryExpr)) - + case Randn(child, _) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setRandn(unaryExpr)) case mk: MapKeys => val childExpr = exprToProtoInternal(mk.child, inputs, binding) scalarFunctionExprToProto("map_keys", childExpr) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 7ca8129b6d..303b8e3f01 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2749,7 +2749,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val rowsNumber = Random.nextInt(500) val seed = Random.nextLong() // use this value to have both single-batch and multi-batch partitions - val cometBatchSize = math.max(1, math.ceil(rowsNumber.toDouble / partitionsNumber).toInt) + val cometBatchSize = math.max(1, math.floor(rowsNumber.toDouble / partitionsNumber).toInt) withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) { withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df => val dfWithRandParameters = df.repartition(partitionsNumber).withColumn("rnd", rand(seed)) @@ -2764,6 +2764,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("randn expression with random parameters") { + val partitionsNumber = Random.nextInt(10) + 1 + val rowsNumber = Random.nextInt(500) + val seed = Random.nextLong() + val cometBatchSize = math.max(1, math.floor(rowsNumber.toDouble / partitionsNumber).toInt) + withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) { + withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df => + val dfWithRandParameters = + df.repartition(partitionsNumber).withColumn("randn", randn(seed)) + checkSparkAnswerAndOperatorWithTol(dfWithRandParameters) + val dfWithOverflowSeed = + df.repartition(partitionsNumber).withColumn("randn", randn(Long.MaxValue)) + checkSparkAnswerAndOperatorWithTol(dfWithOverflowSeed) + val dfWithNullSeed = + df.repartition(partitionsNumber).selectExpr("_1", "randn(null) as randn") + checkSparkAnswerAndOperatorWithTol(dfWithNullSeed) + } + } + } + test("window query with rangeBetween") { // values are int diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 9d51c69196..365003aa8c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -179,6 +179,20 @@ abstract class CometTestBase checkSparkAnswer(df) } + protected def checkSparkAnswerAndOperatorWithTol(df: => DataFrame, tol: Double = 1e-6): Unit = { + checkSparkAnswerAndOperatorWithTol(df, tol, Seq.empty) + } + + protected def checkSparkAnswerAndOperatorWithTol( + df: => DataFrame, + tol: Double, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) + checkSparkAnswerWithTol(df, tol) + } + protected def checkCometOperators(plan: SparkPlan, excludedClasses: Class[_]*): Unit = { val wrapped = wrapCometSparkToColumnar(plan) wrapped.foreach { From e7913adc4d0a6f8e8b9bc8ffc85ed3945079a267 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Thu, 10 Jul 2025 16:43:23 +0400 Subject: [PATCH 2/7] added a missing license --- .../src/nondetermenistic_funcs/internal/mod.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs index 0428573cb7..c7437f0667 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + mod rand_utils; pub use rand_utils::evaluate_batch_for_rand; From a00ec9c3a62762c05902d377f1eaa2275f9a633d Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Fri, 11 Jul 2025 00:34:40 +0400 Subject: [PATCH 3/7] added tolerance for randn test in rust --- .../src/nondetermenistic_funcs/randn.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/nondetermenistic_funcs/randn.rs b/native/spark-expr/src/nondetermenistic_funcs/randn.rs index 8f4e85815b..353892bbd2 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/randn.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/randn.rs @@ -196,6 +196,8 @@ mod tests { use datafusion::common::cast::as_float64_array; use datafusion::physical_expr::expressions::lit; + const PRECISION_TOLERANCE: f64 = 1e-6; + const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [ 2.384479054241165, 0.1920934041293524, @@ -213,7 +215,7 @@ mod tests { let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_float64_array(&result)?; let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); - assert_eq!(result, expected); + assert_eq_with_tolerance(result, expected); Ok(()) } @@ -245,7 +247,19 @@ mod tests { let result_arrays = &concat(&result_arrays)?; let final_result = as_float64_array(result_arrays)?; let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); - assert_eq!(final_result, expected); + assert_eq_with_tolerance(final_result, expected); Ok(()) } + + fn assert_eq_with_tolerance(left: &Float64Array, right: &Float64Array) { + assert_eq!(left.len(), right.len()); + left.iter().zip(right.iter()).for_each(|(l, r)| { + assert!( + (l.unwrap() - r.unwrap()).abs() < PRECISION_TOLERANCE, + "difference between {:?} and {:?} is larger than acceptable precision", + l.unwrap(), + r.unwrap() + ) + }) + } } From 9bf6fbd9a6c7e2746ee5b8fad08a09ff41c693ce Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Thu, 17 Jul 2025 12:28:59 +0400 Subject: [PATCH 4/7] simplified serde logic for rand expressions and added a test case for multiple expressions evaluation --- native/core/src/execution/planner.rs | 8 +-- native/proto/src/proto/expr.proto | 13 +++- .../internal/rand_utils.rs | 18 +----- .../src/nondetermenistic_funcs/rand.rs | 64 +++++-------------- .../src/nondetermenistic_funcs/randn.rs | 48 +++++--------- .../apache/comet/serde/QueryPlanSerde.scala | 32 ++++++---- .../apache/comet/CometExpressionSuite.scala | 63 ++++++++++-------- 7 files changed, 105 insertions(+), 141 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 6c902f1ddf..f76e10199f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -791,12 +791,12 @@ impl PhysicalPlanner { ))) } ExprStruct::Rand(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(RandExpr::new(child, self.partition))) + let seed = expr.seed.wrapping_add(self.partition.into()); + Ok(Arc::new(RandExpr::new(seed))) } ExprStruct::Randn(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(RandnExpr::new(child, self.partition))) + let seed = expr.seed.wrapping_add(self.partition.into()); + Ok(Arc::new(RandnExpr::new(seed))) } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 076143a039..252bf60488 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -80,8 +80,8 @@ message Expr { ArrayInsert array_insert = 58; MathExpr integral_divide = 59; ToPrettyString to_pretty_string = 60; - UnaryExpr rand = 61; - UnaryExpr randn = 62; + Rand rand = 61; + Randn randn = 62; } } @@ -416,6 +416,15 @@ message ArrayJoin { Expr null_replacement_expr = 3; } +message Rand { + int64 seed = 1; +} + +message Randn { + int64 seed = 1; +} + + message DataType { enum DataTypeId { BOOL = 0; diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs index 5b87baaba4..9abaaa9396 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs @@ -16,26 +16,13 @@ // under the License. use arrow::array::{Float64Array, Float64Builder}; -use arrow::datatypes::DataType; -use datafusion::common::{DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use std::ops::Deref; use std::sync::{Arc, Mutex}; -fn extract_seed_from_scalar_value(seed: &ScalarValue) -> datafusion::common::Result { - if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? { - Ok(seed_opt.unwrap_or(0)) - } else { - Err(DataFusionError::Internal( - "unexpected execution branch".to_string(), - )) - } -} - pub fn evaluate_batch_for_rand( state_holder: &Arc>>, - seed: ScalarValue, - init_seed_shift: i64, + seed: i64, num_rows: usize, ) -> datafusion::common::Result where @@ -43,8 +30,7 @@ where S: Copy, { let seed_state = state_holder.lock().unwrap(); - let init = extract_seed_from_scalar_value(&seed)?.wrapping_add(init_seed_shift); - let mut rnd = R::from_state_ref(seed_state, init); + let mut rnd = R::from_state_ref(seed_state, seed); let mut arr_builder = Float64Builder::with_capacity(num_rows); std::iter::repeat_with(|| rnd.next_value()) .take(num_rows) diff --git a/native/spark-expr/src/nondetermenistic_funcs/rand.rs b/native/spark-expr/src/nondetermenistic_funcs/rand.rs index 22a4093ce6..e548f78909 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/rand.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/rand.rs @@ -21,7 +21,6 @@ use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema}; use datafusion::common::Result; -use datafusion::error::DataFusionError; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::any::Any; @@ -87,16 +86,14 @@ impl StatefulSeedValueGenerator for XorShiftRandom { #[derive(Debug)] pub struct RandExpr { - seed: Arc, - init_seed_shift: i32, + seed: i64, state_holder: Arc>>, } impl RandExpr { - pub fn new(seed: Arc, init_seed_shift: i32) -> Self { + pub fn new(seed: i64) -> Self { Self { seed, - init_seed_shift, state_holder: Arc::new(Mutex::new(None::)), } } @@ -110,7 +107,7 @@ impl Display for RandExpr { impl PartialEq for RandExpr { fn eq(&self, other: &Self) -> bool { - self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift + self.seed.eq(&other.seed) } } @@ -136,21 +133,15 @@ impl PhysicalExpr for RandExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - match self.seed.evaluate(batch)? { - ColumnarValue::Scalar(seed) => evaluate_batch_for_rand::( - &self.state_holder, - seed, - self.init_seed_shift as i64, - batch.num_rows(), - ), - ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!( - "Only literal seeds are supported for {self}" - ))), - } + evaluate_batch_for_rand::( + &self.state_holder, + self.seed, + batch.num_rows(), + ) } fn children(&self) -> Vec<&Arc> { - vec![&self.seed] + vec![] } fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { @@ -159,26 +150,22 @@ impl PhysicalExpr for RandExpr { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - Ok(Arc::new(RandExpr::new( - Arc::clone(&children[0]), - self.init_seed_shift, - ))) + Ok(Arc::new(RandExpr::new(self.seed))) } } -pub fn rand(seed: Arc, init_seed_shift: i32) -> Result> { - Ok(Arc::new(RandExpr::new(seed, init_seed_shift))) +pub fn rand(seed: i64) -> Arc { + Arc::new(RandExpr::new(seed)) } #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, BooleanArray, Float64Array, Int64Array}; + use arrow::array::{Array, Float64Array, Int64Array}; use arrow::{array::StringArray, compute::concat, datatypes::*}; use datafusion::common::cast::as_float64_array; - use datafusion::physical_expr::expressions::lit; const SPARK_SEED_42_FIRST_5: [f64; 5] = [ 0.619189370225301, @@ -193,7 +180,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let rand_expr = rand(lit(42), 0)?; + let rand_expr = rand(42); let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_float64_array(&result)?; let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5)); @@ -207,7 +194,7 @@ mod tests { let first_batch_data = Int64Array::from(vec![Some(42), None]); let second_batch_schema = first_batch_schema.clone(); let second_batch_data = Int64Array::from(vec![None, Some(-42), None]); - let rand_expr = rand(lit(42), 0)?; + let rand_expr = rand(42); let first_batch = RecordBatch::try_new( Arc::new(first_batch_schema), vec![Arc::new(first_batch_data)], @@ -232,23 +219,4 @@ mod tests { assert_eq!(final_result, expected); Ok(()) } - - #[test] - fn test_overflow_shift_seed() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let data = BooleanArray::from(vec![Some(true), Some(false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?; - let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?; - let first_expr_result = max_seed_and_shift_expr - .evaluate(&batch)? - .into_array(batch.num_rows())?; - let first_expr_result = as_float64_array(&first_expr_result)?; - let second_expr_result = min_seed_no_shift_expr - .evaluate(&batch)? - .into_array(batch.num_rows())?; - let second_expr_result = as_float64_array(&second_expr_result)?; - assert_eq!(first_expr_result, second_expr_result); - Ok(()) - } } diff --git a/native/spark-expr/src/nondetermenistic_funcs/randn.rs b/native/spark-expr/src/nondetermenistic_funcs/randn.rs index 353892bbd2..e1455b68e8 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/randn.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/randn.rs @@ -20,7 +20,6 @@ use crate::nondetermenistic_funcs::rand::XorShiftRandom; use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema}; -use datafusion::common::DataFusionError; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::any::Any; @@ -98,16 +97,14 @@ impl StatefulSeedValueGenerator for XorShiftRandomForG #[derive(Debug, Clone)] pub struct RandnExpr { - seed: Arc, - init_seed_shift: i32, + seed: i64, state_holder: Arc>>, } impl RandnExpr { - pub fn new(seed: Arc, init_seed_shift: i32) -> Self { + pub fn new(seed: i64) -> Self { Self { seed, - init_seed_shift, state_holder: Arc::new(Mutex::new(None)), } } @@ -121,7 +118,7 @@ impl Display for RandnExpr { impl PartialEq for RandnExpr { fn eq(&self, other: &Self) -> bool { - self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift + self.seed.eq(&other.seed) } } @@ -147,23 +144,15 @@ impl PhysicalExpr for RandnExpr { } fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { - match self.seed.evaluate(batch)? { - ColumnarValue::Scalar(seed) => { - evaluate_batch_for_rand::( - &self.state_holder, - seed, - self.init_seed_shift as i64, - batch.num_rows(), - ) - } - ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!( - "Only literal seeds are supported for {self}" - ))), - } + evaluate_batch_for_rand::( + &self.state_holder, + self.seed, + batch.num_rows(), + ) } fn children(&self) -> Vec<&Arc> { - vec![&self.seed] + vec![] } fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { @@ -172,20 +161,14 @@ impl PhysicalExpr for RandnExpr { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> datafusion::common::Result> { - Ok(Arc::new(RandnExpr::new( - Arc::clone(&children[0]), - self.init_seed_shift, - ))) + Ok(Arc::new(RandnExpr::new(self.seed))) } } -pub fn randn( - seed: Arc, - init_seed_shift: i32, -) -> datafusion::common::Result> { - Ok(Arc::new(RandnExpr::new(seed, init_seed_shift))) +pub fn randn(seed: i64) -> Arc { + Arc::new(RandnExpr::new(seed)) } #[cfg(test)] @@ -194,7 +177,6 @@ mod tests { use arrow::array::{Array, Float64Array, Int64Array}; use arrow::{array::StringArray, compute::concat, datatypes::*}; use datafusion::common::cast::as_float64_array; - use datafusion::physical_expr::expressions::lit; const PRECISION_TOLERANCE: f64 = 1e-6; @@ -211,7 +193,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let randn_expr = randn(lit(42), 0)?; + let randn_expr = randn(42); let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_float64_array(&result)?; let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); @@ -225,7 +207,7 @@ mod tests { let first_batch_data = Int64Array::from(vec![Some(24), None, None]); let second_batch_schema = first_batch_schema.clone(); let second_batch_data = Int64Array::from(vec![None, Some(22)]); - let randn_expr = randn(lit(42), 0)?; + let randn_expr = randn(42); let first_batch = RecordBatch::try_new( Arc::new(first_batch_schema), vec![Arc::new(first_batch_data)], diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index ffc7e14416..4e580f6fbc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1860,19 +1860,27 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: ArrayExcept => convert(CometArrayExcept) case Rand(child, _) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setRand(unaryExpr)) + val seed = child match { + case Literal(seed: Long, _) => Some(seed) + case Literal(null, _) => Some(0L) + case _ => None + } + seed.map(seed => + ExprOuterClass.Expr + .newBuilder() + .setRand(ExprOuterClass.Rand.newBuilder().setSeed(seed)) + .build()) case Randn(child, _) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setRandn(unaryExpr)) + val seed = child match { + case Literal(seed: Long, _) => Some(seed) + case Literal(null, _) => Some(0L) + case _ => None + } + seed.map(seed => + ExprOuterClass.Expr + .newBuilder() + .setRandn(ExprOuterClass.Randn.newBuilder().setSeed(seed)) + .build()) case mk: MapKeys => val childExpr = exprToProtoInternal(mk.child, inputs, binding) scalarFunctionExprToProto("map_keys", childExpr) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index aa46a316d5..3a950d17ef 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2745,43 +2745,54 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("rand expression with random parameters") { + private def testOnShuffledRangeWithRandomParameters(testLogic: DataFrame => Unit): Unit = { val partitionsNumber = Random.nextInt(10) + 1 val rowsNumber = Random.nextInt(500) - val seed = Random.nextLong() // use this value to have both single-batch and multi-batch partitions val cometBatchSize = math.max(1, math.floor(rowsNumber.toDouble / partitionsNumber).toInt) withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) { withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df => - val dfWithRandParameters = df.repartition(partitionsNumber).withColumn("rnd", rand(seed)) - checkSparkAnswerAndOperator(dfWithRandParameters) - val dfWithOverflowSeed = - df.repartition(partitionsNumber).withColumn("rnd", rand(Long.MaxValue)) - checkSparkAnswerAndOperator(dfWithOverflowSeed) - val dfWithNullSeed = - df.repartition(partitionsNumber).selectExpr("_1", "rand(null) as rnd") - checkSparkAnswerAndOperator(dfWithNullSeed) + testLogic(df.repartition(partitionsNumber)) } } } + test("rand expression with random parameters") { + testOnShuffledRangeWithRandomParameters { df => + val seed = Random.nextLong() + val dfWithRandParameters = df.withColumn("rnd", rand(seed)) + checkSparkAnswerAndOperator(dfWithRandParameters) + val dfWithOverflowSeed = df.withColumn("rnd", rand(Long.MaxValue)) + checkSparkAnswerAndOperator(dfWithOverflowSeed) + val dfWithNullSeed = df.selectExpr("_1", "rand(null) as rnd") + checkSparkAnswerAndOperator(dfWithNullSeed) + } + } + test("randn expression with random parameters") { - val partitionsNumber = Random.nextInt(10) + 1 - val rowsNumber = Random.nextInt(500) - val seed = Random.nextLong() - val cometBatchSize = math.max(1, math.floor(rowsNumber.toDouble / partitionsNumber).toInt) - withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) { - withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df => - val dfWithRandParameters = - df.repartition(partitionsNumber).withColumn("randn", randn(seed)) - checkSparkAnswerAndOperatorWithTol(dfWithRandParameters) - val dfWithOverflowSeed = - df.repartition(partitionsNumber).withColumn("randn", randn(Long.MaxValue)) - checkSparkAnswerAndOperatorWithTol(dfWithOverflowSeed) - val dfWithNullSeed = - df.repartition(partitionsNumber).selectExpr("_1", "randn(null) as randn") - checkSparkAnswerAndOperatorWithTol(dfWithNullSeed) - } + testOnShuffledRangeWithRandomParameters { df => + val seed = Random.nextLong() + val dfWithRandParameters = df.withColumn("randn", randn(seed)) + checkSparkAnswerAndOperatorWithTol(dfWithRandParameters) + val dfWithOverflowSeed = df.withColumn("randn", randn(Long.MaxValue)) + checkSparkAnswerAndOperatorWithTol(dfWithOverflowSeed) + val dfWithNullSeed = df.selectExpr("_1", "randn(null) as randn") + checkSparkAnswerAndOperatorWithTol(dfWithNullSeed) + } + } + + test("multiple nondetermenistic expressions with shuffle") { + testOnShuffledRangeWithRandomParameters { df => + val seed1 = Random.nextLong() + val seed2 = Random.nextLong() + val complexRandDf = df + .withColumn("rand1", rand(seed1)) + .withColumn("randn1", randn(seed1)) + .repartition(2, col("_1")) + .sortWithinPartitions("_1") + .withColumn("rand2", rand(seed2)) + .withColumn("randn2", randn(seed2)) + checkSparkAnswerAndOperator(complexRandDf) } } From 1fd08b01917ca4443f4ff97407ba11dc9cb8b4e6 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Thu, 17 Jul 2025 12:33:27 +0400 Subject: [PATCH 5/7] removed accidentally added case --- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4e580f6fbc..5ac7b6da03 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1881,16 +1881,6 @@ object QueryPlanSerde extends Logging with CometExprShim { .newBuilder() .setRandn(ExprOuterClass.Randn.newBuilder().setSeed(seed)) .build()) - case mk: MapKeys => - val childExpr = exprToProtoInternal(mk.child, inputs, binding) - scalarFunctionExprToProto("map_keys", childExpr) - case mv: MapValues => - val childExpr = exprToProtoInternal(mv.child, inputs, binding) - scalarFunctionExprToProto("map_values", childExpr) - case gmv: GetMapValue => - val mapExpr = exprToProtoInternal(gmv.child, inputs, binding) - val keyExpr = exprToProtoInternal(gmv.key, inputs, binding) - scalarFunctionExprToProto("map_extract", mapExpr, keyExpr) case expr => QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { case Some(handler) => convert(handler) From 754d0c89d768979c6eef68f79d26814d71162d50 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Fri, 18 Jul 2025 09:58:27 +0400 Subject: [PATCH 6/7] added a missing tolerance to the new test --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3a950d17ef..a09f337f8b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2792,7 +2792,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { .sortWithinPartitions("_1") .withColumn("rand2", rand(seed2)) .withColumn("randn2", randn(seed2)) - checkSparkAnswerAndOperator(complexRandDf) + checkSparkAnswerAndOperatorWithTol(complexRandDf) } } From 3a5ebabddf0300635220b96b8193e20d01db8537 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Fri, 18 Jul 2025 15:03:32 +0400 Subject: [PATCH 7/7] review fix: simplifying serde --- native/proto/src/proto/expr.proto | 7 +------ .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 252bf60488..9f31beffdd 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -81,7 +81,7 @@ message Expr { MathExpr integral_divide = 59; ToPrettyString to_pretty_string = 60; Rand rand = 61; - Randn randn = 62; + Rand randn = 62; } } @@ -420,11 +420,6 @@ message Rand { int64 seed = 1; } -message Randn { - int64 seed = 1; -} - - message DataType { enum DataTypeId { BOOL = 0; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5ac7b6da03..4e5631ed2c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1879,7 +1879,7 @@ object QueryPlanSerde extends Logging with CometExprShim { seed.map(seed => ExprOuterClass.Expr .newBuilder() - .setRandn(ExprOuterClass.Randn.newBuilder().setSeed(seed)) + .setRandn(ExprOuterClass.Rand.newBuilder().setSeed(seed)) .build()) case expr => QueryPlanSerde.exprSerdeMap.get(expr.getClass) match {