Skip to content

Commit 97114a4

Browse files
committed
Addressed comments
1 parent f6ef0fa commit 97114a4

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ fn replace_grouping_exprs(
9999
{
100100
match expr {
101101
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
102+
let grouping_id_type = is_grouping_set
103+
.then(|| {
104+
schema
105+
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
106+
.map(|f| f.data_type().clone())
107+
})
108+
.transpose()?;
102109
let grouping_expr = grouping_function_on_id(
103110
function,
104111
&group_expr_to_bitmap_index,
105-
is_grouping_set,
112+
grouping_id_type,
106113
)?;
107114
projection_exprs.push(Expr::Alias(Alias::new(
108115
grouping_expr,
@@ -184,41 +191,40 @@ fn validate_args(
184191
fn grouping_function_on_id(
185192
function: &AggregateFunction,
186193
group_by_expr: &HashMap<&Expr, usize>,
187-
is_grouping_set: bool,
194+
// None means not a grouping set (result is always 0).
195+
grouping_id_type: Option<DataType>,
188196
) -> Result<Expr> {
189197
validate_args(function, group_by_expr)?;
190198
let args = &function.params.args;
191199

192200
// Postgres allows grouping function for group by without grouping sets, the result is then
193201
// always 0
194-
if !is_grouping_set {
202+
let Some(grouping_id_type) = grouping_id_type else {
195203
return Ok(Expr::Literal(ScalarValue::from(0i32), None));
196-
}
204+
};
197205

198-
let group_by_expr_count = group_by_expr.len();
199-
let literal = |value: usize| {
200-
if group_by_expr_count < 8 {
201-
Expr::Literal(ScalarValue::from(value as u8), None)
202-
} else if group_by_expr_count < 16 {
203-
Expr::Literal(ScalarValue::from(value as u16), None)
204-
} else if group_by_expr_count < 32 {
205-
Expr::Literal(ScalarValue::from(value as u32), None)
206-
} else {
207-
Expr::Literal(ScalarValue::from(value as u64), None)
208-
}
206+
// Use the actual __grouping_id column type to size literals correctly. This
207+
// accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type`
208+
// packs into the high bits of the column, which a simple count of grouping
209+
// expressions would miss.
210+
let literal = |value: usize| match &grouping_id_type {
211+
DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
212+
DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
213+
DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
214+
_ => Expr::Literal(ScalarValue::from(value as u64), None),
209215
};
210216
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
211-
if args.len() == group_by_expr_count
217+
if args.len() == group_by_expr.len()
212218
&& args
213219
.iter()
214220
.rev()
215221
.enumerate()
216222
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
217223
{
218-
let n = group_by_expr_count;
224+
let n = group_by_expr.len();
219225
// Mask the ordinal bits above position `n` so only the semantic bitmask is visible.
220-
// (1 << n) - 1 masks the low n bits.
221-
let semantic_mask: u64 = if n >= 64 { u64::MAX } else { (1u64 << n) - 1 };
226+
// checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX.
227+
let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
222228
let masked_id =
223229
bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
224230
return Ok(cast(masked_id, DataType::Int32));

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,7 @@ pub fn evaluate_group_by(
20272027
batch: &RecordBatch,
20282028
) -> Result<Vec<Vec<ArrayRef>>> {
20292029
let max_ordinal = max_duplicate_ordinal(&group_by.groups);
2030-
let mut ordinal_per_pattern: HashMap<&Vec<bool>, usize> = HashMap::new();
2030+
let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new();
20312031
let exprs = evaluate_expressions_to_arrays(
20322032
group_by.expr.iter().map(|(expr, _)| expr),
20332033
batch,

0 commit comments

Comments
 (0)