Skip to content

Commit 07cdc38

Browse files
committed
perf: Use flattened_group_by_values to accumulate group keys for output
1 parent fa95bd3 commit 07cdc38

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,15 @@ pub(crate) fn group_aggregate_batch(
463463
// 1.2
464464
.or_insert_with(|| {
465465
batch_keys.append_value(&key).expect("must not fail");
466+
// Note that we still use plain String objects in GroupByScalar. Thus flattened_group_by_values isn't that great.
466467
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
467-
let mut taken_values =
468-
smallvec![GroupByScalar::UInt32(0); group_values.len()];
469-
std::mem::swap(&mut taken_values, &mut group_by_values);
468+
accumulation_state.flattened_group_by_values.extend(
469+
group_by_values.iter_mut().map(|x| std::mem::replace(x, GroupByScalar::UInt32(0))));
470470
let group_index = accumulation_state.next_group_index;
471471
accumulation_state.next_group_index += 1;
472472
(
473473
key.clone(),
474474
AccumulationGroupState {
475-
group_by_values: taken_values,
476475
indices: smallvec![row as u32],
477476
group_index,
478477
},
@@ -884,7 +883,6 @@ pub type Accumulators = HashMap<KeyVec, AccumulationGroupState, RandomState>;
884883

885884
#[allow(missing_docs)]
886885
pub struct AccumulationGroupState {
887-
group_by_values: SmallVec<[GroupByScalar; 2]>,
888886
indices: SmallVec<[u32; 4]>,
889887
group_index: usize,
890888
}
@@ -893,6 +891,8 @@ pub struct AccumulationGroupState {
893891
#[derive(Default)]
894892
pub struct AccumulationState {
895893
accumulators: HashMap<KeyVec, AccumulationGroupState, RandomState>,
894+
// Of length accumulators.len() * N where N is the number of group by columns.
895+
flattened_group_by_values: Vec<GroupByScalar>,
896896
groups_accumulators: Vec<Box<dyn GroupsAccumulator>>,
897897
// For now, always equal to accumulators.len()
898898
next_group_index: usize,
@@ -905,6 +905,7 @@ impl AccumulationState {
905905
) -> AccumulationState {
906906
AccumulationState {
907907
accumulators: HashMap::new(),
908+
flattened_group_by_values: Vec::new(),
908909
groups_accumulators,
909910
next_group_index: 0,
910911
}
@@ -1174,12 +1175,13 @@ pub(crate) fn create_batch_from_map(
11741175
for (
11751176
_,
11761177
AccumulationGroupState {
1177-
group_by_values,
11781178
group_index,
11791179
..
11801180
},
11811181
) in &accumulation_state.accumulators
11821182
{
1183+
let group_by_values: &[GroupByScalar] = &accumulation_state.flattened_group_by_values[num_group_expr * group_index..num_group_expr * (group_index + 1)];
1184+
11831185
// 2 and 3.
11841186
write_group_result_row_with_groups_accumulator(
11851187
*mode,

0 commit comments

Comments
 (0)