@@ -24,6 +24,7 @@ use std::sync::Arc;
24
24
use crate :: error:: { DataFusionError , Result } ;
25
25
use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
26
26
use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27
+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
27
28
use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
28
29
use crate :: scalar:: ScalarValue ;
29
30
use arrow:: compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
49
50
pub struct Sum {
50
51
name : String ,
51
52
data_type : DataType ,
53
+ input_data_type : DataType ,
52
54
expr : Arc < dyn PhysicalExpr > ,
53
55
nullable : bool ,
54
56
}
@@ -80,11 +82,16 @@ impl Sum {
80
82
expr : Arc < dyn PhysicalExpr > ,
81
83
name : impl Into < String > ,
82
84
data_type : DataType ,
85
+ input_data_type : & DataType ,
83
86
) -> Self {
87
+ // Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88
+ // really need two params. But, we keep the four params to break symmetry with other
89
+ // accumulators and any code that might use 3 params, such as the generic_test_op macro.
84
90
Self {
85
91
name : name. into ( ) ,
86
92
expr,
87
93
data_type,
94
+ input_data_type : input_data_type. clone ( ) ,
88
95
nullable : true ,
89
96
}
90
97
}
@@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
127
134
fn create_groups_accumulator (
128
135
& self ,
129
136
) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
130
- let data_type = self . data_type . clone ( ) ;
131
- Ok ( Some ( Box :: new (
132
- GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new ( move || {
133
- SumAccumulator :: try_new ( & data_type)
134
- } ) ,
135
- ) ) )
137
+ use arrow:: datatypes:: ArrowPrimitiveType ;
138
+
139
+ macro_rules! make_accumulator {
140
+ ( $T: ty, $U: ty) => {
141
+ Box :: new( PrimitiveGroupsAccumulator :: <$T, $U, _, _>:: new(
142
+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
143
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
144
+ y: <$U as ArrowPrimitiveType >:: Native | {
145
+ * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
146
+ } ,
147
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
148
+ y: <$T as ArrowPrimitiveType >:: Native | {
149
+ * x = * x + y;
150
+ } ,
151
+ ) )
152
+ } ;
153
+ }
154
+
155
+ // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
156
+ // the current datafusion Sum accumulator implementation using native +. (That native +
157
+ // specifically is the one in the expressions *x = *x + ... above.)
158
+ Ok ( Some ( match ( & self . data_type , & self . input_data_type ) {
159
+ ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! (
160
+ arrow:: datatypes:: Int64Type ,
161
+ arrow:: datatypes:: Int64Type
162
+ ) ,
163
+ ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! (
164
+ arrow:: datatypes:: Int64Type ,
165
+ arrow:: datatypes:: Int32Type
166
+ ) ,
167
+ ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! (
168
+ arrow:: datatypes:: Int64Type ,
169
+ arrow:: datatypes:: Int16Type
170
+ ) ,
171
+ ( DataType :: Int64 , DataType :: Int8 ) => {
172
+ make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type )
173
+ }
174
+
175
+ ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! (
176
+ arrow:: datatypes:: Int96Type ,
177
+ arrow:: datatypes:: Int96Type
178
+ ) ,
179
+
180
+ ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! (
181
+ arrow:: datatypes:: Int64Decimal0Type ,
182
+ arrow:: datatypes:: Int64Decimal0Type
183
+ ) ,
184
+ ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! (
185
+ arrow:: datatypes:: Int64Decimal1Type ,
186
+ arrow:: datatypes:: Int64Decimal1Type
187
+ ) ,
188
+ ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! (
189
+ arrow:: datatypes:: Int64Decimal2Type ,
190
+ arrow:: datatypes:: Int64Decimal2Type
191
+ ) ,
192
+ ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! (
193
+ arrow:: datatypes:: Int64Decimal3Type ,
194
+ arrow:: datatypes:: Int64Decimal3Type
195
+ ) ,
196
+ ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! (
197
+ arrow:: datatypes:: Int64Decimal4Type ,
198
+ arrow:: datatypes:: Int64Decimal4Type
199
+ ) ,
200
+ ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! (
201
+ arrow:: datatypes:: Int64Decimal5Type ,
202
+ arrow:: datatypes:: Int64Decimal5Type
203
+ ) ,
204
+ ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => {
205
+ make_accumulator ! (
206
+ arrow:: datatypes:: Int64Decimal10Type ,
207
+ arrow:: datatypes:: Int64Decimal10Type
208
+ )
209
+ }
210
+
211
+ ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! (
212
+ arrow:: datatypes:: Int96Decimal0Type ,
213
+ arrow:: datatypes:: Int96Decimal0Type
214
+ ) ,
215
+ ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! (
216
+ arrow:: datatypes:: Int96Decimal1Type ,
217
+ arrow:: datatypes:: Int96Decimal1Type
218
+ ) ,
219
+ ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! (
220
+ arrow:: datatypes:: Int96Decimal2Type ,
221
+ arrow:: datatypes:: Int96Decimal2Type
222
+ ) ,
223
+ ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! (
224
+ arrow:: datatypes:: Int96Decimal3Type ,
225
+ arrow:: datatypes:: Int96Decimal3Type
226
+ ) ,
227
+ ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! (
228
+ arrow:: datatypes:: Int96Decimal4Type ,
229
+ arrow:: datatypes:: Int96Decimal4Type
230
+ ) ,
231
+ ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! (
232
+ arrow:: datatypes:: Int96Decimal5Type ,
233
+ arrow:: datatypes:: Int96Decimal5Type
234
+ ) ,
235
+ ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => {
236
+ make_accumulator ! (
237
+ arrow:: datatypes:: Int96Decimal10Type ,
238
+ arrow:: datatypes:: Int96Decimal10Type
239
+ )
240
+ }
241
+
242
+ ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! (
243
+ arrow:: datatypes:: UInt64Type ,
244
+ arrow:: datatypes:: UInt64Type
245
+ ) ,
246
+ ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! (
247
+ arrow:: datatypes:: UInt64Type ,
248
+ arrow:: datatypes:: UInt32Type
249
+ ) ,
250
+ ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! (
251
+ arrow:: datatypes:: UInt64Type ,
252
+ arrow:: datatypes:: UInt16Type
253
+ ) ,
254
+ ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! (
255
+ arrow:: datatypes:: UInt64Type ,
256
+ arrow:: datatypes:: UInt8Type
257
+ ) ,
258
+
259
+ ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! (
260
+ arrow:: datatypes:: Float32Type ,
261
+ arrow:: datatypes:: Float32Type
262
+ ) ,
263
+ ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! (
264
+ arrow:: datatypes:: Float64Type ,
265
+ arrow:: datatypes:: Float64Type
266
+ ) ,
267
+
268
+ _ => {
269
+ // This case should never be reached because we've handled all sum_return_type
270
+ // arg_type values. Nonetheless:
271
+ let data_type = self . data_type . clone ( ) ;
272
+
273
+ Box :: new ( GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new (
274
+ move || SumAccumulator :: try_new ( & data_type) ,
275
+ ) )
276
+ }
277
+ } ) )
136
278
}
137
279
138
280
fn name ( & self ) -> & str {
@@ -416,13 +558,27 @@ mod tests {
416
558
use arrow:: datatypes:: * ;
417
559
use arrow:: record_batch:: RecordBatch ;
418
560
561
+ // A wrapper to make Sum::new, which now has an input_type argument, work with
562
+ // generic_test_op!.
563
+ struct SumTestStandin ;
564
+ impl SumTestStandin {
565
+ fn new (
566
+ expr : Arc < dyn PhysicalExpr > ,
567
+ name : impl Into < String > ,
568
+ data_type : DataType ,
569
+ ) -> Sum {
570
+ Sum :: new ( expr, name, data_type. clone ( ) , & data_type)
571
+ }
572
+ }
573
+
419
574
#[ test]
420
575
fn sum_i32 ( ) -> Result < ( ) > {
421
576
let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
577
+
422
578
generic_test_op ! (
423
579
a,
424
580
DataType :: Int32 ,
425
- Sum ,
581
+ SumTestStandin ,
426
582
ScalarValue :: from( 15i64 ) ,
427
583
DataType :: Int64
428
584
)
@@ -440,7 +596,7 @@ mod tests {
440
596
generic_test_op ! (
441
597
a,
442
598
DataType :: Int32 ,
443
- Sum ,
599
+ SumTestStandin ,
444
600
ScalarValue :: from( 13i64 ) ,
445
601
DataType :: Int64
446
602
)
@@ -452,7 +608,7 @@ mod tests {
452
608
generic_test_op ! (
453
609
a,
454
610
DataType :: Int32 ,
455
- Sum ,
611
+ SumTestStandin ,
456
612
ScalarValue :: Int64 ( None ) ,
457
613
DataType :: Int64
458
614
)
@@ -465,7 +621,7 @@ mod tests {
465
621
generic_test_op ! (
466
622
a,
467
623
DataType :: UInt32 ,
468
- Sum ,
624
+ SumTestStandin ,
469
625
ScalarValue :: from( 15u64 ) ,
470
626
DataType :: UInt64
471
627
)
@@ -478,7 +634,7 @@ mod tests {
478
634
generic_test_op ! (
479
635
a,
480
636
DataType :: Float32 ,
481
- Sum ,
637
+ SumTestStandin ,
482
638
ScalarValue :: from( 15_f32 ) ,
483
639
DataType :: Float32
484
640
)
@@ -491,7 +647,7 @@ mod tests {
491
647
generic_test_op ! (
492
648
a,
493
649
DataType :: Float64 ,
494
- Sum ,
650
+ SumTestStandin ,
495
651
ScalarValue :: from( 15_f64 ) ,
496
652
DataType :: Float64
497
653
)
0 commit comments