Skip to content

Commit 5bbbb32

Browse files
use quote style for column and table names based on the sql dialect (#102)
* don't wrap Engine with Options in SqlTable and SqlExec * sql/sql_provider: use quote style for column and table names based on the dialect --------- Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech>
1 parent 2c2ad24 commit 5bbbb32

2 files changed

Lines changed: 75 additions & 37 deletions

File tree

src/sql/sql_provider_datafusion/federation.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::sql::db_connection_pool::{dbconnection::get_schema, JoinPushDown};
22
use async_trait::async_trait;
3-
use datafusion::sql::unparser::dialect::DefaultDialect;
43
use datafusion_federation::sql::{SQLExecutor, SQLFederationProvider, SQLTableSource};
54
use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource};
65
use futures::TryStreamExt;
@@ -58,9 +57,7 @@ impl<T, P> SQLExecutor for SqlTable<T, P> {
5857
}
5958

6059
fn dialect(&self) -> Arc<dyn Dialect> {
61-
self.engine
62-
.map(|e| e.dialect())
63-
.unwrap_or_else(|| Arc::new(DefaultDialect {}))
60+
self.engine.dialect()
6461
}
6562

6663
fn execute(

src/sql/sql_provider_datafusion/mod.rs

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use datafusion::{
3232
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionMode,
3333
ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream,
3434
},
35-
sql::{unparser::Unparser, TableReference},
35+
sql::{sqlparser::ast, unparser::Unparser, TableReference},
3636
};
3737

3838
pub mod federation;
@@ -59,6 +59,7 @@ pub enum Engine {
5959
ODBC,
6060
Postgres,
6161
MySQL,
62+
Default,
6263
}
6364

6465
impl Engine {
@@ -68,19 +69,22 @@ impl Engine {
6869
Engine::SQLite => Arc::new(SqliteDialect {}),
6970
Engine::Postgres => Arc::new(PostgreSqlDialect {}),
7071
Engine::MySQL => Arc::new(MySqlDialect {}),
71-
Engine::Spark | Engine::DuckDB | Engine::ODBC => Arc::new(DefaultDialect {}),
72+
Engine::Spark | Engine::DuckDB | Engine::ODBC | Engine::Default => {
73+
Arc::new(DefaultDialect {})
74+
}
7275
}
7376
}
7477
}
7578

7679
pub type Result<T, E = Error> = std::result::Result<T, E>;
7780

81+
#[derive(Clone)]
7882
pub struct SqlTable<T: 'static, P: 'static> {
7983
name: &'static str,
8084
pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
8185
schema: SchemaRef,
8286
pub table_reference: TableReference,
83-
engine: Option<Engine>,
87+
engine: Engine,
8488
}
8589

8690
impl<T, P> SqlTable<T, P> {
@@ -116,6 +120,7 @@ impl<T, P> SqlTable<T, P> {
116120
table_reference: impl Into<TableReference>,
117121
engine: Option<Engine>,
118122
) -> Self {
123+
let engine = engine.unwrap_or(Engine::Default);
119124
Self {
120125
name,
121126
pool: Arc::clone(pool),
@@ -139,7 +144,7 @@ impl<T, P> SqlTable<T, P> {
139144
Arc::clone(&self.pool),
140145
filters,
141146
limit,
142-
self.engine,
147+
Some(self.engine),
143148
)?))
144149
}
145150

