@@ -25,7 +25,7 @@ use std::hash::Hash;
2525use std:: sync:: Arc ;
2626
2727use arrow:: array:: RecordBatch ;
28- use arrow:: datatypes:: { DataType , Field , FieldRef , SchemaRef } ;
28+ use arrow:: datatypes:: { DataType , FieldRef , SchemaRef } ;
2929use datafusion_common:: {
3030 DataFusionError , Result , ScalarValue , exec_err,
3131 metadata:: FieldMetadata ,
@@ -34,11 +34,10 @@ use datafusion_common::{
3434} ;
3535use datafusion_functions:: core:: getfield:: GetFieldFunc ;
3636use datafusion_physical_expr:: PhysicalExprSimplifier ;
37- use datafusion_physical_expr:: expressions:: CastColumnExpr ;
3837use datafusion_physical_expr:: projection:: { ProjectionExprs , Projector } ;
3938use datafusion_physical_expr:: {
4039 ScalarFunctionExpr ,
41- expressions:: { self , Column } ,
40+ expressions:: { self , CastExpr , Column } ,
4241} ;
4342use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
4443use itertools:: Itertools ;
@@ -423,13 +422,12 @@ impl DefaultPhysicalExprAdapterRewriter {
423422 ) ) ) ;
424423 } ;
425424
426- if resolved_column . index ( ) == column . index ( )
427- && logical_field == physical_field . as_ref ( )
428- {
429- return Ok ( Transformed :: no ( expr) ) ;
430- }
425+ let fields_match = logical_field == physical_field . as_ref ( ) ;
426+ if fields_match {
427+ if resolved_column . index ( ) == column . index ( ) {
428+ return Ok ( Transformed :: no ( expr) ) ;
429+ }
431430
432- if logical_field == physical_field. as_ref ( ) {
433431 // If the fields match (including metadata/nullability), we can use the column as is
434432 return Ok ( Transformed :: yes ( Arc :: new ( resolved_column) ) ) ;
435433 }
@@ -439,7 +437,25 @@ impl DefaultPhysicalExprAdapterRewriter {
439437 // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
440438 // since that's much cheaper to evalaute.
441439 // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
442- self . create_cast_column_expr ( resolved_column, physical_field, logical_field)
440+ validate_data_type_compatibility (
441+ resolved_column. name ( ) ,
442+ physical_field. data_type ( ) ,
443+ logical_field. data_type ( ) ,
444+ )
445+ . map_err ( |e| {
446+ DataFusionError :: Execution ( format ! (
447+ "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}" ,
448+ resolved_column. name( ) ,
449+ physical_field. data_type( ) ,
450+ logical_field. data_type( )
451+ ) )
452+ } ) ?;
453+
454+ Ok ( Transformed :: yes ( Arc :: new ( CastExpr :: new_with_target_field (
455+ Arc :: new ( resolved_column) ,
456+ Arc :: new ( logical_field. clone ( ) ) ,
457+ None ,
458+ ) ) ) )
443459 }
444460
445461 /// Resolves a logical column to the corresponding physical column and field.
@@ -465,48 +481,13 @@ impl DefaultPhysicalExprAdapterRewriter {
465481 Column :: new_with_schema ( column. name ( ) , self . physical_file_schema . as_ref ( ) ) ?
466482 } ;
467483
468- Ok ( Some ( (
469- column,
470- Arc :: new (
471- self . physical_file_schema
472- . field ( physical_column_index)
473- . clone ( ) ,
474- ) ,
475- ) ) )
476- }
477-
478- /// Validates type compatibility and creates a CastColumnExpr if needed.
479- ///
480- /// Checks whether the physical field can be cast to the logical field type,
481- /// handling both struct and scalar types. Returns a CastColumnExpr with the
482- /// appropriate configuration.
483- fn create_cast_column_expr (
484- & self ,
485- column : Column ,
486- physical_field : FieldRef ,
487- logical_field : & Field ,
488- ) -> Result < Transformed < Arc < dyn PhysicalExpr > > > {
489- validate_data_type_compatibility (
490- column. name ( ) ,
491- physical_field. data_type ( ) ,
492- logical_field. data_type ( ) ,
493- )
494- . map_err ( |e|
495- DataFusionError :: Execution ( format ! (
496- "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}" ,
497- column. name( ) ,
498- physical_field. data_type( ) ,
499- logical_field. data_type( )
500- ) ) ) ?;
501-
502- let cast_expr = Arc :: new ( CastColumnExpr :: new (
503- Arc :: new ( column) ,
504- physical_field,
505- Arc :: new ( logical_field. clone ( ) ) ,
506- None ,
507- ) ) ;
484+ let physical_field = Arc :: new (
485+ self . physical_file_schema
486+ . field ( physical_column_index)
487+ . clone ( ) ,
488+ ) ;
508489
509- Ok ( Transformed :: yes ( cast_expr ) )
490+ Ok ( Some ( ( column , physical_field ) ) )
510491 }
511492}
512493
@@ -652,10 +633,40 @@ mod tests {
652633 Array , BooleanArray , GenericListArray , Int32Array , Int64Array , RecordBatch ,
653634 RecordBatchOptions , StringArray , StringViewArray , StructArray ,
654635 } ;
655- use arrow:: datatypes:: { Fields , Schema } ;
636+ use arrow:: datatypes:: { Field , Fields , Schema } ;
656637 use datafusion_common:: { assert_contains, record_batch} ;
657638 use datafusion_expr:: Operator ;
658- use datafusion_physical_expr:: expressions:: { Column , Literal , col, lit} ;
639+ use datafusion_physical_expr:: expressions:: { Column , Literal , col} ;
640+
641+ fn assert_cast_expr ( expr : & Arc < dyn PhysicalExpr > ) -> & CastExpr {
642+ expr. as_any ( )
643+ . downcast_ref :: < CastExpr > ( )
644+ . expect ( "Expected CastExpr" )
645+ }
646+
647+ fn assert_cast_column ( cast_expr : & CastExpr , name : & str , index : usize ) {
648+ let inner_col = cast_expr
649+ . expr ( )
650+ . as_any ( )
651+ . downcast_ref :: < Column > ( )
652+ . expect ( "Expected inner Column" ) ;
653+ assert_eq ! ( inner_col. name( ) , name) ;
654+ assert_eq ! ( inner_col. index( ) , index) ;
655+ }
656+
657+ fn stale_index_cast_schemas ( ) -> ( SchemaRef , SchemaRef ) {
658+ let physical_schema = Arc :: new ( Schema :: new ( vec ! [
659+ Field :: new( "b" , DataType :: Binary , true ) ,
660+ Field :: new( "a" , DataType :: Int32 , false ) ,
661+ ] ) ) ;
662+
663+ let logical_schema = Arc :: new ( Schema :: new ( vec ! [
664+ Field :: new( "a" , DataType :: Int64 , false ) ,
665+ Field :: new( "b" , DataType :: Binary , true ) ,
666+ ] ) ) ;
667+
668+ ( logical_schema, physical_schema)
669+ }
659670
660671 fn create_test_schema ( ) -> ( Schema , Schema ) {
661672 let physical_schema = Schema :: new ( vec ! [
@@ -685,7 +696,7 @@ mod tests {
685696 let result = adapter. rewrite ( column_expr) . unwrap ( ) ;
686697
687698 // Should be wrapped in a cast expression
688- assert ! ( result. as_any( ) . downcast_ref:: <CastColumnExpr >( ) . is_some( ) ) ;
699+ assert ! ( result. as_any( ) . downcast_ref:: <CastExpr >( ) . is_some( ) ) ;
689700 }
690701
691702 #[ test]
@@ -702,24 +713,19 @@ mod tests {
702713 . unwrap ( ) ;
703714
704715 let result = adapter. rewrite ( Arc :: new ( Column :: new ( "a" , 0 ) ) ) ?;
705- let cast = result
706- . as_any ( )
707- . downcast_ref :: < CastColumnExpr > ( )
708- . expect ( "Expected CastColumnExpr" ) ;
709716
710- assert_eq ! ( cast. target_field( ) . data_type( ) , & DataType :: Int64 ) ;
711- assert ! ( !cast. target_field( ) . is_nullable( ) ) ;
717+ // Ensure the expression preserves the logical field nullability/metadata.
718+ let return_field = result. return_field ( physical_schema. as_ref ( ) ) ?;
719+ assert_eq ! ( return_field. data_type( ) , & DataType :: Int64 ) ;
720+ assert ! ( !return_field. is_nullable( ) ) ;
712721 assert_eq ! (
713- cast . target_field ( )
722+ return_field
714723 . metadata( )
715724 . get( "logical_meta" )
716725 . map( String :: as_str) ,
717726 Some ( "1" )
718727 ) ;
719728
720- // Ensure the expression reports the logical nullability regardless of input schema
721- assert ! ( !result. nullable( physical_schema. as_ref( ) ) ?) ;
722-
723729 Ok ( ( ) )
724730 }
725731
@@ -750,33 +756,35 @@ mod tests {
750756 ) ;
751757
752758 let result = adapter. rewrite ( Arc :: new ( expr) ) . unwrap ( ) ;
753- println ! ( "Rewritten expression: {result}" ) ;
754-
755- let expected = expressions:: BinaryExpr :: new (
756- Arc :: new ( CastColumnExpr :: new (
757- Arc :: new ( Column :: new ( "a" , 0 ) ) ,
758- Arc :: new ( Field :: new ( "a" , DataType :: Int32 , false ) ) ,
759- Arc :: new ( Field :: new ( "a" , DataType :: Int64 , false ) ) ,
760- None ,
761- ) ) ,
762- Operator :: Plus ,
763- Arc :: new ( Literal :: new ( ScalarValue :: Int64 ( Some ( 5 ) ) ) ) ,
764- ) ;
765- let expected = Arc :: new ( expressions:: BinaryExpr :: new (
766- Arc :: new ( expected) ,
767- Operator :: Or ,
768- Arc :: new ( expressions:: BinaryExpr :: new (
769- lit ( ScalarValue :: Float64 ( None ) ) , // c is missing, so it becomes null
770- Operator :: Gt ,
771- Arc :: new ( Literal :: new ( ScalarValue :: Float64 ( Some ( 0.0 ) ) ) ) ,
772- ) ) ,
773- ) ) as Arc < dyn PhysicalExpr > ;
759+ let outer = result
760+ . as_any ( )
761+ . downcast_ref :: < expressions:: BinaryExpr > ( )
762+ . expect ( "Expected outer BinaryExpr" ) ;
763+ assert_eq ! ( * outer. op( ) , Operator :: Or ) ;
774764
775- assert_eq ! (
776- result. to_string( ) ,
777- expected. to_string( ) ,
778- "The rewritten expression did not match the expected output"
779- ) ;
765+ let left = outer
766+ . left ( )
767+ . as_any ( )
768+ . downcast_ref :: < expressions:: BinaryExpr > ( )
769+ . expect ( "Expected left BinaryExpr" ) ;
770+ assert_eq ! ( * left. op( ) , Operator :: Plus ) ;
771+
772+ let left_cast = assert_cast_expr ( left. left ( ) ) ;
773+ assert_eq ! ( left_cast. target_field( ) . data_type( ) , & DataType :: Int64 ) ;
774+ assert_cast_column ( left_cast, "a" , 0 ) ;
775+
776+ let right = outer
777+ . right ( )
778+ . as_any ( )
779+ . downcast_ref :: < expressions:: BinaryExpr > ( )
780+ . expect ( "Expected right BinaryExpr" ) ;
781+ assert_eq ! ( * right. op( ) , Operator :: Gt ) ;
782+ let null_literal = right
783+ . left ( )
784+ . as_any ( )
785+ . downcast_ref :: < Literal > ( )
786+ . expect ( "Expected null literal" ) ;
787+ assert_eq ! ( * null_literal. value( ) , ScalarValue :: Float64 ( None ) ) ;
780788 }
781789
782790 #[ test]
@@ -841,17 +849,6 @@ mod tests {
841849
842850 let result = adapter. rewrite ( column_expr) . unwrap ( ) ;
843851
844- let physical_struct_fields: Fields = vec ! [
845- Field :: new( "id" , DataType :: Int32 , false ) ,
846- Field :: new( "name" , DataType :: Utf8 , true ) ,
847- ]
848- . into ( ) ;
849- let physical_field = Arc :: new ( Field :: new (
850- "data" ,
851- DataType :: Struct ( physical_struct_fields) ,
852- false ,
853- ) ) ;
854-
855852 let logical_struct_fields: Fields = vec ! [
856853 Field :: new( "id" , DataType :: Int64 , false ) ,
857854 Field :: new( "name" , DataType :: Utf8View , true ) ,
@@ -863,9 +860,8 @@ mod tests {
863860 false ,
864861 ) ) ;
865862
866- let expected = Arc :: new ( CastColumnExpr :: new (
863+ let expected = Arc :: new ( CastExpr :: new_with_target_field (
867864 Arc :: new ( Column :: new ( "data" , 0 ) ) ,
868- physical_field,
869865 logical_field,
870866 None ,
871867 ) ) as Arc < dyn PhysicalExpr > ;
@@ -1663,8 +1659,7 @@ mod tests {
16631659 Field :: new( "b" , DataType :: Utf8 , true ) ,
16641660 ] ) ;
16651661
1666- let factory = DefaultPhysicalExprAdapterFactory ;
1667- let adapter = factory
1662+ let adapter = DefaultPhysicalExprAdapterFactory
16681663 . create ( Arc :: new ( logical_schema) , Arc :: new ( physical_schema) )
16691664 . unwrap ( ) ;
16701665
@@ -1673,20 +1668,11 @@ mod tests {
16731668
16741669 let result = adapter. rewrite ( column_expr) . unwrap ( ) ;
16751670
1676- // Should be a CastColumnExpr
1677- let cast_expr = result
1678- . as_any ( )
1679- . downcast_ref :: < CastColumnExpr > ( )
1680- . expect ( "Expected CastColumnExpr" ) ;
1671+ // Should be a CastExpr
1672+ let cast_expr = assert_cast_expr ( & result) ;
16811673
16821674 // Verify the inner column points to the correct physical index (1)
1683- let inner_col = cast_expr
1684- . expr ( )
1685- . as_any ( )
1686- . downcast_ref :: < Column > ( )
1687- . expect ( "Expected inner Column" ) ;
1688- assert_eq ! ( inner_col. name( ) , "a" ) ;
1689- assert_eq ! ( inner_col. index( ) , 1 ) ; // Physical index is 1
1675+ assert_cast_column ( cast_expr, "a" , 1 ) ;
16901676
16911677 // Verify cast types
16921678 assert_eq ! (
@@ -1696,41 +1682,17 @@ mod tests {
16961682 }
16971683
16981684 #[ test]
1699- fn test_create_cast_column_expr_uses_name_lookup_not_column_index ( ) {
1700- // Physical schema has column `a` at index 1; index 0 is an incompatible type.
1701- let physical_schema = Arc :: new ( Schema :: new ( vec ! [
1702- Field :: new( "b" , DataType :: Binary , true ) ,
1703- Field :: new( "a" , DataType :: Int32 , false ) ,
1704- ] ) ) ;
1705-
1706- let logical_schema = Arc :: new ( Schema :: new ( vec ! [
1707- Field :: new( "a" , DataType :: Int64 , false ) ,
1708- Field :: new( "b" , DataType :: Binary , true ) ,
1709- ] ) ) ;
1710-
1711- let rewriter = DefaultPhysicalExprAdapterRewriter {
1712- logical_file_schema : Arc :: clone ( & logical_schema) ,
1713- physical_file_schema : Arc :: clone ( & physical_schema) ,
1714- } ;
1685+ fn test_rewrite_resolves_physical_column_by_name_before_casting ( ) {
1686+ let ( logical_schema, physical_schema) = stale_index_cast_schemas ( ) ;
1687+ let adapter = DefaultPhysicalExprAdapterFactory
1688+ . create ( logical_schema, physical_schema)
1689+ . unwrap ( ) ;
17151690
17161691 // Deliberately provide the wrong index for column `a`.
17171692 // Regression: this must still resolve against physical field `a` by name.
1718- let transformed = rewriter
1719- . create_cast_column_expr (
1720- Column :: new ( "a" , 0 ) ,
1721- Arc :: new ( physical_schema. field_with_name ( "a" ) . unwrap ( ) . clone ( ) ) ,
1722- logical_schema. field_with_name ( "a" ) . unwrap ( ) ,
1723- )
1724- . unwrap ( ) ;
1725-
1726- let cast_expr = transformed
1727- . data
1728- . as_any ( )
1729- . downcast_ref :: < CastColumnExpr > ( )
1730- . expect ( "Expected CastColumnExpr" ) ;
1731-
1732- assert_eq ! ( cast_expr. input_field( ) . name( ) , "a" ) ;
1733- assert_eq ! ( cast_expr. input_field( ) . data_type( ) , & DataType :: Int32 ) ;
1693+ let rewritten = adapter. rewrite ( Arc :: new ( Column :: new ( "a" , 0 ) ) ) . unwrap ( ) ;
1694+ let cast_expr = assert_cast_expr ( & rewritten) ;
1695+ assert_cast_column ( cast_expr, "a" , 1 ) ;
17341696 assert_eq ! ( cast_expr. target_field( ) . data_type( ) , & DataType :: Int64 ) ;
17351697 }
17361698}
0 commit comments