@@ -32,7 +32,7 @@ use datafusion::{
3232 stream:: RecordBatchStreamAdapter , DisplayAs , DisplayFormatType , ExecutionMode ,
3333 ExecutionPlan , Partitioning , PlanProperties , SendableRecordBatchStream ,
3434 } ,
35- sql:: { unparser:: Unparser , TableReference } ,
35+ sql:: { sqlparser :: ast , unparser:: Unparser , TableReference } ,
3636} ;
3737
3838pub mod federation;
@@ -59,6 +59,7 @@ pub enum Engine {
5959 ODBC ,
6060 Postgres ,
6161 MySQL ,
62+ Default ,
6263}
6364
6465impl Engine {
@@ -68,19 +69,22 @@ impl Engine {
6869 Engine :: SQLite => Arc :: new ( SqliteDialect { } ) ,
6970 Engine :: Postgres => Arc :: new ( PostgreSqlDialect { } ) ,
7071 Engine :: MySQL => Arc :: new ( MySqlDialect { } ) ,
71- Engine :: Spark | Engine :: DuckDB | Engine :: ODBC => Arc :: new ( DefaultDialect { } ) ,
72+ Engine :: Spark | Engine :: DuckDB | Engine :: ODBC | Engine :: Default => {
73+ Arc :: new ( DefaultDialect { } )
74+ }
7275 }
7376 }
7477}
7578
7679pub type Result < T , E = Error > = std:: result:: Result < T , E > ;
7780
81+ #[ derive( Clone ) ]
7882pub struct SqlTable < T : ' static , P : ' static > {
7983 name : & ' static str ,
8084 pool : Arc < dyn DbConnectionPool < T , P > + Send + Sync > ,
8185 schema : SchemaRef ,
8286 pub table_reference : TableReference ,
83- engine : Option < Engine > ,
87+ engine : Engine ,
8488}
8589
8690impl < T , P > SqlTable < T , P > {
@@ -116,6 +120,7 @@ impl<T, P> SqlTable<T, P> {
116120 table_reference : impl Into < TableReference > ,
117121 engine : Option < Engine > ,
118122 ) -> Self {
123+ let engine = engine. unwrap_or ( Engine :: Default ) ;
119124 Self {
120125 name,
121126 pool : Arc :: clone ( pool) ,
@@ -139,7 +144,7 @@ impl<T, P> SqlTable<T, P> {
139144 Arc :: clone ( & self . pool ) ,
140145 filters,
141146 limit,
142- self . engine ,
147+ Some ( self . engine ) ,
143148 ) ?) )
144149 }
145150
@@ -177,16 +182,14 @@ impl<T, P> TableProvider for SqlTable<T, P> {
177182 & self ,
178183 filters : & [ & Expr ] ,
179184 ) -> DataFusionResult < Vec < TableProviderFilterPushDown > > {
180- let dialect = self
181- . engine
182- . map ( |e| e. dialect ( ) )
183- . unwrap_or_else ( || Arc :: new ( DefaultDialect { } ) ) ;
184185 let filter_push_down: Vec < TableProviderFilterPushDown > = filters
185186 . iter ( )
186- . map ( |f| match Unparser :: new ( dialect. as_ref ( ) ) . expr_to_sql ( f) {
187- Ok ( _) => TableProviderFilterPushDown :: Exact ,
188- Err ( _) => TableProviderFilterPushDown :: Unsupported ,
189- } )
187+ . map (
188+ |f| match Unparser :: new ( self . engine . dialect ( ) . as_ref ( ) ) . expr_to_sql ( f) {
189+ Ok ( _) => TableProviderFilterPushDown :: Exact ,
190+ Err ( _) => TableProviderFilterPushDown :: Unsupported ,
191+ } ,
192+ )
190193 . collect ( ) ;
191194
192195 Ok ( filter_push_down)
@@ -217,7 +220,7 @@ pub struct SqlExec<T, P> {
217220 filters : Vec < Expr > ,
218221 limit : Option < usize > ,
219222 properties : PlanProperties ,
220- engine : Option < Engine > ,
223+ engine : Engine ,
221224}
222225
223226pub fn project_schema_safe (
@@ -248,6 +251,7 @@ impl<T, P> SqlExec<T, P> {
248251 engine : Option < Engine > ,
249252 ) -> DataFusionResult < Self > {
250253 let projected_schema = project_schema_safe ( schema, projections) ?;
254+ let engine = engine. unwrap_or ( Engine :: Default ) ;
251255
252256 Ok ( Self {
253257 projected_schema : Arc :: clone ( & projected_schema) ,
@@ -286,15 +290,11 @@ impl<T, P> SqlExec<T, P> {
286290 let where_expr = if self . filters . is_empty ( ) {
287291 String :: new ( )
288292 } else {
289- let dialect = self
290- . engine
291- . map ( |e| e. dialect ( ) )
292- . unwrap_or_else ( || Arc :: new ( DefaultDialect { } ) ) ;
293293 let filter_expr = self
294294 . filters
295295 . iter ( )
296296 . map ( |f| {
297- Unparser :: new ( dialect. as_ref ( ) )
297+ Unparser :: new ( self . engine . dialect ( ) . as_ref ( ) )
298298 . expr_to_sql ( f)
299299 . map ( |e| e. to_string ( ) )
300300 } )
@@ -311,16 +311,20 @@ impl<T, P> SqlExec<T, P> {
311311 }
312312
313313 fn table_name_escaped ( & self ) -> String {
314- self . table_reference . to_quoted_string ( )
314+ self . ident_escaped ( & self . table_reference . to_string ( ) )
315315 }
316316
317317 fn column_name_escaped ( & self , column_name : & str ) -> String {
318- match self . engine {
319- Some ( Engine :: ODBC ) => column_name. to_string ( ) ,
320- _ => {
321- format ! ( "\" {}\" " , column_name)
322- }
323- }
318+ self . ident_escaped ( column_name)
319+ }
320+
321+ fn ident_escaped ( & self , ident : & str ) -> String {
322+ let quote_style = self . engine . dialect ( ) . identifier_quote_style ( ident) ;
323+ ast:: Expr :: Identifier ( ast:: Ident {
324+ value : ident. to_string ( ) ,
325+ quote_style,
326+ } )
327+ . to_string ( )
324328 }
325329}
326330
@@ -485,6 +489,7 @@ mod tests {
485489 ) ,
486490 Field :: new( "userId" , DataType :: LargeUtf8 , false ) ,
487491 Field :: new( "active" , DataType :: Boolean , false ) ,
492+ Field :: new( "5e48" , DataType :: LargeUtf8 , false ) ,
488493 ] ;
489494 let schema = Arc :: new ( Schema :: new ( fields) ) ;
490495 let pool = Arc :: new ( MockDBPool { } )
@@ -511,18 +516,18 @@ mod tests {
511516
512517 #[ tokio:: test]
513518 async fn test_sql_to_string_with_limit ( ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
514- let sql_exec = new_sql_exec ( Some ( & vec ! [ 0 ] ) , "users" , & [ ] , Some ( 3 ) , None ) ?;
515- assert_eq ! ( sql_exec. sql( ) ?, r#"SELECT "name" FROM users LIMIT 3"# ) ;
519+ let sql_exec = new_sql_exec ( Some ( & vec ! [ 0 , 1 ] ) , "users" , & [ ] , Some ( 3 ) , None ) ?;
520+ assert_eq ! ( sql_exec. sql( ) ?, r#"SELECT "name", age FROM users LIMIT 3"# ) ;
516521 Ok ( ( ) )
517522 }
518523
519524 #[ tokio:: test]
520525 async fn test_sql_to_string_with_filters ( ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
521526 let filters = vec ! [ col( "age" ) . gt_eq( lit( 30 ) ) . and( col( "name" ) . eq( lit( "x" ) ) ) ] ;
522- let sql_exec = new_sql_exec ( None , "users" , & filters, None , None ) ?;
527+ let sql_exec = new_sql_exec ( Some ( & vec ! [ 0 , 1 ] ) , "users" , & filters, None , None ) ?;
523528 assert_eq ! (
524529 sql_exec. sql( ) ?,
525- r#"SELECT "name", " age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30) AND ("name" = 'x')) "#
530+ r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) "#
526531 ) ;
527532 Ok ( ( ) )
528533 }
@@ -531,21 +536,57 @@ mod tests {
531536 async fn test_sql_to_string_with_filters_and_limit (
532537 ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
533538 let filters = vec ! [ col( "age" ) . gt_eq( lit( 30 ) ) . and( col( "name" ) . eq( lit( "x" ) ) ) ] ;
534- let sql_exec = new_sql_exec ( None , "users" , & filters, Some ( 3 ) , None ) ?;
539+ let sql_exec = new_sql_exec ( Some ( & vec ! [ 0 , 1 ] ) , "users" , & filters, Some ( 3 ) , None ) ?;
535540 assert_eq ! (
536541 sql_exec. sql( ) ?,
537- r#"SELECT "name", " age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
542+ r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
538543 ) ;
539544 Ok ( ( ) )
540545 }
541546
542547 #[ tokio:: test]
543548 async fn test_sql_to_string_with_engine ( ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
544549 let filters = vec ! [ col( "age" ) . gt_eq( lit( 30 ) ) . and( col( "name" ) . eq( lit( "x" ) ) ) ] ;
545- let sql_exec = new_sql_exec ( None , "users" , & filters, Some ( 3 ) , Some ( Engine :: DuckDB ) ) ?;
550+ let sql_exec = new_sql_exec (
551+ Some ( & vec ! [ 0 , 1 ] ) ,
552+ "users" ,
553+ & filters,
554+ Some ( 3 ) ,
555+ Some ( Engine :: DuckDB ) ,
556+ ) ?;
557+ assert_eq ! (
558+ sql_exec. sql( ) ?,
559+ r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
560+ ) ;
561+ Ok ( ( ) )
562+ }
563+
564+ #[ tokio:: test]
565+ async fn test_sql_to_string_with_not_reasonable_name (
566+ ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
567+ let filters = vec ! [ col( "5e48" ) . eq( lit( "test" ) ) . and( col( "name" ) . eq( lit( "x" ) ) ) ] ;
568+ let sql_exec = new_sql_exec ( Some ( & vec ! [ 0 , 1 , 5 ] ) , "users" , & filters, Some ( 3 ) , None ) ?;
569+ assert_eq ! (
570+ sql_exec. sql( ) ?,
571+ r#"SELECT "name", age, "5e48" FROM users WHERE (("5e48" = 'test') AND ("name" = 'x')) LIMIT 3"#
572+ ) ;
573+ Ok ( ( ) )
574+ }
575+
576+ #[ tokio:: test]
577+ async fn test_sql_to_string_with_not_reasonable_name_mysql (
578+ ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
579+ let filters = vec ! [ col( "5e48" ) . eq( lit( "test" ) ) . and( col( "name" ) . eq( lit( "x" ) ) ) ] ;
580+ let sql_exec = new_sql_exec (
581+ Some ( & vec ! [ 0 , 1 , 5 ] ) ,
582+ "users" ,
583+ & filters,
584+ Some ( 3 ) ,
585+ Some ( Engine :: MySQL ) ,
586+ ) ?;
546587 assert_eq ! (
547588 sql_exec. sql( ) ?,
548- r#"SELECT " name", " age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30 ) AND (" name" = 'x')) LIMIT 3"#
589+ r#"SELECT ` name`, ` age`, `5e48` FROM ` users` WHERE ((`5e48` = 'test' ) AND (` name` = 'x')) LIMIT 3"#
549590 ) ;
550591 Ok ( ( ) )
551592 }
0 commit comments