@@ -99,10 +99,17 @@ fn replace_grouping_exprs(
9999 {
100100 match expr {
101101 Expr :: AggregateFunction ( ref function) if is_grouping_function ( & expr) => {
102+ let grouping_id_type = is_grouping_set
103+ . then ( || {
104+ schema
105+ . field_with_name ( None , Aggregate :: INTERNAL_GROUPING_ID )
106+ . map ( |f| f. data_type ( ) . clone ( ) )
107+ } )
108+ . transpose ( ) ?;
102109 let grouping_expr = grouping_function_on_id (
103110 function,
104111 & group_expr_to_bitmap_index,
105- is_grouping_set ,
112+ grouping_id_type ,
106113 ) ?;
107114 projection_exprs. push ( Expr :: Alias ( Alias :: new (
108115 grouping_expr,
@@ -184,41 +191,40 @@ fn validate_args(
184191fn grouping_function_on_id (
185192 function : & AggregateFunction ,
186193 group_by_expr : & HashMap < & Expr , usize > ,
187- is_grouping_set : bool ,
194+ // None means not a grouping set (result is always 0).
195+ grouping_id_type : Option < DataType > ,
188196) -> Result < Expr > {
189197 validate_args ( function, group_by_expr) ?;
190198 let args = & function. params . args ;
191199
192200 // Postgres allows grouping function for group by without grouping sets, the result is then
193201 // always 0
194- if !is_grouping_set {
202+ let Some ( grouping_id_type ) = grouping_id_type else {
195203 return Ok ( Expr :: Literal ( ScalarValue :: from ( 0i32 ) , None ) ) ;
196- }
204+ } ;
197205
198- let group_by_expr_count = group_by_expr. len ( ) ;
199- let literal = |value : usize | {
200- if group_by_expr_count < 8 {
201- Expr :: Literal ( ScalarValue :: from ( value as u8 ) , None )
202- } else if group_by_expr_count < 16 {
203- Expr :: Literal ( ScalarValue :: from ( value as u16 ) , None )
204- } else if group_by_expr_count < 32 {
205- Expr :: Literal ( ScalarValue :: from ( value as u32 ) , None )
206- } else {
207- Expr :: Literal ( ScalarValue :: from ( value as u64 ) , None )
208- }
206+ // Use the actual __grouping_id column type to size literals correctly. This
207+ // accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type`
208+ // packs into the high bits of the column, which a simple count of grouping
209+ // expressions would miss.
210+ let literal = |value : usize | match & grouping_id_type {
211+ DataType :: UInt8 => Expr :: Literal ( ScalarValue :: from ( value as u8 ) , None ) ,
212+ DataType :: UInt16 => Expr :: Literal ( ScalarValue :: from ( value as u16 ) , None ) ,
213+ DataType :: UInt32 => Expr :: Literal ( ScalarValue :: from ( value as u32 ) , None ) ,
214+ _ => Expr :: Literal ( ScalarValue :: from ( value as u64 ) , None ) ,
209215 } ;
210216 let grouping_id_column = Expr :: Column ( Column :: from ( Aggregate :: INTERNAL_GROUPING_ID ) ) ;
211- if args. len ( ) == group_by_expr_count
217+ if args. len ( ) == group_by_expr . len ( )
212218 && args
213219 . iter ( )
214220 . rev ( )
215221 . enumerate ( )
216222 . all ( |( idx, expr) | group_by_expr. get ( expr) == Some ( & idx) )
217223 {
218- let n = group_by_expr_count ;
224+ let n = group_by_expr . len ( ) ;
219225 // Mask the ordinal bits above position `n` so only the semantic bitmask is visible.
220- // (1 << n) - 1 masks the low n bits .
221- let semantic_mask: u64 = if n >= 64 { u64:: MAX } else { ( 1u64 << n ) - 1 } ;
226+ // checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX .
227+ let semantic_mask: u64 = 1u64 . checked_shl ( n as u32 ) . map_or ( u64:: MAX , |m| m - 1 ) ;
222228 let masked_id =
223229 bitwise_and ( grouping_id_column. clone ( ) , literal ( semantic_mask as usize ) ) ;
224230 return Ok ( cast ( masked_id, DataType :: Int32 ) ) ;
0 commit comments