Skip to content

Commit b46634c

Browse files
authored
fix: PostgreSQL dialect can not support tinyint type (#21445)
## Which issue does this PR close? - No linked issue. ## Rationale for this change DataFusion's PostgreSQL unparser should emit PostgreSQL-compatible SQL for integer casts.`Int8` was still being rendered as `TINYINT`, which is not valid PostgreSQL syntax. This change makes PostgreSQL output `SMALLINT` instead. ## What changes are included in this PR? - Added an `int8_cast_dtype` hook to the SQL unparser dialect abstraction. - Updated `PostgreSqlDialect` to map `Int8` to `SMALLINT`. - Routed `DataType::Int8` unparsing through the dialect hook. - Added a regression test covering `select cast(3 as tinyint)]` with PostgreSQL unparser output. ## Are these changes tested? - Yes. Ran: - `cargo test -p datafusion-sql --test sql_integration cases::plan_to_sql::test_cast_to_tinyint -- --exact` - The test passes. ## Are there any user-facing changes? - Yes. PostgreSQL dialect SQL generation now renders `CAST(... AS SMALLINT)` for `Int8` values instead of `CAST(... AS TINYINT)`.
1 parent 4389f14 commit b46634c

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ pub trait Dialect: Send + Sync {
100100
ast::DataType::BigInt(None)
101101
}
102102

103+
/// The SQL type to use for Arrow Int8 unparsing
104+
/// Most dialects use TinyInt, but PostgreSQL prefers SmallInt
105+
fn int8_cast_dtype(&self) -> ast::DataType {
106+
ast::DataType::TinyInt(None)
107+
}
108+
103109
/// The SQL type to use for Arrow Int32 unparsing
104110
/// Most dialects use Integer, but some, like MySQL, require SIGNED
105111
fn int32_cast_dtype(&self) -> ast::DataType {
@@ -345,6 +351,10 @@ impl Dialect for PostgreSqlDialect {
345351
ast::DataType::DoublePrecision
346352
}
347353

354+
fn int8_cast_dtype(&self) -> ast::DataType {
355+
ast::DataType::SmallInt(None)
356+
}
357+
348358
fn scalar_function_to_sql_overrides(
349359
&self,
350360
unparser: &Unparser,
@@ -664,6 +674,7 @@ pub struct CustomDialect {
664674
large_utf8_cast_dtype: ast::DataType,
665675
date_field_extract_style: DateFieldExtractStyle,
666676
character_length_style: CharacterLengthStyle,
677+
int8_cast_dtype: ast::DataType,
667678
int64_cast_dtype: ast::DataType,
668679
int32_cast_dtype: ast::DataType,
669680
timestamp_cast_dtype: ast::DataType,
@@ -689,6 +700,7 @@ impl Default for CustomDialect {
689700
large_utf8_cast_dtype: ast::DataType::Text,
690701
date_field_extract_style: DateFieldExtractStyle::DatePart,
691702
character_length_style: CharacterLengthStyle::CharacterLength,
703+
int8_cast_dtype: ast::DataType::TinyInt(None),
692704
int64_cast_dtype: ast::DataType::BigInt(None),
693705
int32_cast_dtype: ast::DataType::Integer(None),
694706
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
@@ -748,6 +760,10 @@ impl Dialect for CustomDialect {
748760
self.int64_cast_dtype.clone()
749761
}
750762

763+
fn int8_cast_dtype(&self) -> ast::DataType {
764+
self.int8_cast_dtype.clone()
765+
}
766+
751767
fn int32_cast_dtype(&self) -> ast::DataType {
752768
self.int32_cast_dtype.clone()
753769
}
@@ -839,6 +855,7 @@ pub struct CustomDialectBuilder {
839855
large_utf8_cast_dtype: ast::DataType,
840856
date_field_extract_style: DateFieldExtractStyle,
841857
character_length_style: CharacterLengthStyle,
858+
int8_cast_dtype: ast::DataType,
842859
int64_cast_dtype: ast::DataType,
843860
int32_cast_dtype: ast::DataType,
844861
timestamp_cast_dtype: ast::DataType,
@@ -870,6 +887,7 @@ impl CustomDialectBuilder {
870887
large_utf8_cast_dtype: ast::DataType::Text,
871888
date_field_extract_style: DateFieldExtractStyle::DatePart,
872889
character_length_style: CharacterLengthStyle::CharacterLength,
890+
int8_cast_dtype: ast::DataType::TinyInt(None),
873891
int64_cast_dtype: ast::DataType::BigInt(None),
874892
int32_cast_dtype: ast::DataType::Integer(None),
875893
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
@@ -898,6 +916,7 @@ impl CustomDialectBuilder {
898916
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
899917
date_field_extract_style: self.date_field_extract_style,
900918
character_length_style: self.character_length_style,
919+
int8_cast_dtype: self.int8_cast_dtype,
901920
int64_cast_dtype: self.int64_cast_dtype,
902921
int32_cast_dtype: self.int32_cast_dtype,
903922
timestamp_cast_dtype: self.timestamp_cast_dtype,
@@ -952,6 +971,12 @@ impl CustomDialectBuilder {
952971
self
953972
}
954973

974+
/// Customize the dialect with a specific SQL type for Int8 casting: TinyInt, SmallInt, etc.
975+
pub fn with_int8_cast_dtype(mut self, int8_cast_dtype: ast::DataType) -> Self {
976+
self.int8_cast_dtype = int8_cast_dtype;
977+
self
978+
}
979+
955980
/// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc.
956981
pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self {
957982
self.float64_ast_dtype = float64_ast_dtype;

datafusion/sql/src/unparser/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ impl Unparser<'_> {
17341734
not_impl_err!("Unsupported DataType: conversion: {data_type}")
17351735
}
17361736
DataType::Boolean => Ok(ast::DataType::Bool),
1737-
DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
1737+
DataType::Int8 => Ok(self.dialect.int8_cast_dtype()),
17381738
DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
17391739
DataType::Int32 => Ok(self.dialect.int32_cast_dtype()),
17401740
DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,28 @@ fn test_without_offset() {
19371937
)
19381938
}
19391939

1940+
#[test]
1941+
fn test_cast_to_tinyint() -> Result<(), DataFusionError> {
1942+
roundtrip_statement_with_dialect_helper!(
1943+
sql: "select cast(3 as tinyint)",
1944+
parser_dialect: GenericDialect {},
1945+
unparser_dialect: UnparserPostgreSqlDialect {},
1946+
expected: @"SELECT CAST(3 AS SMALLINT)",
1947+
);
1948+
Ok(())
1949+
}
1950+
1951+
#[test]
1952+
fn test_cast_to_tinyint_default_dialect() -> Result<(), DataFusionError> {
1953+
roundtrip_statement_with_dialect_helper!(
1954+
sql: "select cast(3 as tinyint)",
1955+
parser_dialect: GenericDialect {},
1956+
unparser_dialect: UnparserDefaultDialect {},
1957+
expected: @"SELECT CAST(3 AS TINYINT)",
1958+
);
1959+
Ok(())
1960+
}
1961+
19401962
#[test]
19411963
fn test_with_offset0() {
19421964
let statement = generate_round_trip_statement(MySqlDialect {}, "select 1 offset 0");

0 commit comments

Comments
 (0)