@@ -137,50 +137,133 @@ impl AggregateExpr for Sum {
137137 use arrow:: datatypes:: ArrowPrimitiveType ;
138138
139139 macro_rules! make_accumulator {
140- ( $T: ty, $U: ty) => { Box :: new( PrimitiveGroupsAccumulator :: <
141- $T,
142- $U,
143- _,
144- _,
145- >:: new( & <$T as ArrowPrimitiveType >:: DATA_TYPE , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$U as ArrowPrimitiveType >:: Native | {
146- * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
147- } , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$T as ArrowPrimitiveType >:: Native | { * x = * x + y; } ) ) } ;
140+ ( $T: ty, $U: ty) => {
141+ Box :: new( PrimitiveGroupsAccumulator :: <$T, $U, _, _>:: new(
142+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
143+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
144+ y: <$U as ArrowPrimitiveType >:: Native | {
145+ * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
146+ } ,
147+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
148+ y: <$T as ArrowPrimitiveType >:: Native | {
149+ * x = * x + y;
150+ } ,
151+ ) )
152+ } ;
148153 }
149154
150155 // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
151156 // the current datafusion Sum accumulator implementation using native +. (That native +
152157 // specifically is the one in the expressions *x = *x + ... above.)
153158 Ok ( Some ( match ( & self . data_type , & self . input_data_type ) {
154- ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int64Type ) ,
155- ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int32Type ) ,
156- ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int16Type ) ,
157- ( DataType :: Int64 , DataType :: Int8 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type ) ,
158-
159- ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! ( arrow:: datatypes:: Int96Type , arrow:: datatypes:: Int96Type ) ,
160-
161- ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type , arrow:: datatypes:: Int64Decimal0Type ) ,
162- ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type , arrow:: datatypes:: Int64Decimal1Type ) ,
163- ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type , arrow:: datatypes:: Int64Decimal2Type ) ,
164- ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type , arrow:: datatypes:: Int64Decimal3Type ) ,
165- ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type , arrow:: datatypes:: Int64Decimal4Type ) ,
166- ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type , arrow:: datatypes:: Int64Decimal5Type ) ,
167- ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type , arrow:: datatypes:: Int64Decimal10Type ) ,
168-
169- ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type , arrow:: datatypes:: Int96Decimal0Type ) ,
170- ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type , arrow:: datatypes:: Int96Decimal1Type ) ,
171- ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type , arrow:: datatypes:: Int96Decimal2Type ) ,
172- ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type , arrow:: datatypes:: Int96Decimal3Type ) ,
173- ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type , arrow:: datatypes:: Int96Decimal4Type ) ,
174- ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type , arrow:: datatypes:: Int96Decimal5Type ) ,
175- ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type , arrow:: datatypes:: Int96Decimal10Type ) ,
176-
177- ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt64Type ) ,
178- ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt32Type ) ,
179- ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt16Type ) ,
180- ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt8Type ) ,
181-
182- ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! ( arrow:: datatypes:: Float32Type , arrow:: datatypes:: Float32Type ) ,
183- ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! ( arrow:: datatypes:: Float64Type , arrow:: datatypes:: Float64Type ) ,
159+ ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! (
160+ arrow:: datatypes:: Int64Type ,
161+ arrow:: datatypes:: Int64Type
162+ ) ,
163+ ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! (
164+ arrow:: datatypes:: Int64Type ,
165+ arrow:: datatypes:: Int32Type
166+ ) ,
167+ ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! (
168+ arrow:: datatypes:: Int64Type ,
169+ arrow:: datatypes:: Int16Type
170+ ) ,
171+ ( DataType :: Int64 , DataType :: Int8 ) => {
172+ make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type )
173+ }
174+
175+ ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! (
176+ arrow:: datatypes:: Int96Type ,
177+ arrow:: datatypes:: Int96Type
178+ ) ,
179+
180+ ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! (
181+ arrow:: datatypes:: Int64Decimal0Type ,
182+ arrow:: datatypes:: Int64Decimal0Type
183+ ) ,
184+ ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! (
185+ arrow:: datatypes:: Int64Decimal1Type ,
186+ arrow:: datatypes:: Int64Decimal1Type
187+ ) ,
188+ ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! (
189+ arrow:: datatypes:: Int64Decimal2Type ,
190+ arrow:: datatypes:: Int64Decimal2Type
191+ ) ,
192+ ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! (
193+ arrow:: datatypes:: Int64Decimal3Type ,
194+ arrow:: datatypes:: Int64Decimal3Type
195+ ) ,
196+ ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! (
197+ arrow:: datatypes:: Int64Decimal4Type ,
198+ arrow:: datatypes:: Int64Decimal4Type
199+ ) ,
200+ ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! (
201+ arrow:: datatypes:: Int64Decimal5Type ,
202+ arrow:: datatypes:: Int64Decimal5Type
203+ ) ,
204+ ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => {
205+ make_accumulator ! (
206+ arrow:: datatypes:: Int64Decimal10Type ,
207+ arrow:: datatypes:: Int64Decimal10Type
208+ )
209+ }
210+
211+ ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! (
212+ arrow:: datatypes:: Int96Decimal0Type ,
213+ arrow:: datatypes:: Int96Decimal0Type
214+ ) ,
215+ ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! (
216+ arrow:: datatypes:: Int96Decimal1Type ,
217+ arrow:: datatypes:: Int96Decimal1Type
218+ ) ,
219+ ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! (
220+ arrow:: datatypes:: Int96Decimal2Type ,
221+ arrow:: datatypes:: Int96Decimal2Type
222+ ) ,
223+ ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! (
224+ arrow:: datatypes:: Int96Decimal3Type ,
225+ arrow:: datatypes:: Int96Decimal3Type
226+ ) ,
227+ ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! (
228+ arrow:: datatypes:: Int96Decimal4Type ,
229+ arrow:: datatypes:: Int96Decimal4Type
230+ ) ,
231+ ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! (
232+ arrow:: datatypes:: Int96Decimal5Type ,
233+ arrow:: datatypes:: Int96Decimal5Type
234+ ) ,
235+ ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => {
236+ make_accumulator ! (
237+ arrow:: datatypes:: Int96Decimal10Type ,
238+ arrow:: datatypes:: Int96Decimal10Type
239+ )
240+ }
241+
242+ ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! (
243+ arrow:: datatypes:: UInt64Type ,
244+ arrow:: datatypes:: UInt64Type
245+ ) ,
246+ ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! (
247+ arrow:: datatypes:: UInt64Type ,
248+ arrow:: datatypes:: UInt32Type
249+ ) ,
250+ ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! (
251+ arrow:: datatypes:: UInt64Type ,
252+ arrow:: datatypes:: UInt16Type
253+ ) ,
254+ ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! (
255+ arrow:: datatypes:: UInt64Type ,
256+ arrow:: datatypes:: UInt8Type
257+ ) ,
258+
259+ ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! (
260+ arrow:: datatypes:: Float32Type ,
261+ arrow:: datatypes:: Float32Type
262+ ) ,
263+ ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! (
264+ arrow:: datatypes:: Float64Type ,
265+ arrow:: datatypes:: Float64Type
266+ ) ,
184267
185268 _ => {
186269 // This case should never be reached because we've handled all sum_return_type
@@ -479,9 +562,11 @@ mod tests {
479562 // generic_test_op!.
480563 struct SumTestStandin ;
481564 impl SumTestStandin {
482- fn new ( expr : Arc < dyn PhysicalExpr > ,
483- name : impl Into < String > ,
484- data_type : DataType ) -> Sum {
565+ fn new (
566+ expr : Arc < dyn PhysicalExpr > ,
567+ name : impl Into < String > ,
568+ data_type : DataType ,
569+ ) -> Sum {
485570 Sum :: new ( expr, name, data_type. clone ( ) , & data_type)
486571 }
487572 }
0 commit comments