@@ -177,16 +182,14 @@ impl<T, P> TableProvider for SqlTable<T, P> {
177182
&self,
178183
filters: &[&Expr],
179184
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
180-
let dialect = self
181-
.engine
182-
.map(|e| e.dialect())
183-
.unwrap_or_else(|| Arc::new(DefaultDialect {}));
184185
let filter_push_down: Vec<TableProviderFilterPushDown> = filters
185186
.iter()
186-
.map(|f| match Unparser::new(dialect.as_ref()).expr_to_sql(f) {
187-
Ok(_) => TableProviderFilterPushDown::Exact,
188-
Err(_) => TableProviderFilterPushDown::Unsupported,
189-
})
187+
.map(
188+
|f| match Unparser::new(self.engine.dialect().as_ref()).expr_to_sql(f) {
189+
Ok(_) => TableProviderFilterPushDown::Exact,
190+
Err(_) => TableProviderFilterPushDown::Unsupported,
191+
},
192+
)
190193
.collect();
191194

192195
Ok(filter_push_down)
@@ -217,7 +220,7 @@ pub struct SqlExec<T, P> {
217220
filters: Vec<Expr>,
218221
limit: Option<usize>,
219222
properties: PlanProperties,
220-
engine: Option<Engine>,
223+
engine: Engine,
221224
}
222225

223226
pub fn project_schema_safe(
@@ -248,6 +251,7 @@ impl<T, P> SqlExec<T, P> {
248251
engine: Option<Engine>,
249252
) -> DataFusionResult<Self> {
250253
let projected_schema = project_schema_safe(schema, projections)?;
254+
let engine = engine.unwrap_or(Engine::Default);
251255

252256
Ok(Self {
253257
projected_schema: Arc::clone(&projected_schema),
@@ -286,15 +290,11 @@ impl<T, P> SqlExec<T, P> {
286290
let where_expr = if self.filters.is_empty() {
287291
String::new()
288292
} else {
289-
let dialect = self
290-
.engine
291-
.map(|e| e.dialect())
292-
.unwrap_or_else(|| Arc::new(DefaultDialect {}));
293293
let filter_expr = self
294294
.filters
295295
.iter()
296296
.map(|f| {
297-
Unparser::new(dialect.as_ref())
297+
Unparser::new(self.engine.dialect().as_ref())
298298
.expr_to_sql(f)
299299
.map(|e| e.to_string())
300300
})
@@ -311,16 +311,20 @@ impl<T, P> SqlExec<T, P> {
311311
}
312312

313313
fn table_name_escaped(&self) -> String {
314-
self.table_reference.to_quoted_string()
314+
self.ident_escaped(&self.table_reference.to_string())
315315
}
316316

317317
fn column_name_escaped(&self, column_name: &str) -> String {
318-
match self.engine {
319-
Some(Engine::ODBC) => column_name.to_string(),
320-
_ => {
321-
format!("\"{}\"", column_name)
322-
}
323-
}
318+
self.ident_escaped(column_name)
319+
}
320+
321+
fn ident_escaped(&self, ident: &str) -> String {
322+
let quote_style = self.engine.dialect().identifier_quote_style(ident);
323+
ast::Expr::Identifier(ast::Ident {
324+
value: ident.to_string(),
325+
quote_style,
326+
})
327+
.to_string()
324328
}
325329
}
326330

@@ -485,6 +489,7 @@ mod tests {
485489
),
486490
Field::new("userId", DataType::LargeUtf8, false),
487491
Field::new("active", DataType::Boolean, false),
492+
Field::new("5e48", DataType::LargeUtf8, false),
488493
];
489494
let schema = Arc::new(Schema::new(fields));
490495
let pool = Arc::new(MockDBPool {})
@@ -511,18 +516,18 @@ mod tests {
511516

512517
#[tokio::test]
513518
async fn test_sql_to_string_with_limit() -> Result<(), Box<dyn Error + Send + Sync>> {
514-
let sql_exec = new_sql_exec(Some(&vec![0]), "users", &[], Some(3), None)?;
515-
assert_eq!(sql_exec.sql()?, r#"SELECT "name" FROM users LIMIT 3"#);
519+
let sql_exec = new_sql_exec(Some(&vec![0, 1]), "users", &[], Some(3), None)?;
520+
assert_eq!(sql_exec.sql()?, r#"SELECT "name", age FROM users LIMIT 3"#);
516521
Ok(())
517522
}
518523

519524
#[tokio::test]
520525
async fn test_sql_to_string_with_filters() -> Result<(), Box<dyn Error + Send + Sync>> {
521526
let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
522-
let sql_exec = new_sql_exec(None, "users", &filters, None, None)?;
527+
let sql_exec = new_sql_exec(Some(&vec![0, 1]), "users", &filters, None, None)?;
523528
assert_eq!(
524529
sql_exec.sql()?,
525-
r#"SELECT "name", "age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30) AND ("name" = 'x')) "#
530+
r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) "#
526531
);
527532
Ok(())
528533
}
@@ -531,21 +536,57 @@ mod tests {
531536
async fn test_sql_to_string_with_filters_and_limit(
532537
) -> Result<(), Box<dyn Error + Send + Sync>> {
533538
let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
534-
let sql_exec = new_sql_exec(None, "users", &filters, Some(3), None)?;
539+
let sql_exec = new_sql_exec(Some(&vec![0, 1]), "users", &filters, Some(3), None)?;
535540
assert_eq!(
536541
sql_exec.sql()?,
537-
r#"SELECT "name", "age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
542+
r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
538543
);
539544
Ok(())
540545
}
541546

542547
#[tokio::test]
543548
async fn test_sql_to_string_with_engine() -> Result<(), Box<dyn Error + Send + Sync>> {
544549
let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
545-
let sql_exec = new_sql_exec(None, "users", &filters, Some(3), Some(Engine::DuckDB))?;
550+
let sql_exec = new_sql_exec(
551+
Some(&vec![0, 1]),
552+
"users",
553+
&filters,
554+
Some(3),
555+
Some(Engine::DuckDB),
556+
)?;
557+
assert_eq!(
558+
sql_exec.sql()?,
559+
r#"SELECT "name", age FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
560+
);
561+
Ok(())
562+
}
563+
564+
#[tokio::test]
565+
async fn test_sql_to_string_with_not_reasonable_name(
566+
) -> Result<(), Box<dyn Error + Send + Sync>> {
567+
let filters = vec![col("5e48").eq(lit("test")).and(col("name").eq(lit("x")))];
568+
let sql_exec = new_sql_exec(Some(&vec![0, 1, 5]), "users", &filters, Some(3), None)?;
569+
assert_eq!(
570+
sql_exec.sql()?,
571+
r#"SELECT "name", age, "5e48" FROM users WHERE (("5e48" = 'test') AND ("name" = 'x')) LIMIT 3"#
572+
);
573+
Ok(())
574+
}
575+
576+
#[tokio::test]
577+
async fn test_sql_to_string_with_not_reasonable_name_mysql(
578+
) -> Result<(), Box<dyn Error + Send + Sync>> {
579+
let filters = vec![col("5e48").eq(lit("test")).and(col("name").eq(lit("x")))];
580+
let sql_exec = new_sql_exec(
581+
Some(&vec![0, 1, 5]),
582+
"users",
583+
&filters,
584+
Some(3),
585+
Some(Engine::MySQL),
586+
)?;
546587
assert_eq!(
547588
sql_exec.sql()?,
548-
r#"SELECT "name", "age", "createdDate", "userId", "active" FROM users WHERE ((age >= 30) AND ("name" = 'x')) LIMIT 3"#
589+
r#"SELECT `name`, `age`, `5e48` FROM `users` WHERE ((`5e48` = 'test') AND (`name` = 'x')) LIMIT 3"#
549590
);
550591
Ok(())
551592
}

0 commit comments

Comments
 (0)