From 9f387f25997b0466fc0062de6aa0f7c430ad6e45 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 21 Sep 2024 10:41:33 +0200 Subject: [PATCH 1/6] Fix grouping sets behavior when data contains nulls --- datafusion/core/src/physical_planner.rs | 16 +- .../src/combine_partial_final_agg.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 364 +++++++++++------- .../physical-plan/src/aggregates/row_hash.rs | 11 +- .../sqllogictest/test_files/aggregate.slt | 14 +- 5 files changed, 257 insertions(+), 150 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 78c70606bf688..f60622234a98d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner { physical_input_schema.clone(), )?); - // update group column indices based on partial aggregate plan evaluation - let final_group: Vec> = - initial_aggr.output_group_expr(); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); @@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner { AggregateMode::Final }; - let final_grouping_set = PhysicalGroupBy::new_single( - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) - .collect(), - ); + let final_grouping_set = initial_aggr.group_expr().as_final(); Arc::new(AggregateExec::try_new( next_partition_mode, @@ -2053,7 +2043,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_internal_exprs: 1 })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2080,7 +2070,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_internal_exprs: 1 })"#; assert_eq!(format!("{rollup:?}"), expected); diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 4e352e25b52c9..bc1642bf7952c 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { // Compare output expressions of the partial, and input expressions of the final operator. physical_exprs_equal( - &input_group_by.output_exprs(), + &input_group_by.output_exprs(&AggregateMode::Partial), &final_group_by.input_exprs(), ) && input_group_by.groups() == final_group_by.groups() && input_group_by.null_expr().len() == final_group_by.null_expr().len() diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9466ff6dd4591..2bbfc3e41922b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,6 +36,8 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; @@ -109,6 +111,8 @@ impl AggregateMode { } } +const INTERNAL_GROUPING_ID: &str = "grouping_id"; + /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -138,6 +142,10 @@ pub struct PhysicalGroupBy { /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, + // The number of internal expressions that are used to implement grouping + // sets. These output are removed from the final output and not in `expr` + // as they are generated based on the value in `groups` + num_internal_exprs: usize, } impl PhysicalGroupBy { @@ -147,10 +155,12 @@ impl PhysicalGroupBy { null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { + let num_internal_exprs = if !null_expr.is_empty() { 1 } else { 0 }; Self { expr, null_expr, groups, + num_internal_exprs, } } @@ -162,6 +172,7 @@ impl PhysicalGroupBy { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], + num_internal_exprs: 0, } } @@ -211,13 +222,115 @@ impl PhysicalGroupBy { .collect() } + /// The number of expressions in the output schema. + fn num_output_exprs(&self, mode: &AggregateMode) -> usize { + let mut num_exprs = self.expr.len(); + if !self.is_single() { + num_exprs += self.num_internal_exprs; + } + if *mode != AggregateMode::Partial { + num_exprs -= self.num_internal_exprs; + } + num_exprs + } + /// Return grouping expressions as they occur in the output schema. - pub fn output_exprs(&self) -> Vec> { - self.expr - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + pub fn output_exprs(&self, mode: &AggregateMode) -> Vec> { + let num_output_exprs = self.num_output_exprs(mode); + let mut output_exprs = Vec::with_capacity(num_output_exprs); + output_exprs.extend( + self.expr + .iter() + .enumerate() + .take(num_output_exprs) + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), + ); + if !self.is_single() && *mode == AggregateMode::Partial { + output_exprs + .push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _); + } + output_exprs + } + + /// Returns the number expression as grouping keys. + fn num_group_exprs(&self) -> usize { + if self.is_single() { + self.expr.len() + } else { + self.expr.len() + self.num_internal_exprs + } + } + + /// Returns the data type of the grouping id. + fn grouping_id_type(&self) -> DataType { + if self.expr.len() <= 8 { + DataType::UInt8 + } else if self.expr.len() <= 16 { + DataType::UInt16 + } else if self.expr.len() <= 32 { + DataType::UInt32 + } else { + DataType::UInt64 + } + } + + /// Returns the fields that are used as the grouping keys. + fn group_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = Vec::with_capacity(self.num_group_exprs()); + for ((expr, name), group_expr_nullable) in + self.expr.iter().zip(self.exprs_nullable().into_iter()) + { + fields.push(Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + )) + .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()); + } + if !self.is_single() { + fields.push(Field::new( + INTERNAL_GROUPING_ID, + self.grouping_id_type(), + false, + )); + } + Ok(fields) + } + + /// Returns the output fields of the group by. + /// + /// This might be different from the `group_fields` that might contain internal expressions that + /// should not be part of the output schema. + fn output_fields( + &self, + input_schema: &Schema, + mode: &AggregateMode, + ) -> Result> { + let mut fields = self.group_fields(input_schema)?; + fields.truncate(self.num_output_exprs(mode)); + Ok(fields) + } + + /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial + /// aggregation. + pub fn as_final(&self) -> PhysicalGroupBy { + let expr: Vec<_> = self + .output_exprs(&AggregateMode::Partial) + .into_iter() + .zip( + self.expr + .iter() + .map(|t| t.1.clone()) + .chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())), + ) + .collect(); + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + num_internal_exprs: self.num_internal_exprs, + } } } @@ -321,13 +434,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( - &input.schema(), - &group_by.expr, - &aggr_expr, - group_by.exprs_nullable(), - mode, - )?; + let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); AggregateExec::try_new_with_schema( @@ -459,7 +566,7 @@ impl AggregateExec { /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { - self.group_by.output_exprs() + self.group_by.output_exprs(&AggregateMode::Partial) } /// Aggregate expressions @@ -789,25 +896,13 @@ impl ExecutionPlan for AggregateExec { fn create_schema( input_schema: &Schema, - group_expr: &[(Arc, String)], + group_by: &PhysicalGroupBy, aggr_expr: &[AggregateFunctionExpr], - group_expr_nullable: Vec, mode: AggregateMode, ) -> Result { - let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push( - Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - ) - .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()), - ) - } + let mut fields = + Vec::with_capacity(group_by.num_output_exprs(&mode) + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema, &mode)?); match mode { AggregateMode::Partial => { @@ -833,9 +928,8 @@ fn create_schema( )) } -fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { - let group_fields = schema.fields()[0..group_count].to_vec(); - Arc::new(Schema::new(group_fields)) +fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { + Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) } /// Determines the lexical ordering requirement for an aggregate expression. @@ -1142,6 +1236,27 @@ fn evaluate_optional( .collect() } +fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { + if group.len() > 64 { + return not_impl_err!( + "Grouping sets with more than 64 columns are not supported" + ); + } + let group_id = group.iter().fold(0u64, |acc, &is_null| { + (acc << 1) | if is_null { 1 } else { 0 } + }); + let num_rows = batch.num_rows(); + if group.len() <= 8 { + Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) + } else if group.len() <= 16 { + Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) + } else if group.len() <= 32 { + Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + } else { + Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + } +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: @@ -1174,23 +1289,24 @@ pub(crate) fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by + group_by .groups .iter() .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - Arc::clone(&null_exprs[idx]) - } else { - Arc::clone(&exprs[idx]) - } - }) - .collect() + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); + group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { + if *is_null { + Arc::clone(&null_exprs[idx]) + } else { + Arc::clone(&exprs[idx]) + } + })); + if !group_by.is_single() { + group_values.push(group_id_array(group, batch)?); + } + Ok(group_values) }) - .collect()) + .collect() } #[cfg(test)] @@ -1348,21 +1464,21 @@ mod tests { ) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::UInt32(None)), "a".to_string()), (lit(ScalarValue::Float64(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![true, false], // (NULL, b) vec![false, false], // (a,b) ], - }; + ); let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) .schema(Arc::clone(&input_schema)) @@ -1392,63 +1508,56 @@ mod tests { // In spill mode, we test with the limited memory, if the mem usage exceeds, // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 1 |", - "| | 1.0 | 1 |", - "| | 2.0 | 1 |", - "| | 2.0 | 1 |", - "| | 3.0 | 1 |", - "| | 3.0 | 1 |", - "| | 4.0 | 1 |", - "| | 4.0 | 1 |", - "| 2 | | 1 |", - "| 2 | | 1 |", - "| 2 | 1.0 | 1 |", - "| 2 | 1.0 | 1 |", - "| 3 | | 1 |", - "| 3 | | 2 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 1 |", - "| 4 | | 2 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+-------------+-----------------+", + "| a | b | grouping_id | COUNT(1)[count] |", + "+---+-----+-------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+-------------+-----------------+", ] } else { vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+-------------+-----------------+", + "| a | b | grouping_id | COUNT(1)[count] |", + "+---+-----+-------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+-------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); - let groups = partial_aggregate.group_expr().expr().to_vec(); - let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = groups - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let task_ctx = if spill { new_spill_ctx(4, 3160) @@ -1503,11 +1612,11 @@ mod tests { async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); let aggregates: Vec = vec![ @@ -1563,13 +1672,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = grouping_set - .expr - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, @@ -1825,11 +1928,11 @@ mod tests { let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); - let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let groups_some = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); // something that allocates within the aggregator let aggregates_v0: Vec = @@ -2306,7 +2409,7 @@ mod tests { )?); let aggregate_exec = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups, aggregates.clone(), vec![None], @@ -2318,13 +2421,13 @@ mod tests { collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; let expected = [ - "+-----+-----+-------+----------+", - "| a | b | const | 1[count] |", - "+-----+-----+-------+----------+", - "| | 0.0 | | 32768 |", - "| 0.0 | | | 32768 |", - "| | | 1 | 32768 |", - "+-----+-----+-------+----------+", + "+-----+-----+-------+-------+", + "| a | b | const | 1 |", + "+-----+-----+-------+-------+", + "| | 0.0 | | 32768 |", + "| 0.0 | | | 32768 |", + "| | | 1 | 32768 |", + "+-----+-----+-------+-------+", ]; assert_batches_sorted_eq!(expected, &output); @@ -2638,25 +2741,24 @@ mod tests { .build()?, ]; - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::Float32(None)), "a".to_string()), (lit(ScalarValue::Float32(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![false, false], // (a,b) ], - }; + ); let aggr_schema = create_schema( &input_schema, - &grouping_set.expr, + &grouping_set, &aggr_expr, - grouping_set.exprs_nullable(), AggregateMode::Final, )?; let expected_schema = Schema::new(vec![ diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 9e4968f1123e7..384e440993368 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -449,13 +449,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; let filter_expressions = match agg.mode { @@ -473,7 +473,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; let spill_expr = group_schema .fields .into_iter() @@ -491,7 +491,7 @@ impl GroupedHashAggregateStream { let (ordering, _) = agg .properties() .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs()); + .find_longest_permutation(&agg_group_by.output_exprs(&agg.mode)); let group_ordering = GroupOrdering::try_new( &group_schema, &agg.input_order_mode, @@ -885,6 +885,9 @@ impl GroupedHashAggregateStream { } let mut output = self.group_values.emit(emit_to)?; + if !spilling { + output.truncate(self.group_by.num_output_exprs(&self.mode)); + } if let EmitTo::First(n) = emit_to { self.group_ordering.remove_groups(n); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a78ade81eeba5..fad0e3dadd8b8 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3520,6 +3520,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls ---- 1 5 +# grouping_sets with null values +query II rowsort +SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value) +---- +1 1 +3 3 +4 4 +5 5 +NULL 1 +NULL NULL + + statement ok DROP TABLE integers_with_nulls; @@ -4884,7 +4896,7 @@ logical_plan 03)----TableScan: aggregate_test_100 projection=[c2, c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, grouping_id@2 as grouping_id], aggr=[], lim=[3] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 From 15aafe366cccbf41feb2e54a97907cf6d0c024ad Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Wed, 25 Sep 2024 20:09:18 +0200 Subject: [PATCH 2/6] PR suggestion comment --- datafusion/physical-plan/src/aggregates/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2bbfc3e41922b..cddd90cd08f06 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -262,6 +262,9 @@ impl PhysicalGroupBy { } /// Returns the data type of the grouping id. + /// The grouping ID value is a bitmask where each set bit + /// indicates that the corresponding grouping expression is + /// null fn grouping_id_type(&self) -> DataType { if self.expr.len() <= 8 { DataType::UInt8 From c2567379232ca68859192acb6564cd0da6ff9667 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 27 Sep 2024 15:47:17 +0200 Subject: [PATCH 3/6] Update new test case --- datafusion/sqllogictest/test_files/group_by.slt | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index f561fa9e9ac8d..a80a0891e9770 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5152,8 +5152,6 @@ drop table test_case_expr statement ok drop table t; -# TODO: Current grouping set result is not align with Postgres and DuckDB, we might want to change the result -# See https://github.com/apache/datafusion/issues/12570 # test multi group by for binary type with nulls statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); @@ -5162,11 +5160,14 @@ query I?I select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); ---- 1 0a 2 -2 NULL 2 -NULL 0b 4 +2 NULL 1 +NULL 0b 2 1 NULL 2 -NULL NULL 3 +2 NULL 1 +NULL NULL 2 NULL 0a 2 +NULL NULL 1 +NULL 0b 2 statement ok drop table t; From 9e7c31408bbc6339b910a8093054c3cadf3f04f7 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 29 Sep 2024 09:58:21 +0200 Subject: [PATCH 4/6] Add grouping_id to the logical plan --- datafusion/core/src/dataframe/mod.rs | 17 ++ datafusion/core/src/physical_planner.rs | 6 +- datafusion/expr/src/logical_plan/plan.rs | 39 ++- datafusion/expr/src/utils.rs | 12 +- .../src/single_distinct_to_groupby.rs | 6 +- .../src/combine_partial_final_agg.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 233 ++++++++---------- .../physical-plan/src/aggregates/row_hash.rs | 5 +- .../sqllogictest/test_files/aggregate.slt | 20 +- .../tests/cases/roundtrip_logical_plan.rs | 5 +- 10 files changed, 190 insertions(+), 155 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index f5867881da139..67e2a4780d06c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -535,9 +535,26 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? .build()?; + let plan = if is_grouping_set { + let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len; + // For grouping sets we do a project to not expose the internal grouping id + let exprs = plan + .schema() + .columns() + .into_iter() + .enumerate() + .filter(|(idx, _)| *idx != grouping_id_pos) + .map(|(_, column)| Expr::Column(column)) + .collect::>(); + LogicalPlanBuilder::from(plan).project(exprs)?.build()? + } else { + plan + }; Ok(DataFrame { session_state: self.session_state, plan, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f60622234a98d..cf2a157b04b68 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2043,7 +2043,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_internal_exprs: 1 })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2070,7 +2070,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_internal_exprs: 1 })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2335,7 +2335,7 @@ mod tests { .expect("hash aggregate"); assert_eq!( "sum(aggregate_test_100.c3)", - final_hash_agg.schema().field(2).name() + final_hash_agg.schema().field(3).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 19e73140b75ca..5578b27de1f35 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::dml::CopyTo; use super::DdlStatement; @@ -2965,6 +2965,15 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + qualified_fields.push(( + None, + Field::new( + Self::INTERNAL_GROUPING_ID, + Self::grouping_id_type(qualified_fields.len()), + false, + ) + .into(), + )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); @@ -3016,9 +3025,19 @@ impl Aggregate { }) } + fn is_grouping_set(&self) -> bool { + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + } + /// Get the output expressions. fn output_expressions(&self) -> Result> { + static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; + if self.is_grouping_set() { + exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { + Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + })); + } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); Ok(exprs) @@ -3030,6 +3049,24 @@ impl Aggregate { pub fn group_expr_len(&self) -> Result { grouping_set_expr_count(&self.group_expr) } + + /// Returns the data type of the grouping id. + /// The grouping ID value is a bitmask where each set bit + /// indicates that the corresponding grouping expression is + /// null + pub fn grouping_id_type(group_exprs: usize) -> DataType { + if group_exprs <= 8 { + DataType::UInt8 + } else if group_exprs <= 16 { + DataType::UInt16 + } else if group_exprs <= 32 { + DataType::UInt32 + } else { + DataType::UInt64 + } + } + + pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } // Manual implementation needed because of `schema` field. Comparison excludes this field. diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 9bb53a1d04a07..f377f0fe59bf5 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -60,7 +60,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { - grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) + if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if group_expr.len() > 1 { + return plan_err!( + "Invalid group by expressions, GroupingSet must be the only expression" + ); + } + // Groupings sets have an additional interal column for the grouping id + Ok(grouping_set.distinct_expr().len() + 1) + } else { + grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) + } } /// The [power set] (or powerset) of a set S is the set of all subsets of S, \ diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1c22c2a4375ad..74251e5caad2b 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -355,7 +355,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -373,7 +373,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -392,7 +392,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index bc1642bf7952c..4e352e25b52c9 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { // Compare output expressions of the partial, and input expressions of the final operator. physical_exprs_equal( - &input_group_by.output_exprs(&AggregateMode::Partial), + &input_group_by.output_exprs(), &final_group_by.input_exprs(), ) && input_group_by.groups() == final_group_by.groups() && input_group_by.null_expr().len() == final_group_by.null_expr().len() diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index cddd90cd08f06..f9dd973c814e4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -37,11 +37,10 @@ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; -use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, expressions::Column, @@ -111,8 +110,6 @@ impl AggregateMode { } } -const INTERNAL_GROUPING_ID: &str = "grouping_id"; - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -142,10 +139,6 @@ pub struct PhysicalGroupBy { /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, - // The number of internal expressions that are used to implement grouping - // sets. These output are removed from the final output and not in `expr` - // as they are generated based on the value in `groups` - num_internal_exprs: usize, } impl PhysicalGroupBy { @@ -155,12 +148,10 @@ impl PhysicalGroupBy { null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { - let num_internal_exprs = if !null_expr.is_empty() { 1 } else { 0 }; Self { expr, null_expr, groups, - num_internal_exprs, } } @@ -172,7 +163,6 @@ impl PhysicalGroupBy { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], - num_internal_exprs: 0, } } @@ -223,20 +213,17 @@ impl PhysicalGroupBy { } /// The number of expressions in the output schema. - fn num_output_exprs(&self, mode: &AggregateMode) -> usize { + fn num_output_exprs(&self) -> usize { let mut num_exprs = self.expr.len(); if !self.is_single() { - num_exprs += self.num_internal_exprs; - } - if *mode != AggregateMode::Partial { - num_exprs -= self.num_internal_exprs; + num_exprs += 1 } num_exprs } /// Return grouping expressions as they occur in the output schema. - pub fn output_exprs(&self, mode: &AggregateMode) -> Vec> { - let num_output_exprs = self.num_output_exprs(mode); + pub fn output_exprs(&self) -> Vec> { + let num_output_exprs = self.num_output_exprs(); let mut output_exprs = Vec::with_capacity(num_output_exprs); output_exprs.extend( self.expr @@ -245,9 +232,11 @@ impl PhysicalGroupBy { .take(num_output_exprs) .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), ); - if !self.is_single() && *mode == AggregateMode::Partial { - output_exprs - .push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _); + if !self.is_single() { + output_exprs.push(Arc::new(Column::new( + Aggregate::INTERNAL_GROUPING_ID, + self.expr.len(), + )) as _); } output_exprs } @@ -257,23 +246,7 @@ impl PhysicalGroupBy { if self.is_single() { self.expr.len() } else { - self.expr.len() + self.num_internal_exprs - } - } - - /// Returns the data type of the grouping id. - /// The grouping ID value is a bitmask where each set bit - /// indicates that the corresponding grouping expression is - /// null - fn grouping_id_type(&self) -> DataType { - if self.expr.len() <= 8 { - DataType::UInt8 - } else if self.expr.len() <= 16 { - DataType::UInt16 - } else if self.expr.len() <= 32 { - DataType::UInt32 - } else { - DataType::UInt64 + self.expr.len() + 1 } } @@ -283,17 +256,21 @@ impl PhysicalGroupBy { for ((expr, name), group_expr_nullable) in self.expr.iter().zip(self.exprs_nullable().into_iter()) { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - group_expr_nullable || expr.nullable(input_schema)?, - )) - .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()); + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + ) + .with_metadata( + get_field_metadata(expr, input_schema).unwrap_or_default(), + ), + ); } if !self.is_single() { fields.push(Field::new( - INTERNAL_GROUPING_ID, - self.grouping_id_type(), + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), false, )); } @@ -304,35 +281,29 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields( - &self, - input_schema: &Schema, - mode: &AggregateMode, - ) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; - fields.truncate(self.num_output_exprs(mode)); + fields.truncate(self.num_output_exprs()); Ok(fields) } /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial /// aggregation. pub fn as_final(&self) -> PhysicalGroupBy { - let expr: Vec<_> = self - .output_exprs(&AggregateMode::Partial) - .into_iter() - .zip( - self.expr - .iter() - .map(|t| t.1.clone()) - .chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())), - ) - .collect(); + let expr: Vec<_> = + self.output_exprs() + .into_iter() + .zip( + self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( + Aggregate::INTERNAL_GROUPING_ID.to_owned(), + )), + ) + .collect(); let num_exprs = expr.len(); Self { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], - num_internal_exprs: self.num_internal_exprs, } } } @@ -569,7 +540,7 @@ impl AggregateExec { /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { - self.group_by.output_exprs(&AggregateMode::Partial) + self.group_by.output_exprs() } /// Aggregate expressions @@ -903,9 +874,8 @@ fn create_schema( aggr_expr: &[AggregateFunctionExpr], mode: AggregateMode, ) -> Result { - let mut fields = - Vec::with_capacity(group_by.num_output_exprs(&mode) + aggr_expr.len()); - fields.extend(group_by.output_fields(input_schema, &mode)?); + let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -1511,49 +1481,49 @@ mod tests { // In spill mode, we test with the limited memory, if the mem usage exceeds, // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-------------+-----------------+", - "| a | b | grouping_id | COUNT(1)[count] |", - "+---+-----+-------------+-----------------+", - "| | 1.0 | 2 | 1 |", - "| | 1.0 | 2 | 1 |", - "| | 2.0 | 2 | 1 |", - "| | 2.0 | 2 | 1 |", - "| | 3.0 | 2 | 1 |", - "| | 3.0 | 2 | 1 |", - "| | 4.0 | 2 | 1 |", - "| | 4.0 | 2 | 1 |", - "| 2 | | 1 | 1 |", - "| 2 | | 1 | 1 |", - "| 2 | 1.0 | 0 | 1 |", - "| 2 | 1.0 | 0 | 1 |", - "| 3 | | 1 | 1 |", - "| 3 | | 1 | 2 |", - "| 3 | 2.0 | 0 | 2 |", - "| 3 | 3.0 | 0 | 1 |", - "| 4 | | 1 | 1 |", - "| 4 | | 1 | 2 |", - "| 4 | 3.0 | 0 | 1 |", - "| 4 | 4.0 | 0 | 2 |", - "+---+-----+-------------+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] } else { vec![ - "+---+-----+-------------+-----------------+", - "| a | b | grouping_id | COUNT(1)[count] |", - "+---+-----+-------------+-----------------+", - "| | 1.0 | 2 | 2 |", - "| | 2.0 | 2 | 2 |", - "| | 3.0 | 2 | 2 |", - "| | 4.0 | 2 | 2 |", - "| 2 | | 1 | 2 |", - "| 2 | 1.0 | 0 | 2 |", - "| 3 | | 1 | 3 |", - "| 3 | 2.0 | 0 | 2 |", - "| 3 | 3.0 | 0 | 1 |", - "| 4 | | 1 | 3 |", - "| 4 | 3.0 | 0 | 1 |", - "| 4 | 4.0 | 0 | 2 |", - "+---+-----+-------------+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); @@ -1580,26 +1550,26 @@ mod tests { let result = common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; - assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); let expected = vec![ - "+---+-----+----------+", - "| a | b | COUNT(1) |", - "+---+-----+----------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+----------+", + "+---+-----+---------------+----------+", + "| a | b | __grouping_id | COUNT(1) |", + "+---+-----+---------------+----------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+----------+", ]; assert_batches_sorted_eq!(&expected, &result); @@ -2424,13 +2394,13 @@ mod tests { collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; let expected = [ - "+-----+-----+-------+-------+", - "| a | b | const | 1 |", - "+-----+-----+-------+-------+", - "| | 0.0 | | 32768 |", - "| 0.0 | | | 32768 |", - "| | | 1 | 32768 |", - "+-----+-----+-------+-------+", + "+-----+-----+-------+---------------+-------+", + "| a | b | const | __grouping_id | 1 |", + "+-----+-----+-------+---------------+-------+", + "| | | 1 | 6 | 32768 |", + "| | 0.0 | | 5 | 32768 |", + "| 0.0 | | | 3 | 32768 |", + "+-----+-----+-------+---------------+-------+", ]; assert_batches_sorted_eq!(expected, &output); @@ -2767,6 +2737,7 @@ mod tests { let expected_schema = Schema::new(vec![ Field::new("a", DataType::Float32, false), Field::new("b", DataType::Float32, true), + Field::new("__grouping_id", DataType::UInt8, false), Field::new("COUNT(a)", DataType::Int64, false), ]); assert_eq!(aggr_schema, expected_schema); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 384e440993368..5121e6cc3b354 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -491,7 +491,7 @@ impl GroupedHashAggregateStream { let (ordering, _) = agg .properties() .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs(&agg.mode)); + .find_longest_permutation(&agg_group_by.output_exprs()); let group_ordering = GroupOrdering::try_new( &group_schema, &agg.input_order_mode, @@ -885,9 +885,6 @@ impl GroupedHashAggregateStream { } let mut output = self.group_values.emit(emit_to)?; - if !spilling { - output.truncate(self.group_by.num_output_exprs(&self.mode)); - } if let EmitTo::First(n) = emit_to { self.group_ordering.remove_groups(n); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index fad0e3dadd8b8..250fa85cddef1 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4891,16 +4891,18 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Limit: skip=0, fetch=3 -02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -03)----TableScan: aggregate_test_100 projection=[c2, c3] +01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 +02)--Limit: skip=0, fetch=3 +03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, grouping_id@2 as grouping_id], aggr=[], lim=[3] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true +01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3] +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 3b7d0fd296105..ce6d1825cd25c 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -294,8 +294,9 @@ async fn aggregate_grouping_sets() -> Result<()> { async fn aggregate_grouping_rollup() -> Result<()> { assert_expected_plan( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ - \n TableScan: data projection=[a, b, c, e]", + "Projection: data.a, data.c, data.e, avg(data.b)\ + \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ + \n TableScan: data projection=[a, b, c, e]", true ).await } From 920b3845b4d913e7ad8289b6e7f796d7a33daa85 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 5 Oct 2024 10:17:12 +0200 Subject: [PATCH 5/6] Add doc comment next to INTERNAL_GROUPING_ID --- datafusion/expr/src/logical_plan/plan.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 5578b27de1f35..0292274e57ee3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3066,6 +3066,23 @@ impl Aggregate { } } + /// Internal column used when the aggregation is a grouping set. + /// + /// This column contains a bitmask where each bit represents a grouping + /// expression. The least significant bit corresponds to the rightmost + /// grouping expression. A bit value of 0 indicates that the corresponding + /// column is included in the grouping set, while a value of 1 means it is excluded. + /// + /// For example, for the grouping expressions CUBE(a, b), the grouping ID + /// column will have the following values: + /// 0b00: Both `a` and `b` are included + /// 0b01: `b` is excluded + /// 0b10: `a` is excluded + /// 0b11: Both `a` and `b` are excluded + /// + /// This internal column is necessary because excluded columns are replaced + /// with `NULL` values. To handle these cases correctly, we must distinguish + /// between an actual `NULL` value in a column and a column being excluded from the set. pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } From 1f61ddf0bbbc19ad30481e1932b6810fcad84a63 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 6 Oct 2024 21:15:48 +0200 Subject: [PATCH 6/6] Fix unparsing of Aggregate with grouping sets --- datafusion/sql/src/unparser/utils.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 0059aba257381..33d4104b015d0 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; + use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, @@ -166,10 +168,17 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result Ok(grouping_expr.into_iter().nth(index)), + Ordering::Equal => { + internal_err!( + "Tried to unproject column refereing to internal grouping id" + ) + } + Ordering::Greater => { + Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1)) + } + } } else { Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) }