Skip to content

Commit b256458

Browse files
authored
feat: randn expression support (#2010)
1 parent d0812d5 commit b256458

File tree

11 files changed

+462
-111
lines changed

11 files changed

+462
-111
lines changed

docs/spark_expressions_support.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@
351351
- [ ] input_file_name
352352
- [ ] monotonically_increasing_id
353353
- [ ] raise_error
354+
- [x] rand
355+
- [x] randn
354356
- [ ] spark_partition_id
355357
- [ ] typeof
356358
- [x] user

native/core/src/execution/planner.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ use datafusion_comet_proto::{
103103
use datafusion_comet_spark_expr::{
104104
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
105105
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
106-
RandExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
106+
RandExpr, RandnExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
107107
TimestampTruncExpr, ToJson, UnboundColumn, Variance,
108108
};
109109
use itertools::Itertools;
@@ -791,8 +791,12 @@ impl PhysicalPlanner {
791791
)))
792792
}
793793
ExprStruct::Rand(expr) => {
794-
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
795-
Ok(Arc::new(RandExpr::new(child, self.partition)))
794+
let seed = expr.seed.wrapping_add(self.partition.into());
795+
Ok(Arc::new(RandExpr::new(seed)))
796+
}
797+
ExprStruct::Randn(expr) => {
798+
let seed = expr.seed.wrapping_add(self.partition.into());
799+
Ok(Arc::new(RandnExpr::new(seed)))
796800
}
797801
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
798802
}

native/proto/src/proto/expr.proto

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ message Expr {
8080
ArrayInsert array_insert = 58;
8181
MathExpr integral_divide = 59;
8282
ToPrettyString to_pretty_string = 60;
83-
UnaryExpr rand = 61;
83+
Rand rand = 61;
84+
Rand randn = 62;
8485
}
8586
}
8687

@@ -415,6 +416,10 @@ message ArrayJoin {
415416
Expr null_replacement_expr = 3;
416417
}
417418

419+
message Rand {
420+
int64 seed = 1;
421+
}
422+
418423
message DataType {
419424
enum DataTypeId {
420425
BOOL = 0;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod rand_utils;
19+
20+
pub use rand_utils::evaluate_batch_for_rand;
21+
pub use rand_utils::StatefulSeedValueGenerator;
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Float64Array, Float64Builder};
19+
use datafusion::logical_expr::ColumnarValue;
20+
use std::ops::Deref;
21+
use std::sync::{Arc, Mutex};
22+
23+
pub fn evaluate_batch_for_rand<R, S>(
24+
state_holder: &Arc<Mutex<Option<S>>>,
25+
seed: i64,
26+
num_rows: usize,
27+
) -> datafusion::common::Result<ColumnarValue>
28+
where
29+
R: StatefulSeedValueGenerator<S, f64>,
30+
S: Copy,
31+
{
32+
let seed_state = state_holder.lock().unwrap();
33+
let mut rnd = R::from_state_ref(seed_state, seed);
34+
let mut arr_builder = Float64Builder::with_capacity(num_rows);
35+
std::iter::repeat_with(|| rnd.next_value())
36+
.take(num_rows)
37+
.for_each(|v| arr_builder.append_value(v));
38+
let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
39+
let mut seed_state = state_holder.lock().unwrap();
40+
seed_state.replace(rnd.get_current_state());
41+
Ok(ColumnarValue::Array(array_ref))
42+
}
43+
44+
pub trait StatefulSeedValueGenerator<State: Copy, Value>: Sized {
45+
fn from_init_seed(init_seed: i64) -> Self;
46+
47+
fn from_stored_state(stored_state: State) -> Self;
48+
49+
fn next_value(&mut self) -> Value;
50+
51+
fn get_current_state(&self) -> State;
52+
53+
fn from_state_ref(state: impl Deref<Target = Option<State>>, init_value: i64) -> Self {
54+
if state.is_none() {
55+
Self::from_init_seed(init_value)
56+
} else {
57+
Self::from_stored_state(state.unwrap())
58+
}
59+
}
60+
}

native/spark-expr/src/nondetermenistic_funcs/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod internal;
1819
pub mod rand;
20+
pub mod randn;
1921

2022
pub use rand::RandExpr;
23+
pub use randn::RandnExpr;

native/spark-expr/src/nondetermenistic_funcs/rand.rs

Lines changed: 39 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash;
19-
use arrow::array::{Float64Array, Float64Builder, RecordBatch};
19+
20+
use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator};
21+
use arrow::array::RecordBatch;
2022
use arrow::datatypes::{DataType, Schema};
2123
use datafusion::common::Result;
22-
use datafusion::common::ScalarValue;
23-
use datafusion::error::DataFusionError;
2424
use datafusion::logical_expr::ColumnarValue;
2525
use datafusion::physical_expr::PhysicalExpr;
2626
use std::any::Any;
@@ -42,21 +42,11 @@ const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
4242
const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
4343

