@@ -24,10 +24,11 @@ use std::sync::Arc;
24
24
use crate :: error:: { DataFusionError , Result } ;
25
25
use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
26
26
use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27
+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
27
28
use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
28
29
use crate :: scalar:: ScalarValue ;
29
30
use arrow:: compute;
30
- use arrow:: datatypes:: { DataType , TimeUnit } ;
31
+ use arrow:: datatypes:: { ArrowPrimitiveType , DataType , TimeUnit } ;
31
32
use arrow:: {
32
33
array:: {
33
34
ArrayRef , Float32Array , Float64Array , Int16Array , Int32Array , Int64Array ,
@@ -108,12 +109,92 @@ impl AggregateExpr for Max {
108
109
fn create_groups_accumulator (
109
110
& self ,
110
111
) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
111
- let data_type = self . data_type . clone ( ) ;
112
- Ok ( Some ( Box :: new (
113
- GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new ( move || {
114
- MaxAccumulator :: try_new ( & data_type)
115
- } ) ,
116
- ) ) )
112
+ use arrow:: datatypes:: ArrowPrimitiveType ;
113
+
114
+ macro_rules! make_max_accumulator {
115
+ ( $T: ty) => {
116
+ Box :: new(
117
+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
118
+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
119
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
120
+ y: <$T as ArrowPrimitiveType >:: Native | {
121
+ * x = ( * x) . max( y) ;
122
+ } ,
123
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
124
+ y: <$T as ArrowPrimitiveType >:: Native | {
125
+ * x = ( * x) . max( y) ;
126
+ } ,
127
+ )
128
+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MIN ) ,
129
+ )
130
+ } ;
131
+ }
132
+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
133
+ DataType :: Float64 => make_max_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
134
+ DataType :: Float32 => make_max_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
135
+ DataType :: Int64 => make_max_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
136
+ DataType :: Int96 => make_max_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
137
+ DataType :: Int64Decimal ( 0 ) => {
138
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
139
+ }
140
+ DataType :: Int64Decimal ( 1 ) => {
141
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
142
+ }
143
+ DataType :: Int64Decimal ( 2 ) => {
144
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
145
+ }
146
+ DataType :: Int64Decimal ( 3 ) => {
147
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
148
+ }
149
+ DataType :: Int64Decimal ( 4 ) => {
150
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
151
+ }
152
+ DataType :: Int64Decimal ( 5 ) => {
153
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
154
+ }
155
+ DataType :: Int64Decimal ( 10 ) => {
156
+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
157
+ }
158
+ DataType :: Int96Decimal ( 0 ) => {
159
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
160
+ }
161
+ DataType :: Int96Decimal ( 1 ) => {
162
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
163
+ }
164
+ DataType :: Int96Decimal ( 2 ) => {
165
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
166
+ }
167
+ DataType :: Int96Decimal ( 3 ) => {
168
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
169
+ }
170
+ DataType :: Int96Decimal ( 4 ) => {
171
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
172
+ }
173
+ DataType :: Int96Decimal ( 5 ) => {
174
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
175
+ }
176
+ DataType :: Int96Decimal ( 10 ) => {
177
+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
178
+ }
179
+ DataType :: Int32 => make_max_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
180
+ DataType :: Int16 => make_max_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
181
+ DataType :: Int8 => make_max_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
182
+ DataType :: UInt64 => make_max_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
183
+ DataType :: UInt32 => make_max_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
184
+ DataType :: UInt16 => make_max_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
185
+ DataType :: UInt8 => make_max_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
186
+ _ => {
187
+ // Not all types (strings) can use primitive accumulators. And strings use
188
+ // max_string as the $OP in typed_min_match_batch.
189
+
190
+ // Timestamps presently take this branch.
191
+ let data_type = self . data_type . clone ( ) ;
192
+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new (
193
+ move || MaxAccumulator :: try_new ( & data_type) ,
194
+ ) )
195
+ }
196
+ } ;
197
+ Ok ( Some ( acc) )
117
198
}
118
199
119
200
fn name ( & self ) -> & str {
@@ -547,12 +628,91 @@ impl AggregateExpr for Min {
547
628
fn create_groups_accumulator (
548
629
& self ,
549
630
) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
550
- let data_type = self . data_type . clone ( ) ;
551
- Ok ( Some ( Box :: new (
552
- GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new ( move || {
553
- MinAccumulator :: try_new ( & data_type)
554
- } ) ,
555
- ) ) )
631
+ macro_rules! make_min_accumulator {
632
+ ( $T: ty) => {
633
+ Box :: new(
634
+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
635
+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
636
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
637
+ y: <$T as ArrowPrimitiveType >:: Native | {
638
+ * x = ( * x) . min( y) ;
639
+ } ,
640
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
641
+ y: <$T as ArrowPrimitiveType >:: Native | {
642
+ * x = ( * x) . min( y) ;
643
+ } ,
644
+ )
645
+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MAX ) ,
646
+ )
647
+ } ;
648
+ }
649
+
650
+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
651
+ DataType :: Float64 => make_min_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
652
+ DataType :: Float32 => make_min_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
653
+ DataType :: Int64 => make_min_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
654
+ DataType :: Int96 => make_min_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
655
+ DataType :: Int64Decimal ( 0 ) => {
656
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
657
+ }
658
+ DataType :: Int64Decimal ( 1 ) => {
659
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
660
+ }
661
+ DataType :: Int64Decimal ( 2 ) => {
662
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
663
+ }
664
+ DataType :: Int64Decimal ( 3 ) => {
665
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
666
+ }
667
+ DataType :: Int64Decimal ( 4 ) => {
668
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
669
+ }
670
+ DataType :: Int64Decimal ( 5 ) => {
671
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
672
+ }
673
+ DataType :: Int64Decimal ( 10 ) => {
674
+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
675
+ }
676
+ DataType :: Int96Decimal ( 0 ) => {
677
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
678
+ }
679
+ DataType :: Int96Decimal ( 1 ) => {
680
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
681
+ }
682
+ DataType :: Int96Decimal ( 2 ) => {
683
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
684
+ }
685
+ DataType :: Int96Decimal ( 3 ) => {
686
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
687
+ }
688
+ DataType :: Int96Decimal ( 4 ) => {
689
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
690
+ }
691
+ DataType :: Int96Decimal ( 5 ) => {
692
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
693
+ }
694
+ DataType :: Int96Decimal ( 10 ) => {
695
+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
696
+ }
697
+ DataType :: Int32 => make_min_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
698
+ DataType :: Int16 => make_min_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
699
+ DataType :: Int8 => make_min_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
700
+ DataType :: UInt64 => make_min_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
701
+ DataType :: UInt32 => make_min_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
702
+ DataType :: UInt16 => make_min_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
703
+ DataType :: UInt8 => make_min_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
704
+ _ => {
705
+ // Not all types (strings) can use primitive accumulators. And strings use
706
+ // min_string as the $OP in typed_min_match_batch.
707
+
708
+ // Timestamps presently take this branch.
709
+ let data_type = self . data_type . clone ( ) ;
710
+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new (
711
+ move || MinAccumulator :: try_new ( & data_type) ,
712
+ ) )
713
+ }
714
+ } ;
715
+ Ok ( Some ( acc) )
556
716
}
557
717
558
718
fn name ( & self ) -> & str {
0 commit comments