Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/duckdb/sql_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl<T, P> DuckSqlExec<T, P> {
filters,
limit,
Some(Engine::DuckDB),
None,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)?;

Ok(Self {
Expand Down
1 change: 1 addition & 0 deletions src/mysql/sql_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl MySQLSQLExec {
filters,
limit,
Some(Engine::MySQL),
None,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)?;

Ok(Self { base_exec })
Expand Down
14 changes: 14 additions & 0 deletions src/sql/db_connection_pool/dbconnection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,17 @@ pub async fn query_arrow<T, P>(
return Err(Error::UnableToDowncastConnection {});
}
}

#[cfg(test)]
pub(crate) struct DummyDbConnection(pub ());

#[cfg(test)]
impl DbConnection<(), ()> for DummyDbConnection {
fn as_any(&self) -> &dyn Any {
&self.0
}

fn as_any_mut(&mut self) -> &mut dyn Any {
&mut self.0
}
}
15 changes: 15 additions & 0 deletions src/sql/db_connection_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,18 @@ impl DbInstanceKey {
DbInstanceKey::File(path)
}
}

#[cfg(test)]
pub(crate) struct DummyDbConnectionPool;

#[cfg(test)]
#[async_trait]
impl DbConnectionPool<(), ()> for DummyDbConnectionPool {
async fn connect(&self) -> Result<Box<dyn DbConnection<(), ()>>> {
Ok(Box::new(dbconnection::DummyDbConnection(())))
}

fn join_push_down(&self) -> JoinPushDown {
JoinPushDown::Disallow
}
}
104 changes: 80 additions & 24 deletions src/sql/sql_provider_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl<T, P> SqlTable<T, P> {
filters,
limit,
self.engine,
self.dialect.clone(),
)?))
}

Expand Down Expand Up @@ -222,6 +223,7 @@ pub struct SqlExec<T, P> {
limit: Option<usize>,
properties: PlanProperties,
engine: Option<Engine>,
dialect: Option<Arc<dyn Dialect + Send + Sync>>,
}

pub fn project_schema_safe(
Expand Down Expand Up @@ -250,6 +252,7 @@ impl<T, P> SqlExec<T, P> {
filters: &[Expr],
limit: Option<usize>,
engine: Option<Engine>,
dialect: Option<Arc<dyn Dialect + Send + Sync>>,
) -> DataFusionResult<Self> {
let projected_schema = project_schema_safe(schema, projections)?;

Expand All @@ -266,6 +269,7 @@ impl<T, P> SqlExec<T, P> {
Boundedness::Bounded,
),
engine,
dialect,
})
}
#[must_use]
Expand All @@ -274,15 +278,25 @@ impl<T, P> SqlExec<T, P> {
}

pub fn sql(&self) -> Result<String> {
let quote = |name| {
let quote = self
.dialect
.as_ref()
.and_then(|d| d.identifier_quote_style(name))
.unwrap_or('"');
format!("{quote}{name}{quote}")
};
Comment on lines +281 to +288

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this inline closure be moved to its own function?


let columns = self
.projected_schema
.fields()
.iter()
.map(|f| {
let name = f.name();
if let Some(Engine::ODBC) = self.engine {
f.name().to_owned()
name.to_owned()
} else {
format!("\"{}\"", f.name())
quote(name)
}
})
.collect::<Vec<_>>()
Expand All @@ -297,7 +311,7 @@ impl<T, P> SqlExec<T, P> {
String::new()
} else {
let dialect = self.engine.map_or(
Arc::new(DefaultDialect {}) as Arc<dyn Dialect + Send + Sync>,
self.dialect.clone().unwrap_or(Arc::new(DefaultDialect {})),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for learning purpose: when would the .sql() method be used when it uses the dialect that's other than the engine.dialect()

|e| e.dialect(),
);
let unparser = Unparser::new(dialect.as_ref());
Expand All @@ -311,9 +325,10 @@ impl<T, P> SqlExec<T, P> {
format!("WHERE {}", filter_expr.join(" AND "))
};

let table_reference = quote(self.table_reference.table());

Ok(format!(
"SELECT {columns} FROM {table_reference} {where_expr} {limit_expr}",
table_reference = self.table_reference.to_quoted_string()
))
}
}
Expand Down Expand Up @@ -398,27 +413,49 @@ pub fn to_execution_error(

#[cfg(test)]
mod tests {
use std::{error::Error, sync::Arc};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::sql::unparser::dialect::CustomDialectBuilder;
use db_connection_pool::DummyDbConnectionPool;

use datafusion::execution::context::SessionContext;
use datafusion::sql::TableReference;
use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};

use crate::sql::sql_provider_datafusion::SqlTable;

fn setup_tracing() -> DefaultGuard {
let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
.with_max_level(LevelFilter::DEBUG)
.finish();

let dispatch = Dispatch::new(subscriber);
tracing::dispatcher::set_default(&dispatch)
}
use super::*;

#[test]
fn test_references() {
let table_ref = TableReference::bare("test");
assert_eq!(format!("{table_ref}"), "test");
fn test_sql_exec_backtick_dialect() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]));

let table_reference = TableReference::bare("my_table");

let dialect = CustomDialectBuilder::new()
.with_identifier_quote_style('`')
.build();

let sql_exec = SqlExec::<(), ()>::new(
None, // No projections
&schema,
&table_reference,
Arc::new(DummyDbConnectionPool),
&[],
None,
None,
Some(Arc::new(dialect)),
)
.expect("Failed to create SqlExec");

// Generate SQL
let sql = sql_exec.sql().expect("Failed to generate SQL");

// Expected SQL with backtick-quoted identifiers
let expected_sql = "SELECT `id`, `name` FROM `my_table`";

// Assert the generated SQL matches the expected output
assert_eq!(
sql.trim(),
expected_sql,
"SQL output does not match expected backtick-quoted format"
);
}

#[cfg(feature = "duckdb")]
Expand All @@ -428,10 +465,28 @@ mod tests {
DuckDBSyncParameter, DuckDbConnection,
};
use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
use datafusion::prelude::SessionContext;
use datafusion::sql::TableReference;
use duckdb::DuckdbConnectionManager;
use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};

fn setup_tracing() -> DefaultGuard {
let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
.with_max_level(LevelFilter::DEBUG)
.finish();

let dispatch = Dispatch::new(subscriber);
tracing::dispatcher::set_default(&dispatch)
}

#[test]
fn test_references() {
let table_ref = TableReference::bare("test");
assert_eq!(format!("{table_ref}"), "test");
}

#[tokio::test]
async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
async fn test_duckdb_table() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let t = setup_tracing();
let ctx = SessionContext::new();
let pool: Arc<
Expand Down Expand Up @@ -466,7 +521,8 @@ mod tests {
}

#[tokio::test]
async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
async fn test_duckdb_table_filter() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
let t = setup_tracing();
let ctx = SessionContext::new();
let pool: Arc<
Expand Down
1 change: 1 addition & 0 deletions src/sqlite/sql_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ impl<T, P> SQLiteSqlExec<T, P> {
filters,
limit,
Some(Engine::SQLite),
None,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base table on the SQLiteTable has a dialect that could be passed down.

)?;

Ok(Self { base_exec })
Expand Down
Loading