16
16
// under the License.
17
17
18
18
use crate :: hash_funcs:: murmur3:: spark_compatible_murmur3_hash;
19
- use arrow:: array:: { Float64Array , Float64Builder , RecordBatch } ;
19
+
20
+ use crate :: internal:: { evaluate_batch_for_rand, StatefulSeedValueGenerator } ;
21
+ use arrow:: array:: RecordBatch ;
20
22
use arrow:: datatypes:: { DataType , Schema } ;
21
23
use datafusion:: common:: Result ;
22
- use datafusion:: common:: ScalarValue ;
23
- use datafusion:: error:: DataFusionError ;
24
24
use datafusion:: logical_expr:: ColumnarValue ;
25
25
use datafusion:: physical_expr:: PhysicalExpr ;
26
26
use std:: any:: Any ;
@@ -42,21 +42,11 @@ const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
42
42
const SPARK_MURMUR_ARRAY_SEED : u32 = 0x3c074a61 ;
43
43
44
44
#[ derive( Debug , Clone ) ]
45
- struct XorShiftRandom {
46
- seed : i64 ,
45
+ pub ( crate ) struct XorShiftRandom {
46
+ pub ( crate ) seed : i64 ,
47
47
}
48
48
49
49
impl XorShiftRandom {
50
- fn from_init_seed ( init_seed : i64 ) -> Self {
51
- XorShiftRandom {
52
- seed : Self :: init_seed ( init_seed) ,
53
- }
54
- }
55
-
56
- fn from_stored_seed ( stored_seed : i64 ) -> Self {
57
- XorShiftRandom { seed : stored_seed }
58
- }
59
-
60
50
fn next ( & mut self , bits : u8 ) -> i32 {
61
51
let mut next_seed = self . seed ^ ( self . seed << 21 ) ;
62
52
next_seed ^= ( ( next_seed as u64 ) >> 35 ) as i64 ;
@@ -70,60 +60,43 @@ impl XorShiftRandom {
70
60
let b = self . next ( 27 ) as i64 ;
71
61
( ( a << 27 ) + b) as f64 * DOUBLE_UNIT
72
62
}
63
+ }
73
64
74
- fn init_seed ( init : i64 ) -> i64 {
75
- let bytes_repr = init. to_be_bytes ( ) ;
65
+ impl StatefulSeedValueGenerator < i64 , f64 > for XorShiftRandom {
66
+ fn from_init_seed ( init_seed : i64 ) -> Self {
67
+ let bytes_repr = init_seed. to_be_bytes ( ) ;
76
68
let low_bits = spark_compatible_murmur3_hash ( bytes_repr, SPARK_MURMUR_ARRAY_SEED ) ;
77
69
let high_bits = spark_compatible_murmur3_hash ( bytes_repr, low_bits) ;
78
- ( ( high_bits as i64 ) << 32 ) | ( low_bits as i64 & 0xFFFFFFFFi64 )
70
+ let init_seed = ( ( high_bits as i64 ) << 32 ) | ( low_bits as i64 & 0xFFFFFFFFi64 ) ;
71
+ XorShiftRandom { seed : init_seed }
72
+ }
73
+
74
+ fn from_stored_state ( stored_state : i64 ) -> Self {
75
+ XorShiftRandom { seed : stored_state }
76
+ }
77
+
78
+ fn next_value ( & mut self ) -> f64 {
79
+ self . next_f64 ( )
80
+ }
81
+
82
+ fn get_current_state ( & self ) -> i64 {
83
+ self . seed
79
84
}
80
85
}
81
86
82
87
#[ derive( Debug ) ]
83
88
pub struct RandExpr {
84
- seed : Arc < dyn PhysicalExpr > ,
85
- init_seed_shift : i32 ,
89
+ seed : i64 ,
86
90
state_holder : Arc < Mutex < Option < i64 > > > ,
87
91
}
88
92
89
93
impl RandExpr {
90
- pub fn new ( seed : Arc < dyn PhysicalExpr > , init_seed_shift : i32 ) -> Self {
94
+ pub fn new ( seed : i64 ) -> Self {
91
95
Self {
92
96
seed,
93
- init_seed_shift,
94
97
state_holder : Arc :: new ( Mutex :: new ( None :: < i64 > ) ) ,
95
98
}
96
99
}
97
-
98
- fn extract_init_state ( seed : ScalarValue ) -> Result < i64 > {
99
- if let ScalarValue :: Int64 ( seed_opt) = seed. cast_to ( & DataType :: Int64 ) ? {
100
- Ok ( seed_opt. unwrap_or ( 0 ) )
101
- } else {
102
- Err ( DataFusionError :: Internal (
103
- "unexpected execution branch" . to_string ( ) ,
104
- ) )
105
- }
106
- }
107
- fn evaluate_batch ( & self , seed : ScalarValue , num_rows : usize ) -> Result < ColumnarValue > {
108
- let mut seed_state = self . state_holder . lock ( ) . unwrap ( ) ;
109
- let mut rnd = if seed_state. is_none ( ) {
110
- let init_seed = RandExpr :: extract_init_state ( seed) ?;
111
- let init_seed = init_seed. wrapping_add ( self . init_seed_shift as i64 ) ;
112
- * seed_state = Some ( init_seed) ;
113
- XorShiftRandom :: from_init_seed ( init_seed)
114
- } else {
115
- let stored_seed = seed_state. unwrap ( ) ;
116
- XorShiftRandom :: from_stored_seed ( stored_seed)
117
- } ;
118
-
119
- let mut arr_builder = Float64Builder :: with_capacity ( num_rows) ;
120
- std:: iter:: repeat_with ( || rnd. next_f64 ( ) )
121
- . take ( num_rows)
122
- . for_each ( |v| arr_builder. append_value ( v) ) ;
123
- let array_ref = Arc :: new ( Float64Array :: from ( arr_builder. finish ( ) ) ) ;
124
- * seed_state = Some ( rnd. seed ) ;
125
- Ok ( ColumnarValue :: Array ( array_ref) )
126
- }
127
100
}
128
101
129
102
impl Display for RandExpr {
@@ -134,7 +107,7 @@ impl Display for RandExpr {
134
107
135
108
impl PartialEq for RandExpr {
136
109
fn eq ( & self , other : & Self ) -> bool {
137
- self . seed . eq ( & other. seed ) && self . init_seed_shift == other . init_seed_shift
110
+ self . seed . eq ( & other. seed )
138
111
}
139
112
}
140
113
@@ -160,16 +133,15 @@ impl PhysicalExpr for RandExpr {
160
133
}
161
134
162
135
fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
163
- match self . seed . evaluate ( batch) ? {
164
- ColumnarValue :: Scalar ( seed) => self . evaluate_batch ( seed, batch. num_rows ( ) ) ,
165
- ColumnarValue :: Array ( _arr) => Err ( DataFusionError :: NotImplemented ( format ! (
166
- "Only literal seeds are supported for {self}"
167
- ) ) ) ,
168
- }
136
+ evaluate_batch_for_rand :: < XorShiftRandom , i64 > (
137
+ & self . state_holder ,
138
+ self . seed ,
139
+ batch. num_rows ( ) ,
140
+ )
169
141
}
170
142
171
143
fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
172
- vec ! [ & self . seed ]
144
+ vec ! [ ]
173
145
}
174
146
175
147
fn fmt_sql ( & self , _: & mut Formatter < ' _ > ) -> std:: fmt:: Result {
@@ -178,26 +150,22 @@ impl PhysicalExpr for RandExpr {
178
150
179
151
fn with_new_children (
180
152
self : Arc < Self > ,
181
- children : Vec < Arc < dyn PhysicalExpr > > ,
153
+ _children : Vec < Arc < dyn PhysicalExpr > > ,
182
154
) -> Result < Arc < dyn PhysicalExpr > > {
183
- Ok ( Arc :: new ( RandExpr :: new (
184
- Arc :: clone ( & children[ 0 ] ) ,
185
- self . init_seed_shift ,
186
- ) ) )
155
+ Ok ( Arc :: new ( RandExpr :: new ( self . seed ) ) )
187
156
}
188
157
}
189
158
190
- pub fn rand ( seed : Arc < dyn PhysicalExpr > , init_seed_shift : i32 ) -> Result < Arc < dyn PhysicalExpr > > {
191
- Ok ( Arc :: new ( RandExpr :: new ( seed, init_seed_shift ) ) )
159
+ pub fn rand ( seed : i64 ) -> Arc < dyn PhysicalExpr > {
160
+ Arc :: new ( RandExpr :: new ( seed) )
192
161
}
193
162
194
163
#[ cfg( test) ]
195
164
mod tests {
196
165
use super :: * ;
197
- use arrow:: array:: { Array , BooleanArray , Int64Array } ;
166
+ use arrow:: array:: { Array , Float64Array , Int64Array } ;
198
167
use arrow:: { array:: StringArray , compute:: concat, datatypes:: * } ;
199
168
use datafusion:: common:: cast:: as_float64_array;
200
- use datafusion:: physical_expr:: expressions:: lit;
201
169
202
170
const SPARK_SEED_42_FIRST_5 : [ f64 ; 5 ] = [
203
171
0.619189370225301 ,
@@ -212,7 +180,7 @@ mod tests {
212
180
let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ;
213
181
let data = StringArray :: from ( vec ! [ Some ( "foo" ) , None , None , Some ( "bar" ) , None ] ) ;
214
182
let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( data) ] ) ?;
215
- let rand_expr = rand ( lit ( 42 ) , 0 ) ? ;
183
+ let rand_expr = rand ( 42 ) ;
216
184
let result = rand_expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
217
185
let result = as_float64_array ( & result) ?;
218
186
let expected = & Float64Array :: from ( Vec :: from ( SPARK_SEED_42_FIRST_5 ) ) ;
@@ -226,7 +194,7 @@ mod tests {
226
194
let first_batch_data = Int64Array :: from ( vec ! [ Some ( 42 ) , None ] ) ;
227
195
let second_batch_schema = first_batch_schema. clone ( ) ;
228
196
let second_batch_data = Int64Array :: from ( vec ! [ None , Some ( -42 ) , None ] ) ;
229
- let rand_expr = rand ( lit ( 42 ) , 0 ) ? ;
197
+ let rand_expr = rand ( 42 ) ;
230
198
let first_batch = RecordBatch :: try_new (
231
199
Arc :: new ( first_batch_schema) ,
232
200
vec ! [ Arc :: new( first_batch_data) ] ,
@@ -251,23 +219,4 @@ mod tests {
251
219
assert_eq ! ( final_result, expected) ;
252
220
Ok ( ( ) )
253
221
}
254
-
255
- #[ test]
256
- fn test_overflow_shift_seed ( ) -> Result < ( ) > {
257
- let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Boolean , false ) ] ) ;
258
- let data = BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) ] ) ;
259
- let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( data) ] ) ?;
260
- let max_seed_and_shift_expr = rand ( lit ( i64:: MAX ) , 1 ) ?;
261
- let min_seed_no_shift_expr = rand ( lit ( i64:: MIN ) , 0 ) ?;
262
- let first_expr_result = max_seed_and_shift_expr
263
- . evaluate ( & batch) ?
264
- . into_array ( batch. num_rows ( ) ) ?;
265
- let first_expr_result = as_float64_array ( & first_expr_result) ?;
266
- let second_expr_result = min_seed_no_shift_expr
267
- . evaluate ( & batch) ?
268
- . into_array ( batch. num_rows ( ) ) ?;
269
- let second_expr_result = as_float64_array ( & second_expr_result) ?;
270
- assert_eq ! ( first_expr_result, second_expr_result) ;
271
- Ok ( ( ) )
272
- }
273
222
}
0 commit comments