4444
#[derive(Debug, Clone)]
45-
struct XorShiftRandom {
46-
seed: i64,
45+
pub(crate) struct XorShiftRandom {
46+
pub(crate) seed: i64,
4747
}
4848

4949
impl XorShiftRandom {
50-
fn from_init_seed(init_seed: i64) -> Self {
51-
XorShiftRandom {
52-
seed: Self::init_seed(init_seed),
53-
}
54-
}
55-
56-
fn from_stored_seed(stored_seed: i64) -> Self {
57-
XorShiftRandom { seed: stored_seed }
58-
}
59-
6050
fn next(&mut self, bits: u8) -> i32 {
6151
let mut next_seed = self.seed ^ (self.seed << 21);
6252
next_seed ^= ((next_seed as u64) >> 35) as i64;
@@ -70,60 +60,43 @@ impl XorShiftRandom {
7060
let b = self.next(27) as i64;
7161
((a << 27) + b) as f64 * DOUBLE_UNIT
7262
}
63+
}
7364

74-
fn init_seed(init: i64) -> i64 {
75-
let bytes_repr = init.to_be_bytes();
65+
impl StatefulSeedValueGenerator<i64, f64> for XorShiftRandom {
66+
fn from_init_seed(init_seed: i64) -> Self {
67+
let bytes_repr = init_seed.to_be_bytes();
7668
let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED);
7769
let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits);
78-
((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64)
70+
let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64);
71+
XorShiftRandom { seed: init_seed }
72+
}
73+
74+
fn from_stored_state(stored_state: i64) -> Self {
75+
XorShiftRandom { seed: stored_state }
76+
}
77+
78+
fn next_value(&mut self) -> f64 {
79+
self.next_f64()
80+
}
81+
82+
fn get_current_state(&self) -> i64 {
83+
self.seed
7984
}
8085
}
8186

8287
#[derive(Debug)]
8388
pub struct RandExpr {
84-
seed: Arc<dyn PhysicalExpr>,
85-
init_seed_shift: i32,
89+
seed: i64,
8690
state_holder: Arc<Mutex<Option<i64>>>,
8791
}
8892

8993
impl RandExpr {
90-
pub fn new(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Self {
94+
pub fn new(seed: i64) -> Self {
9195
Self {
9296
seed,
93-
init_seed_shift,
9497
state_holder: Arc::new(Mutex::new(None::<i64>)),
9598
}
9699
}
97-
98-
fn extract_init_state(seed: ScalarValue) -> Result<i64> {
99-
if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? {
100-
Ok(seed_opt.unwrap_or(0))
101-
} else {
102-
Err(DataFusionError::Internal(
103-
"unexpected execution branch".to_string(),
104-
))
105-
}
106-
}
107-
fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result<ColumnarValue> {
108-
let mut seed_state = self.state_holder.lock().unwrap();
109-
let mut rnd = if seed_state.is_none() {
110-
let init_seed = RandExpr::extract_init_state(seed)?;
111-
let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64);
112-
*seed_state = Some(init_seed);
113-
XorShiftRandom::from_init_seed(init_seed)
114-
} else {
115-
let stored_seed = seed_state.unwrap();
116-
XorShiftRandom::from_stored_seed(stored_seed)
117-
};
118-
119-
let mut arr_builder = Float64Builder::with_capacity(num_rows);
120-
std::iter::repeat_with(|| rnd.next_f64())
121-
.take(num_rows)
122-
.for_each(|v| arr_builder.append_value(v));
123-
let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
124-
*seed_state = Some(rnd.seed);
125-
Ok(ColumnarValue::Array(array_ref))
126-
}
127100
}
128101

129102
impl Display for RandExpr {
@@ -134,7 +107,7 @@ impl Display for RandExpr {
134107

135108
impl PartialEq for RandExpr {
136109
fn eq(&self, other: &Self) -> bool {
137-
self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift
110+
self.seed.eq(&other.seed)
138111
}
139112
}
140113

@@ -160,16 +133,15 @@ impl PhysicalExpr for RandExpr {
160133
}
161134

162135
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
163-
match self.seed.evaluate(batch)? {
164-
ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()),
165-
ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!(
166-
"Only literal seeds are supported for {self}"
167-
))),
168-
}
136+
evaluate_batch_for_rand::<XorShiftRandom, i64>(
137+
&self.state_holder,
138+
self.seed,
139+
batch.num_rows(),
140+
)
169141
}
170142

