Skip to content

Commit 03037cd

Browse files
Dandandanclaude
andcommitted
fix: propagate column statistics through CAST in join key expressions
When join keys contain CAST expressions (e.g. CAST(id AS Float64)), the cardinality estimator could not extract column statistics because it only handled plain Column references. This caused unknown stats, leading to poor join ordering (e.g. putting a 1.4M-row fact table on the hash join build side instead of a 5-row dimension table). Extract the underlying column index through numeric CAST expressions, since casting can only reduce (never increase) distinct count, making the source column's stats a valid upper bound. TPC-DS Q99: 10.4s → ~60ms. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5ba06ac commit 03037cd

File tree

1 file changed

+26
-7
lines changed
  • datafusion/physical-plan/src/joins

1 file changed

+26
-7
lines changed

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,28 @@ pub(crate) fn estimate_join_statistics(
460460
})
461461
}
462462

463+
/// Extract the column index from a join key expression for statistics lookup.
464+
/// Handles plain `Column` references and `CAST(column AS numeric_type)`
465+
/// expressions. Casting can only merge values (many-to-one), never split
466+
/// them, so the source column's distinct count is always a valid upper
467+
/// bound for the cast result's distinct count.
468+
fn column_index_for_stats(expr: &Arc<dyn PhysicalExpr>) -> Option<usize> {
469+
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
470+
return Some(col.index());
471+
}
472+
if let Some(cast) = expr
473+
.as_any()
474+
.downcast_ref::<datafusion_physical_expr::expressions::CastExpr>()
475+
&& let Some(col) = cast.expr.as_any().downcast_ref::<Column>()
476+
{
477+
let target = cast.cast_type();
478+
if target.is_numeric() {
479+
return Some(col.index());
480+
}
481+
}
482+
None
483+
}
484+
463485
// Estimate the cardinality for the given join with input statistics.
464486
fn estimate_join_cardinality(
465487
join_type: &JoinType,
@@ -470,13 +492,10 @@ fn estimate_join_cardinality(
470492
let (left_col_stats, right_col_stats) = on
471493
.iter()
472494
.map(|(left, right)| {
473-
match (
474-
left.as_any().downcast_ref::<Column>(),
475-
right.as_any().downcast_ref::<Column>(),
476-
) {
477-
(Some(left), Some(right)) => (
478-
left_stats.column_statistics[left.index()].clone(),
479-
right_stats.column_statistics[right.index()].clone(),
495+
match (column_index_for_stats(left), column_index_for_stats(right)) {
496+
(Some(left_idx), Some(right_idx)) => (
497+
left_stats.column_statistics[left_idx].clone(),
498+
right_stats.column_statistics[right_idx].clone(),
480499
),
481500
_ => (
482501
ColumnStatistics::new_unknown(),

0 commit comments

Comments
 (0)