@@ -137,50 +137,133 @@ impl AggregateExpr for Sum {
137
137
use arrow:: datatypes:: ArrowPrimitiveType ;
138
138
139
139
macro_rules! make_accumulator {
140
- ( $T: ty, $U: ty) => { Box :: new( PrimitiveGroupsAccumulator :: <
141
- $T,
142
- $U,
143
- _,
144
- _,
145
- >:: new( & <$T as ArrowPrimitiveType >:: DATA_TYPE , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$U as ArrowPrimitiveType >:: Native | {
146
- * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
147
- } , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$T as ArrowPrimitiveType >:: Native | { * x = * x + y; } ) ) } ;
140
+ ( $T: ty, $U: ty) => {
141
+ Box :: new( PrimitiveGroupsAccumulator :: <$T, $U, _, _>:: new(
142
+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
143
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
144
+ y: <$U as ArrowPrimitiveType >:: Native | {
145
+ * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
146
+ } ,
147
+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
148
+ y: <$T as ArrowPrimitiveType >:: Native | {
149
+ * x = * x + y;
150
+ } ,
151
+ ) )
152
+ } ;
148
153
}
149
154
150
155
// Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
151
156
// the current datafusion Sum accumulator implementation using native +. (That native +
152
157
// specifically is the one in the expressions *x = *x + ... above.)
153
158
Ok ( Some ( match ( & self . data_type , & self . input_data_type ) {
154
- ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int64Type ) ,
155
- ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int32Type ) ,
156
- ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int16Type ) ,
157
- ( DataType :: Int64 , DataType :: Int8 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type ) ,
158
-
159
- ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! ( arrow:: datatypes:: Int96Type , arrow:: datatypes:: Int96Type ) ,
160
-
161
- ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type , arrow:: datatypes:: Int64Decimal0Type ) ,
162
- ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type , arrow:: datatypes:: Int64Decimal1Type ) ,
163
- ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type , arrow:: datatypes:: Int64Decimal2Type ) ,
164
- ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type , arrow:: datatypes:: Int64Decimal3Type ) ,
165
- ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type , arrow:: datatypes:: Int64Decimal4Type ) ,
166
- ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type , arrow:: datatypes:: Int64Decimal5Type ) ,
167
- ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type , arrow:: datatypes:: Int64Decimal10Type ) ,
168
-
169
- ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type , arrow:: datatypes:: Int96Decimal0Type ) ,
170
- ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type , arrow:: datatypes:: Int96Decimal1Type ) ,
171
- ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type , arrow:: datatypes:: Int96Decimal2Type ) ,
172
- ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type , arrow:: datatypes:: Int96Decimal3Type ) ,
173
- ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type , arrow:: datatypes:: Int96Decimal4Type ) ,
174
- ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type , arrow:: datatypes:: Int96Decimal5Type ) ,
175
- ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type , arrow:: datatypes:: Int96Decimal10Type ) ,
176
-
177
- ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt64Type ) ,
178
- ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt32Type ) ,
179
- ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt16Type ) ,
180
- ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt8Type ) ,
181
-
182
- ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! ( arrow:: datatypes:: Float32Type , arrow:: datatypes:: Float32Type ) ,
183
- ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! ( arrow:: datatypes:: Float64Type , arrow:: datatypes:: Float64Type ) ,
159
+ ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! (
160
+ arrow:: datatypes:: Int64Type ,
161
+ arrow:: datatypes:: Int64Type
162
+ ) ,
163
+ ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! (
164
+ arrow:: datatypes:: Int64Type ,
165
+ arrow:: datatypes:: Int32Type
166
+ ) ,
167
+ ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! (
168
+ arrow:: datatypes:: Int64Type ,
169
+ arrow:: datatypes:: Int16Type
170
+ ) ,
171
+ ( DataType :: Int64 , DataType :: Int8 ) => {
172
+ make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type )
173
+ }
174
+
175
+ ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! (
176
+ arrow:: datatypes:: Int96Type ,
177
+ arrow:: datatypes:: Int96Type
178
+ ) ,
179
+
180
+ ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! (
181
+ arrow:: datatypes:: Int64Decimal0Type ,
182
+ arrow:: datatypes:: Int64Decimal0Type
183
+ ) ,
184
+ ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! (
185
+ arrow:: datatypes:: Int64Decimal1Type ,
186
+ arrow:: datatypes:: Int64Decimal1Type
187
+ ) ,
188
+ ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! (
189
+ arrow:: datatypes:: Int64Decimal2Type ,
190
+ arrow:: datatypes:: Int64Decimal2Type
191
+ ) ,
192
+ ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! (
193
+ arrow:: datatypes:: Int64Decimal3Type ,
194
+ arrow:: datatypes:: Int64Decimal3Type
195
+ ) ,
196
+ ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! (
197
+ arrow:: datatypes:: Int64Decimal4Type ,
198
+ arrow:: datatypes:: Int64Decimal4Type
199
+ ) ,
200
+ ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! (
201
+ arrow:: datatypes:: Int64Decimal5Type ,
202
+ arrow:: datatypes:: Int64Decimal5Type
203
+ ) ,
204
+ ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => {
205
+ make_accumulator ! (
206
+ arrow:: datatypes:: Int64Decimal10Type ,
207
+ arrow:: datatypes:: Int64Decimal10Type
208
+ )
209
+ }
210
+
211
+ ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! (
212
+ arrow:: datatypes:: Int96Decimal0Type ,
213
+ arrow:: datatypes:: Int96Decimal0Type
214
+ ) ,
215
+ ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! (
216
+ arrow:: datatypes:: Int96Decimal1Type ,
217
+ arrow:: datatypes:: Int96Decimal1Type
218
+ ) ,
219
+ ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! (
220
+ arrow:: datatypes:: Int96Decimal2Type ,
221
+ arrow:: datatypes:: Int96Decimal2Type
222
+ ) ,
223
+ ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! (
224
+ arrow:: datatypes:: Int96Decimal3Type ,
225
+ arrow:: datatypes:: Int96Decimal3Type
226
+ ) ,
227
+ ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! (
228
+ arrow:: datatypes:: Int96Decimal4Type ,
229
+ arrow:: datatypes:: Int96Decimal4Type
230
+ ) ,
231
+ ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! (
232
+ arrow:: datatypes:: Int96Decimal5Type ,
233
+ arrow:: datatypes:: Int96Decimal5Type
234
+ ) ,
235
+ ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => {
236
+ make_accumulator ! (
237
+ arrow:: datatypes:: Int96Decimal10Type ,
238
+ arrow:: datatypes:: Int96Decimal10Type
239
+ )
240
+ }
241
+
242
+ ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! (
243
+ arrow:: datatypes:: UInt64Type ,
244
+ arrow:: datatypes:: UInt64Type
245
+ ) ,
246
+ ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! (
247
+ arrow:: datatypes:: UInt64Type ,
248
+ arrow:: datatypes:: UInt32Type
249
+ ) ,
250
+ ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! (
251
+ arrow:: datatypes:: UInt64Type ,
252
+ arrow:: datatypes:: UInt16Type
253
+ ) ,
254
+ ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! (
255
+ arrow:: datatypes:: UInt64Type ,
256
+ arrow:: datatypes:: UInt8Type
257
+ ) ,
258
+
259
+ ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! (
260
+ arrow:: datatypes:: Float32Type ,
261
+ arrow:: datatypes:: Float32Type
262
+ ) ,
263
+ ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! (
264
+ arrow:: datatypes:: Float64Type ,
265
+ arrow:: datatypes:: Float64Type
266
+ ) ,
184
267
185
268
_ => {
186
269
// This case should never be reached because we've handled all sum_return_type
@@ -479,9 +562,11 @@ mod tests {
479
562
// generic_test_op!.
480
563
struct SumTestStandin ;
481
564
impl SumTestStandin {
482
- fn new ( expr : Arc < dyn PhysicalExpr > ,
483
- name : impl Into < String > ,
484
- data_type : DataType ) -> Sum {
565
+ fn new (
566
+ expr : Arc < dyn PhysicalExpr > ,
567
+ name : impl Into < String > ,
568
+ data_type : DataType ,
569
+ ) -> Sum {
485
570
Sum :: new ( expr, name, data_type. clone ( ) , & data_type)
486
571
}
487
572
}
0 commit comments