171143
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
172-
vec![&self.seed]
144+
vec![]
173145
}
174146

175147
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
@@ -178,26 +150,22 @@ impl PhysicalExpr for RandExpr {
178150

179151
fn with_new_children(
180152
self: Arc<Self>,
181-
children: Vec<Arc<dyn PhysicalExpr>>,
153+
_children: Vec<Arc<dyn PhysicalExpr>>,
182154
) -> Result<Arc<dyn PhysicalExpr>> {
183-
Ok(Arc::new(RandExpr::new(
184-
Arc::clone(&children[0]),
185-
self.init_seed_shift,
186-
)))
155+
Ok(Arc::new(RandExpr::new(self.seed)))
187156
}
188157
}
189158

190-
pub fn rand(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Result<Arc<dyn PhysicalExpr>> {
191-
Ok(Arc::new(RandExpr::new(seed, init_seed_shift)))
159+
pub fn rand(seed: i64) -> Arc<dyn PhysicalExpr> {
160+
Arc::new(RandExpr::new(seed))
192161
}
193162

194163
#[cfg(test)]
195164
mod tests {
196165
use super::*;
197-
use arrow::array::{Array, BooleanArray, Int64Array};
166+
use arrow::array::{Array, Float64Array, Int64Array};
198167
use arrow::{array::StringArray, compute::concat, datatypes::*};
199168
use datafusion::common::cast::as_float64_array;
200-
use datafusion::physical_expr::expressions::lit;
201169

202170
const SPARK_SEED_42_FIRST_5: [f64; 5] = [
203171
0.619189370225301,
@@ -212,7 +180,7 @@ mod tests {
212180
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
213181
let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
214182
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
215-
let rand_expr = rand(lit(42), 0)?;
183+
let rand_expr = rand(42);
216184
let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
217185
let result = as_float64_array(&result)?;
218186
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
@@ -226,7 +194,7 @@ mod tests {
226194
let first_batch_data = Int64Array::from(vec![Some(42), None]);
227195
let second_batch_schema = first_batch_schema.clone();
228196
let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
229-
let rand_expr = rand(lit(42), 0)?;
197+
let rand_expr = rand(42);
230198
let first_batch = RecordBatch::try_new(
231199
Arc::new(first_batch_schema),
232200
vec![Arc::new(first_batch_data)],
@@ -251,23 +219,4 @@ mod tests {
251219
assert_eq!(final_result, expected);
252220
Ok(())
253221
}
254-
255-
#[test]
256-
fn test_overflow_shift_seed() -> Result<()> {
257-
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
258-
let data = BooleanArray::from(vec![Some(true), Some(false)]);
259-
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
260-
let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?;
261-
let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?;
262-
let first_expr_result = max_seed_and_shift_expr
263-
.evaluate(&batch)?
264-
.into_array(batch.num_rows())?;
265-
let first_expr_result = as_float64_array(&first_expr_result)?;
266-
let second_expr_result = min_seed_no_shift_expr
267-
.evaluate(&batch)?
268-
.into_array(batch.num_rows())?;
269-
let second_expr_result = as_float64_array(&second_expr_result)?;
270-
assert_eq!(first_expr_result, second_expr_result);
271-
Ok(())
272-
}
273222
}

0 commit comments

Comments
 (0)