Skip to content

Commit 4aed81a

Browse files
xiedeyantualamb
andauthored
fix: preserve duplicate GROUPING SETS rows (#21058)
## Which issue does this PR close? - Closes #21316. ## Rationale for this change `GROUPING SETS` with duplicate grouping lists were incorrectly collapsed during execution. The internal grouping id only encoded the semantic null mask, so repeated grouping sets shared the same execution key and were merged, which caused rows to be lost compared with PostgreSQL behavior. For example, with: ```sql create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int); insert into duplicate_grouping_sets values (10, 'CLERK', 1300, null), (20, 'MANAGER', 3000, null); select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal) from duplicate_grouping_sets group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal); ``` PostgreSQL preserves the duplicate grouping set and returns: ```text deptno | job | sal | sum | grouping | grouping | grouping --------+---------+------+-----+----------+----------+---------- 10 | CLERK | | | 0 | 0 | 1 10 | CLERK | | | 0 | 0 | 1 10 | | 1300 | | 0 | 1 | 0 20 | MANAGER | | | 0 | 0 | 1 20 | MANAGER | | | 0 | 0 | 1 20 | | 3000 | | 0 | 1 | 0 (6 rows) ``` Before this fix, DataFusion collapsed the duplicate `(deptno, job)` grouping set and returned only 4 rows for the same query shape. ```text +--------+---------+------+-----------------------------------+------------------------------------------+---------------------------------------+---------------------------------------+ | deptno | job | sal | sum(duplicate_grouping_sets.comm) | grouping(duplicate_grouping_sets.deptno) | grouping(duplicate_grouping_sets.job) | grouping(duplicate_grouping_sets.sal) | +--------+---------+------+-----------------------------------+------------------------------------------+---------------------------------------+---------------------------------------+ | 10 | CLERK | NULL | NULL | 0 | 0 | 1 | | 10 | NULL | 1300 | NULL | 0 | 1 | 0 | | 20 | MANAGER | NULL | NULL | 0 | 0 | 1 | | 20 | NULL | 3000 | NULL | 0 | 1 | 0 | +--------+---------+------+-----------------------------------+------------------------------------------+---------------------------------------+---------------------------------------+ ``` ## What changes are included in this PR? - Preserve duplicate grouping sets by packing a duplicate ordinal into the high bits of `__grouping_id`, so repeated occurrences of the same grouping set pattern produce distinct execution keys. - `GROUPING()` now reads the actual `__grouping_id` column type directly from the schema (via `Aggregate::grouping_id_type` rather than inferring bit width from the count of grouping expressions alone. This ensures bitmask literals are correctly sized when duplicate-ordinal bits widen the column type beyond what the expression count would imply. - `GROUPING()` masks off the ordinal bits before returning the result, so the duplicate-ordinal encoding is invisible to user-facing SQL and semantics remain unchanged. - Add regression coverage for the duplicate `GROUPING SETS` case in: - `datafusion/core/tests/sql/aggregates/basic.rs` - `datafusion/sqllogictest/test_files/group_by.slt` ## Are these changes tested? - `cargo fmt --all` - `cargo test -p datafusion duplicate_grouping_sets_are_preserved` - `cargo test -p datafusion-physical-plan grouping_sets_preserve_duplicate_groups` - `cargo test -p datafusion-physical-plan evaluate_group_by_supports_duplicate_grouping_sets_with_eight_columns` - PostgreSQL validation against the same query/result shape ## Are there any user-facing changes? - Yes. Queries that contain duplicate `GROUPING SETS` entries now return the correct duplicated result rows, matching PostgreSQL behavior. --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 91c2e04 commit 4aed81a

File tree

4 files changed

+212
-50
lines changed

4 files changed

+212
-50
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::utils::{
4545
grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction,
4646
};
4747
use crate::{
48-
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable,
48+
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, GroupingSet,
4949
LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource,
5050
WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed,
5151
};
@@ -3595,11 +3595,12 @@ impl Aggregate {
35953595
.into_iter()
35963596
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
35973597
.collect::<Vec<_>>();
3598+
let max_ordinal = max_grouping_set_duplicate_ordinal(&group_expr);
35983599
qualified_fields.push((
35993600
None,
36003601
Field::new(
36013602
Self::INTERNAL_GROUPING_ID,
3602-
Self::grouping_id_type(qualified_fields.len()),
3603+
Self::grouping_id_type(qualified_fields.len(), max_ordinal),
36033604
false,
36043605
)
36053606
.into(),
@@ -3685,15 +3686,24 @@ impl Aggregate {
36853686
}
36863687

36873688
/// Returns the data type of the grouping id.
3688-
/// The grouping ID value is a bitmask where each set bit
3689-
/// indicates that the corresponding grouping expression is
3690-
/// null
3691-
pub fn grouping_id_type(group_exprs: usize) -> DataType {
3692-
if group_exprs <= 8 {
3689+
///
3690+
/// The grouping ID packs two pieces of information into a single integer:
3691+
/// - The low `group_exprs` bits are the semantic bitmask (a set bit means the
3692+
/// corresponding grouping expression is NULL for this grouping set).
3693+
/// - The bits above position `group_exprs` encode a duplicate ordinal that
3694+
/// distinguishes multiple occurrences of the same grouping set pattern.
3695+
///
3696+
/// `max_ordinal` is the highest ordinal value that will appear (0 when there
3697+
/// are no duplicate grouping sets). The type is chosen to be the smallest
3698+
/// unsigned integer that can represent both parts.
3699+
pub fn grouping_id_type(group_exprs: usize, max_ordinal: usize) -> DataType {
3700+
let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize;
3701+
let total_bits = group_exprs + ordinal_bits;
3702+
if total_bits <= 8 {
36933703
DataType::UInt8
3694-
} else if group_exprs <= 16 {
3704+
} else if total_bits <= 16 {
36953705
DataType::UInt16
3696-
} else if group_exprs <= 32 {
3706+
} else if total_bits <= 32 {
36973707
DataType::UInt32
36983708
} else {
36993709
DataType::UInt64
@@ -3702,21 +3712,36 @@ impl Aggregate {
37023712

37033713
/// Internal column used when the aggregation is a grouping set.
37043714
///
3705-
/// This column contains a bitmask where each bit represents a grouping
3706-
/// expression. The least significant bit corresponds to the rightmost
3707-
/// grouping expression. A bit value of 0 indicates that the corresponding
3708-
/// column is included in the grouping set, while a value of 1 means it is excluded.
3715+
/// This column packs two values into a single unsigned integer:
3716+
///
3717+
/// - **Low bits (positions 0 .. n-1)**: a semantic bitmask where each bit
3718+
/// represents one of the `n` grouping expressions. The least significant
3719+
/// bit corresponds to the rightmost grouping expression. A `1` bit means
3720+
/// the corresponding column is replaced with `NULL` for this grouping set;
3721+
/// a `0` bit means it is included.
3722+
/// - **High bits (positions n and above)**: a *duplicate ordinal* that
3723+
/// distinguishes multiple occurrences of the same semantic grouping set
3724+
/// pattern within a single query. The ordinal is `0` for the first
3725+
/// occurrence, `1` for the second, and so on.
3726+
///
3727+
/// The integer type is chosen by [`Self::grouping_id_type`] to be the
3728+
/// smallest `UInt8 / UInt16 / UInt32 / UInt64` that can represent both
3729+
/// parts.
37093730
///
3710-
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
3711-
/// column will have the following values:
3731+
/// For example, for the grouping expressions CUBE(a, b) (no duplicates),
3732+
/// the grouping ID column will have the following values:
37123733
/// 0b00: Both `a` and `b` are included
37133734
/// 0b01: `b` is excluded
37143735
/// 0b10: `a` is excluded
37153736
/// 0b11: Both `a` and `b` are excluded
37163737
///
3717-
/// This internal column is necessary because excluded columns are replaced
3718-
/// with `NULL` values. To handle these cases correctly, we must distinguish
3719-
/// between an actual `NULL` value in a column and a column being excluded from the set.
3738+
/// When the same set appears twice and `n = 2`, the duplicate ordinal is
3739+
/// packed into bit 2:
3740+
/// first occurrence: `0b0_01` (ordinal = 0, mask = 0b01)
3741+
/// second occurrence: `0b1_01` (ordinal = 1, mask = 0b01)
3742+
///
3743+
/// The GROUPING function always masks the value with `(1 << n) - 1` before
3744+
/// interpreting it so the ordinal bits are invisible to user-facing SQL.
37203745
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
37213746
}
37223747

@@ -3737,6 +3762,24 @@ impl PartialOrd for Aggregate {
37373762
}
37383763
}
37393764

3765+
/// Returns the highest duplicate ordinal across all grouping sets in `group_expr`.
3766+
///
3767+
/// The ordinal for each occurrence of a grouping set pattern is its 0-based
3768+
/// index among identical entries. For example, if the same set appears three
3769+
/// times, the ordinals are 0, 1, 2 and this function returns 2.
3770+
/// Returns 0 when no grouping set is duplicated.
3771+
fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize {
3772+
if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() {
3773+
let mut counts: HashMap<&[Expr], usize> = HashMap::new();
3774+
for set in sets {
3775+
*counts.entry(set).or_insert(0) += 1;
3776+
}
3777+
counts.into_values().max().unwrap_or(0).saturating_sub(1)
3778+
} else {
3779+
0
3780+
}
3781+
}
3782+
37403783
/// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`.
37413784
fn contains_grouping_set(group_expr: &[Expr]) -> bool {
37423785
group_expr
@@ -5053,6 +5096,14 @@ mod tests {
50535096
);
50545097
}
50555098

5099+
#[test]
5100+
fn grouping_id_type_accounts_for_duplicate_ordinal_bits() {
5101+
// 8 grouping columns fit in UInt8 when there are no duplicate ordinals,
5102+
// but adding one duplicate ordinal bit widens the type to UInt16.
5103+
assert_eq!(Aggregate::grouping_id_type(8, 0), DataType::UInt8);
5104+
assert_eq!(Aggregate::grouping_id_type(8, 1), DataType::UInt16);
5105+
}
5106+
50565107
#[test]
50575108
fn test_filter_is_scalar() {
50585109
// test empty placeholder

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ fn replace_grouping_exprs(
9999
{
100100
match expr {
101101
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
102+
let grouping_id_type = is_grouping_set
103+
.then(|| {
104+
schema
105+
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
106+
.map(|f| f.data_type().clone())
107+
})
108+
.transpose()?;
102109
let grouping_expr = grouping_function_on_id(
103110
function,
104111
&group_expr_to_bitmap_index,
105-
is_grouping_set,
112+
grouping_id_type,
106113
)?;
107114
projection_exprs.push(Expr::Alias(Alias::new(
108115
grouping_expr,
@@ -184,40 +191,44 @@ fn validate_args(
184191
fn grouping_function_on_id(
185192
function: &AggregateFunction,
186193
group_by_expr: &HashMap<&Expr, usize>,
187-
is_grouping_set: bool,
194+
// None means not a grouping set (result is always 0).
195+
grouping_id_type: Option<DataType>,
188196
) -> Result<Expr> {
189197
validate_args(function, group_by_expr)?;
190198
let args = &function.params.args;
191199

192200
// Postgres allows grouping function for group by without grouping sets, the result is then
193201
// always 0
194-
if !is_grouping_set {
202+
let Some(grouping_id_type) = grouping_id_type else {
195203
return Ok(Expr::Literal(ScalarValue::from(0i32), None));
196-
}
197-
198-
let group_by_expr_count = group_by_expr.len();
199-
let literal = |value: usize| {
200-
if group_by_expr_count < 8 {
201-
Expr::Literal(ScalarValue::from(value as u8), None)
202-
} else if group_by_expr_count < 16 {
203-
Expr::Literal(ScalarValue::from(value as u16), None)
204-
} else if group_by_expr_count < 32 {
205-
Expr::Literal(ScalarValue::from(value as u32), None)
206-
} else {
207-
Expr::Literal(ScalarValue::from(value as u64), None)
208-
}
209204
};
210205

206+
// Use the actual __grouping_id column type to size literals correctly. This
207+
// accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type`
208+
// packs into the high bits of the column, which a simple count of grouping
209+
// expressions would miss.
210+
let literal = |value: usize| match &grouping_id_type {
211+
DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
212+
DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
213+
DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
214+
DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None),
215+
other => panic!("unexpected __grouping_id type: {other}"),
216+
};
211217
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
212-
// The grouping call is exactly our internal grouping id
213-
if args.len() == group_by_expr_count
218+
if args.len() == group_by_expr.len()
214219
&& args
215220
.iter()
216221
.rev()
217222
.enumerate()
218223
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
219224
{
220-
return Ok(cast(grouping_id_column, DataType::Int32));
225+
let n = group_by_expr.len();
226+
// Mask the ordinal bits above position `n` so only the semantic bitmask is visible.
227+
// checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX.
228+
let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
229+
let masked_id =
230+
bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
231+
return Ok(cast(masked_id, DataType::Int32));
221232
}
222233

223234
args.iter()

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::{
3737
use datafusion_common::config::ConfigOptions;
3838
use datafusion_physical_expr::utils::collect_columns;
3939
use parking_lot::Mutex;
40-
use std::collections::HashSet;
40+
use std::collections::{HashMap, HashSet};
4141

4242
use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
4343
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -396,6 +396,15 @@ impl PhysicalGroupBy {
396396
self.expr.len() + usize::from(self.has_grouping_set)
397397
}
398398

399+
/// Returns the Arrow data type of the `__grouping_id` column.
400+
///
401+
/// The type is chosen to be wide enough to hold both the semantic bitmask
402+
/// (in the low `n` bits, where `n` is the number of grouping expressions)
403+
/// and the duplicate ordinal (in the high bits).
404+
fn grouping_id_data_type(&self) -> DataType {
405+
Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups))
406+
}
407+
399408
pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
400409
Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
401410
}
@@ -420,7 +429,7 @@ impl PhysicalGroupBy {
420429
fields.push(
421430
Field::new(
422431
Aggregate::INTERNAL_GROUPING_ID,
423-
Aggregate::grouping_id_type(self.expr.len()),
432+
self.grouping_id_data_type(),
424433
false,
425434
)
426435
.into(),
@@ -2039,27 +2048,72 @@ fn evaluate_optional(
20392048
.collect()
20402049
}
20412050

2042-
fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
2043-
if group.len() > 64 {
2051+
/// Builds the internal `__grouping_id` array for a single grouping set.
2052+
///
2053+
/// The returned array packs two values into a single integer:
2054+
///
2055+
/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask. A `1` bit
2056+
/// at position `i` means that the `i`-th grouping column (counting from the
2057+
/// least significant bit, i.e. the *last* column in the `group` slice) is
2058+
/// `NULL` for this grouping set.
2059+
/// - High bits (positions n and above): the duplicate `ordinal`, which
2060+
/// distinguishes multiple occurrences of the same grouping-set pattern. The
2061+
/// ordinal is `0` for the first occurrence, `1` for the second, and so on.
2062+
///
2063+
/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
2064+
/// UInt64` that can represent both parts. It matches the type returned by
2065+
/// [`Aggregate::grouping_id_type`].
2066+
fn group_id_array(
2067+
group: &[bool],
2068+
ordinal: usize,
2069+
max_ordinal: usize,
2070+
batch: &RecordBatch,
2071+
) -> Result<ArrayRef> {
2072+
let n = group.len();
2073+
if n > 64 {
20442074
return not_impl_err!(
20452075
"Grouping sets with more than 64 columns are not supported"
20462076
);
20472077
}
2048-
let group_id = group.iter().fold(0u64, |acc, &is_null| {
2078+
let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize;
2079+
let total_bits = n + ordinal_bits;
2080+
if total_bits > 64 {
2081+
return not_impl_err!(
2082+
"Grouping sets with {n} columns and a maximum duplicate ordinal of \
2083+
{max_ordinal} require {total_bits} bits, which exceeds 64"
2084+
);
2085+
}
2086+
let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
20492087
(acc << 1) | if is_null { 1 } else { 0 }
20502088
});
2089+
let full_id = semantic_id | ((ordinal as u64) << n);
20512090
let num_rows = batch.num_rows();
2052-
if group.len() <= 8 {
2053-
Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
2054-
} else if group.len() <= 16 {
2055-
Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
2056-
} else if group.len() <= 32 {
2057-
Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
2091+
if total_bits <= 8 {
2092+
Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
2093+
} else if total_bits <= 16 {
2094+
Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
2095+
} else if total_bits <= 32 {
2096+
Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
20582097
} else {
2059-
Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
2098+
Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
20602099
}
20612100
}
20622101

2102+
/// Returns the highest duplicate ordinal across all grouping sets.
2103+
///
2104+
/// At the call-site, the ordinal is the 0-based index assigned to each
2105+
/// occurrence of a repeated grouping-set pattern: the first occurrence gets
2106+
/// ordinal 0, the second gets 1, and so on. If the same `Vec<bool>` appears
2107+
/// three times the ordinals are 0, 1, 2 and this function returns 2.
2108+
/// Returns 0 when no grouping set is duplicated.
2109+
fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
2110+
let mut counts: HashMap<&[bool], usize> = HashMap::new();
2111+
for group in groups {
2112+
*counts.entry(group).or_insert(0) += 1;
2113+
}
2114+
counts.into_values().max().unwrap_or(0).saturating_sub(1)
2115+
}
2116+
20632117
/// Evaluate a group by expression against a `RecordBatch`
20642118
///
20652119
/// Arguments:
@@ -2074,6 +2128,8 @@ pub fn evaluate_group_by(
20742128
group_by: &PhysicalGroupBy,
20752129
batch: &RecordBatch,
20762130
) -> Result<Vec<Vec<ArrayRef>>> {
2131+
let max_ordinal = max_duplicate_ordinal(&group_by.groups);
2132+
let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new();
20772133
let exprs = evaluate_expressions_to_arrays(
20782134
group_by.expr.iter().map(|(expr, _)| expr),
20792135
batch,
@@ -2087,6 +2143,10 @@ pub fn evaluate_group_by(
20872143
.groups
20882144
.iter()
20892145
.map(|group| {
2146+
let ordinal = ordinal_per_pattern.entry(group).or_insert(0);
2147+
let current_ordinal = *ordinal;
2148+
*ordinal += 1;
2149+
20902150
let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
20912151
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
20922152
if *is_null {
@@ -2096,7 +2156,12 @@ pub fn evaluate_group_by(
20962156
}
20972157
}));
20982158
if !group_by.is_single() {
2099-
group_values.push(group_id_array(group, batch)?);
2159+
group_values.push(group_id_array(
2160+
group,
2161+
current_ordinal,
2162+
max_ordinal,
2163+
batch,
2164+
)?);
21002165
}
21012166
Ok(group_values)
21022167
})

0 commit comments

Comments
 (0)