Skip to content

Commit 93acd98

Browse files
authored
Merge branch 'main' into optimize_count_distinct
2 parents af053db + 4b1901f commit 93acd98

File tree

9 files changed

+752
-62
lines changed

9 files changed

+752
-62
lines changed

datafusion/core/tests/physical_optimizer/projection_pushdown.rs

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use datafusion_physical_optimizer::output_requirements::OutputRequirementExec;
4646
use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown;
4747
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
4848
use datafusion_physical_plan::coop::CooperativeExec;
49-
use datafusion_physical_plan::filter::FilterExec;
49+
use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder};
5050
use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter};
5151
use datafusion_physical_plan::joins::{
5252
HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
@@ -1754,3 +1754,121 @@ fn test_hash_join_empty_projection_embeds() -> Result<()> {
17541754

17551755
Ok(())
17561756
}
1757+
1758+
/// Regression test for <https://github.com/apache/datafusion/issues/21459>
1759+
///
1760+
/// When a `ProjectionExec` sits on top of a `FilterExec` that already carries
1761+
/// an embedded projection, the `ProjectionPushdown` optimizer must not panic.
1762+
///
1763+
/// Before the fix, `FilterExecBuilder::from(self)` copied stale projection
1764+
/// indices (e.g. `[0, 1, 2]`). After swapping, the new input was narrower
1765+
/// (2 columns), so `.build()` panicked with "project index out of bounds".
1766+
#[test]
1767+
fn test_filter_with_embedded_projection_after_projection() -> Result<()> {
1768+
// DataSourceExec: [a, b, c, d, e]
1769+
let csv = create_simple_csv_exec();
1770+
1771+
// FilterExec: a > 0, projection=[0, 1, 2] → output: [a, b, c]
1772+
let predicate = Arc::new(BinaryExpr::new(
1773+
Arc::new(Column::new("a", 0)),
1774+
Operator::Gt,
1775+
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1776+
));
1777+
let filter: Arc<dyn ExecutionPlan> = Arc::new(
1778+
FilterExecBuilder::new(predicate, csv)
1779+
.apply_projection(Some(vec![0, 1, 2]))?
1780+
.build()?,
1781+
);
1782+
1783+
// ProjectionExec: narrows [a, b, c] → [a, b]
1784+
let projection: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
1785+
vec![
1786+
ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"),
1787+
ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"),
1788+
],
1789+
filter,
1790+
)?);
1791+
1792+
let initial = displayable(projection.as_ref()).indent(true).to_string();
1793+
let actual = initial.trim();
1794+
assert_snapshot!(
1795+
actual,
1796+
@r"
1797+
ProjectionExec: expr=[a@0 as a, b@1 as b]
1798+
FilterExec: a@0 > 0, projection=[a@0, b@1, c@2]
1799+
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
1800+
"
1801+
);
1802+
1803+
// This must not panic
1804+
let after_optimize =
1805+
ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?;
1806+
let after_optimize_string = displayable(after_optimize.as_ref())
1807+
.indent(true)
1808+
.to_string();
1809+
let actual = after_optimize_string.trim();
1810+
assert_snapshot!(
1811+
actual,
1812+
@r"
1813+
FilterExec: a@0 > 0
1814+
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false
1815+
"
1816+
);
1817+
1818+
Ok(())
1819+
}
1820+
1821+
/// Same as above, but the outer ProjectionExec also renames columns.
1822+
/// Ensures the rename is preserved after the projection pushdown swap.
1823+
#[test]
1824+
fn test_filter_with_embedded_projection_after_renaming_projection() -> Result<()> {
1825+
let csv = create_simple_csv_exec();
1826+
1827+
// FilterExec: b > 10, projection=[0, 1, 2, 3] → output: [a, b, c, d]
1828+
let predicate = Arc::new(BinaryExpr::new(
1829+
Arc::new(Column::new("b", 1)),
1830+
Operator::Gt,
1831+
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
1832+
));
1833+
let filter: Arc<dyn ExecutionPlan> = Arc::new(
1834+
FilterExecBuilder::new(predicate, csv)
1835+
.apply_projection(Some(vec![0, 1, 2, 3]))?
1836+
.build()?,
1837+
);
1838+
1839+
// ProjectionExec: [a as x, b as y] — narrows and renames
1840+
let projection: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
1841+
vec![
1842+
ProjectionExpr::new(Arc::new(Column::new("a", 0)), "x"),
1843+
ProjectionExpr::new(Arc::new(Column::new("b", 1)), "y"),
1844+
],
1845+
filter,
1846+
)?);
1847+
1848+
let initial = displayable(projection.as_ref()).indent(true).to_string();
1849+
let actual = initial.trim();
1850+
assert_snapshot!(
1851+
actual,
1852+
@r"
1853+
ProjectionExec: expr=[a@0 as x, b@1 as y]
1854+
FilterExec: b@1 > 10, projection=[a@0, b@1, c@2, d@3]
1855+
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
1856+
"
1857+
);
1858+
1859+
let after_optimize =
1860+
ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?;
1861+
let after_optimize_string = displayable(after_optimize.as_ref())
1862+
.indent(true)
1863+
.to_string();
1864+
let actual = after_optimize_string.trim();
1865+
assert_snapshot!(
1866+
actual,
1867+
@r"
1868+
FilterExec: y@1 > 10
1869+
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as x, b@1 as y], file_type=csv, has_header=false
1870+
"
1871+
);
1872+
1873+
Ok(())
1874+
}

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::utils::{
4545
grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction,
4646
};
4747
use crate::{
48-
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable,
48+
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, GroupingSet,
4949
LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource,
5050
WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed,
5151
};
@@ -3595,11 +3595,12 @@ impl Aggregate {
35953595
.into_iter()
35963596
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
35973597
.collect::<Vec<_>>();
3598+
let max_ordinal = max_grouping_set_duplicate_ordinal(&group_expr);
35983599
qualified_fields.push((
35993600
None,
36003601
Field::new(
36013602
Self::INTERNAL_GROUPING_ID,
3602-
Self::grouping_id_type(qualified_fields.len()),
3603+
Self::grouping_id_type(qualified_fields.len(), max_ordinal),
36033604
false,
36043605
)
36053606
.into(),
@@ -3685,15 +3686,24 @@ impl Aggregate {
36853686
}
36863687

36873688
/// Returns the data type of the grouping id.
3688-
/// The grouping ID value is a bitmask where each set bit
3689-
/// indicates that the corresponding grouping expression is
3690-
/// null
3691-
pub fn grouping_id_type(group_exprs: usize) -> DataType {
3692-
if group_exprs <= 8 {
3689+
///
3690+
/// The grouping ID packs two pieces of information into a single integer:
3691+
/// - The low `group_exprs` bits are the semantic bitmask (a set bit means the
3692+
/// corresponding grouping expression is NULL for this grouping set).
3693+
/// - The bits above position `group_exprs` encode a duplicate ordinal that
3694+
/// distinguishes multiple occurrences of the same grouping set pattern.
3695+
///
3696+
/// `max_ordinal` is the highest ordinal value that will appear (0 when there
3697+
/// are no duplicate grouping sets). The type is chosen to be the smallest
3698+
/// unsigned integer that can represent both parts.
3699+
pub fn grouping_id_type(group_exprs: usize, max_ordinal: usize) -> DataType {
3700+
let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize;
3701+
let total_bits = group_exprs + ordinal_bits;
3702+
if total_bits <= 8 {
36933703
DataType::UInt8
3694-
} else if group_exprs <= 16 {
3704+
} else if total_bits <= 16 {
36953705
DataType::UInt16
3696-
} else if group_exprs <= 32 {
3706+
} else if total_bits <= 32 {
36973707
DataType::UInt32
36983708
} else {
36993709
DataType::UInt64
@@ -3702,21 +3712,36 @@ impl Aggregate {
37023712

37033713
/// Internal column used when the aggregation is a grouping set.
37043714
///
3705-
/// This column contains a bitmask where each bit represents a grouping
3706-
/// expression. The least significant bit corresponds to the rightmost
3707-
/// grouping expression. A bit value of 0 indicates that the corresponding
3708-
/// column is included in the grouping set, while a value of 1 means it is excluded.
3715+
/// This column packs two values into a single unsigned integer:
3716+
///
3717+
/// - **Low bits (positions 0 .. n-1)**: a semantic bitmask where each bit
3718+
/// represents one of the `n` grouping expressions. The least significant
3719+
/// bit corresponds to the rightmost grouping expression. A `1` bit means
3720+
/// the corresponding column is replaced with `NULL` for this grouping set;
3721+
/// a `0` bit means it is included.
3722+
/// - **High bits (positions n and above)**: a *duplicate ordinal* that
3723+
/// distinguishes multiple occurrences of the same semantic grouping set
3724+
/// pattern within a single query. The ordinal is `0` for the first
3725+
/// occurrence, `1` for the second, and so on.
3726+
///
3727+
/// The integer type is chosen by [`Self::grouping_id_type`] to be the
3728+
/// smallest `UInt8 / UInt16 / UInt32 / UInt64` that can represent both
3729+
/// parts.
37093730
///
3710-
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
3711-
/// column will have the following values:
3731+
/// For example, for the grouping expressions CUBE(a, b) (no duplicates),
3732+
/// the grouping ID column will have the following values:
37123733
/// 0b00: Both `a` and `b` are included
37133734
/// 0b01: `b` is excluded
37143735
/// 0b10: `a` is excluded
37153736
/// 0b11: Both `a` and `b` are excluded
37163737
///
3717-
/// This internal column is necessary because excluded columns are replaced
3718-
/// with `NULL` values. To handle these cases correctly, we must distinguish
3719-
/// between an actual `NULL` value in a column and a column being excluded from the set.
3738+
/// When the same set appears twice and `n = 2`, the duplicate ordinal is
3739+
/// packed into bit 2:
3740+
/// first occurrence: `0b0_01` (ordinal = 0, mask = 0b01)
3741+
/// second occurrence: `0b1_01` (ordinal = 1, mask = 0b01)
3742+
///
3743+
/// The GROUPING function always masks the value with `(1 << n) - 1` before
3744+
/// interpreting it so the ordinal bits are invisible to user-facing SQL.
37203745
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
37213746
}
37223747

@@ -3737,6 +3762,24 @@ impl PartialOrd for Aggregate {
37373762
}
37383763
}
37393764

3765+
/// Returns the highest duplicate ordinal across all grouping sets in `group_expr`.
3766+
///
3767+
/// The ordinal for each occurrence of a grouping set pattern is its 0-based
3768+
/// index among identical entries. For example, if the same set appears three
3769+
/// times, the ordinals are 0, 1, 2 and this function returns 2.
3770+
/// Returns 0 when no grouping set is duplicated.
3771+
fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize {
3772+
if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() {
3773+
let mut counts: HashMap<&[Expr], usize> = HashMap::new();
3774+
for set in sets {
3775+
*counts.entry(set).or_insert(0) += 1;
3776+
}
3777+
counts.into_values().max().unwrap_or(0).saturating_sub(1)
3778+
} else {
3779+
0
3780+
}
3781+
}
3782+
37403783
/// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`.
37413784
fn contains_grouping_set(group_expr: &[Expr]) -> bool {
37423785
group_expr
@@ -5053,6 +5096,14 @@ mod tests {
50535096
);
50545097
}
50555098

5099+
#[test]
5100+
fn grouping_id_type_accounts_for_duplicate_ordinal_bits() {
5101+
// 8 grouping columns fit in UInt8 when there are no duplicate ordinals,
5102+
// but adding one duplicate ordinal bit widens the type to UInt16.
5103+
assert_eq!(Aggregate::grouping_id_type(8, 0), DataType::UInt8);
5104+
assert_eq!(Aggregate::grouping_id_type(8, 1), DataType::UInt16);
5105+
}
5106+
50565107
#[test]
50575108
fn test_filter_is_scalar() {
50585109
// test empty placeholder

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ fn replace_grouping_exprs(
9999
{
100100
match expr {
101101
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
102+
let grouping_id_type = is_grouping_set
103+
.then(|| {
104+
schema
105+
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
106+
.map(|f| f.data_type().clone())
107+
})
108+
.transpose()?;
102109
let grouping_expr = grouping_function_on_id(
103110
function,
104111
&group_expr_to_bitmap_index,
105-
is_grouping_set,
112+
grouping_id_type,
106113
)?;
107114
projection_exprs.push(Expr::Alias(Alias::new(
108115
grouping_expr,
@@ -184,40 +191,44 @@ fn validate_args(
184191
fn grouping_function_on_id(
185192
function: &AggregateFunction,
186193
group_by_expr: &HashMap<&Expr, usize>,
187-
is_grouping_set: bool,
194+
// None means not a grouping set (result is always 0).
195+
grouping_id_type: Option<DataType>,
188196
) -> Result<Expr> {
189197
validate_args(function, group_by_expr)?;
190198
let args = &function.params.args;
191199

192200
// Postgres allows grouping function for group by without grouping sets, the result is then
193201
// always 0
194-
if !is_grouping_set {
202+
let Some(grouping_id_type) = grouping_id_type else {
195203
return Ok(Expr::Literal(ScalarValue::from(0i32), None));
196-
}
197-
198-
let group_by_expr_count = group_by_expr.len();
199-
let literal = |value: usize| {
200-
if group_by_expr_count < 8 {
201-
Expr::Literal(ScalarValue::from(value as u8), None)
202-
} else if group_by_expr_count < 16 {
203-
Expr::Literal(ScalarValue::from(value as u16), None)
204-
} else if group_by_expr_count < 32 {
205-
Expr::Literal(ScalarValue::from(value as u32), None)
206-
} else {
207-
Expr::Literal(ScalarValue::from(value as u64), None)
208-
}
209204
};
210205

206+
// Use the actual __grouping_id column type to size literals correctly. This
207+
// accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type`
208+
// packs into the high bits of the column, which a simple count of grouping
209+
// expressions would miss.
210+
let literal = |value: usize| match &grouping_id_type {
211+
DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
212+
DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
213+
DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
214+
DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None),
215+
other => panic!("unexpected __grouping_id type: {other}"),
216+
};
211217
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
212-
// The grouping call is exactly our internal grouping id
213-
if args.len() == group_by_expr_count
218+
if args.len() == group_by_expr.len()
214219
&& args
215220
.iter()
216221
.rev()
217222
.enumerate()
218223
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
219224
{
220-
return Ok(cast(grouping_id_column, DataType::Int32));
225+
let n = group_by_expr.len();
226+
// Mask the ordinal bits above position `n` so only the semantic bitmask is visible.
227+
// checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX.
228+
let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
229+
let masked_id =
230+
bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
231+
return Ok(cast(masked_id, DataType::Int32));
221232
}
222233

223234
args.iter()

0 commit comments

Comments
 (0)